xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2from __future__ import annotations
3
4import abc
5
6import argparse
7import collections
8import contextlib
9import copy
10import csv
11import dataclasses
12import functools
13import importlib
14import itertools
15import logging
16import os
17import pathlib
18import shutil
19import signal
20import subprocess
21import sys
22import time
23import weakref
24from contextlib import contextmanager
25
26from typing import (
27    Any,
28    Callable,
29    Generator,
30    List,
31    Mapping,
32    NamedTuple,
33    Optional,
34    Sequence,
35    Tuple,
36    Type,
37    TYPE_CHECKING,
38)
39from typing_extensions import Self
40from unittest.mock import MagicMock
41
42import numpy as np
43import pandas as pd
44import psutil
45from scipy.stats import gmean, ttest_ind
46from tqdm.auto import tqdm, trange
47
48import torch
49import torch._dynamo
50import torch._dynamo.utils
51import torch._export
52import torch.distributed
53import torch.multiprocessing as mp
54from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
55from torch._dynamo.profiler import fx_insert_profiling, Profiler
56from torch._dynamo.testing import (
57    dummy_fx_compile,
58    format_speedup,
59    reset_rng_state,
60    same,
61)
62
63try:
64    from torch._dynamo.utils import (
65        clone_inputs,
66        graph_break_reasons,
67        maybe_enable_compiled_autograd,
68    )
69    from torch._inductor.utils import fresh_inductor_cache
70except ImportError:
71    from _dynamo.utils import (
72        clone_inputs,
73        graph_break_reasons,
74        maybe_enable_compiled_autograd,
75    )
76
77import torch._functorch.config
78from torch._functorch.aot_autograd import set_model_name
79from torch._inductor import config as inductor_config, metrics
80from torch._subclasses.fake_tensor import FakeTensorMode
81from torch.utils import _pytree as pytree
82from torch.utils._pytree import tree_map, tree_map_only
83
84try:
85    import torch_xla
86    import torch_xla.core.xla_model as xm
87
88    # This is to woraround the backward issue https://github.com/pytorch/xla/issues/4174
89    torch_xla._XLAC._init_computation_client()
90except ImportError:
91    # ignore the error if torch_xla is not installed
92    pass
93
94
95if TYPE_CHECKING:
96    from torch.onnx._internal.fx import diagnostics
97
98
99log = logging.getLogger(__name__)
100
101# We are primarily interested in TF32
102torch.backends.cuda.matmul.allow_tf32 = True
103
104# Suppress torch.profiler spam
105os.environ["KINETO_LOG_LEVEL"] = "5"
106
107current_name = ""
108current_device = ""
109current_onnx_compiler = ""
110current_batch_size = None
111output_filename = None
112disable_output = False
113
114MAX_DOWNLOAD_ATTEMPTS = 5
115
116
117class CI(NamedTuple):
118    backend: str  # aot_eager or inductor
119    training: bool
120    dynamic: bool = False
121    device: str = "cuda"
122
123
124CI_SKIP_OPTIMIZER = {
125    # TIMM
126    "convmixer_768_32",  # accuracy
127    "hrnet_w18",  # Stack issue in fx
128    # HF
129    "pnasnet5large",  # Stack issue in fx
130    "MobileBertForMaskedLM",  # Stack issue in fx
131    "MobileBertForQuestionAnswering",  # Stack issue in fx
132    "PegasusForConditionalGeneration",  # OOM
133}
134
135CI_SKIP_DYNAMIC_BATCH_ONLY = {
136    "sam",
137    # See https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L89
138    # It iterates over the batch, which is dynamic, and dynamo chokes
139    # We should be able to graphbreak there.
140    "doctr_det_predictor",
141    "dlrm",
142    "pyhpc_isoneutral_mixing",
143    "pyhpc_equation_of_state",
144    "pyhpc_turbulent_kinetic_energy",
145    "detectron2_fcos_r_50_fpn",
146    "hf_T5_generate",
147}
148
149# These models currently fail accuracy with eager Adam optimizer
150# so we use SGD when running the full benchmarks
151# https://github.com/pytorch/pytorch/issues/115966
152BENCHMARK_USE_SGD = {
153    # TorchBench
154    "BERT_pytorch",
155    "LearningToPaint",
156    "alexnet",
157    "dcgan",
158    "demucs",
159    "densenet121",
160    "dlrm",
161    "fastNLP_Bert",
162    "mobilenet_v2",
163    "phlippe_densenet",
164    "phlippe_resnet",
165    "pytorch_stargan",
166    "resnet18",
167    "shufflenet_v2_x1_0",
168    "speech_transformer",
169    "squeezenet1_1",
170    "stable_diffusion_text_encoder",
171    "timm_efficientdet",
172    "timm_nfnet",
173    "timm_regnet",
174    "timm_vision_transformer",
175    "timm_vovnet",
176    "vgg16",
177    "hf_T5",  # Fails dynamic https://github.com/pytorch/pytorch/issues/115968
178    # HF
179    "AlbertForMaskedLM",
180    "BartForCausalLM",
181    "BartForConditionalGeneration",
182    "BlenderbotSmallForCausalLM",
183    "BlenderbotSmallForConditionalGeneration",
184    "DebertaV2ForQuestionAnswering",  # eager OOM
185    "ElectraForCausalLM",
186    "M2M100ForConditionalGeneration",
187    "MBartForCausalLM",
188    "MBartForConditionalGeneration",
189    "OPTForCausalLM",
190    "PLBartForCausalLM",
191    "PLBartForConditionalGeneration",
192    "PegasusForCausalLM",
193    "Speech2Text2ForCausalLM",
194    "TrOCRForCausalLM",
195    "XGLMForCausalLM",
196    # TIMM
197    "adv_inception_v3",
198    "botnet26t_256",
199    "cait_m36_384",  # OOM
200    "coat_lite_mini",
201    "convit_base",
202    "dpn107",
203    "fbnetv3_b",
204    "gernet_l",
205    "lcnet_050",
206    "mixnet_l",
207    "res2net101_26w_4s",
208    "res2net50_14w_8s",
209    "res2next50",
210    "resnest101e",
211    "sebotnet33ts_256",
212    "swsl_resnext101_32x16d",
213    "tf_efficientnet_b0",
214    "ghostnet_100",
215    "gmixer_24_224",
216    "tinynet_a",
217}
218
219# These models OOM in CI
220# due to the extra memory of Adam optimizer states,
221# so we fall back to SGD in CI
222CI_USE_SGD = {
223    "torchrec_dlrm",
224    "demucs",
225    "detectron2_fasterrcnn_r_101_c4",
226    "detectron2_fasterrcnn_r_101_dc5",
227    "detectron2_fasterrcnn_r_101_fpn",
228    "detectron2_fasterrcnn_r_50_c4",
229    "detectron2_fasterrcnn_r_50_dc5",
230    "detectron2_fasterrcnn_r_50_fpn",
231    "detectron2_maskrcnn_r_101_c4",
232    "detectron2_maskrcnn_r_101_fpn",
233    "detectron2_maskrcnn_r_50_c4",
234    "detectron2_maskrcnn_r_50_fpn",
235    "hf_T5_base",
236    "hf_clip",
237    "llama_v2_7b_16h",
238    "mobilenet_v2_quantized_qat",
239    "phi_1_5 resnet50_quantized_qat",
240    "BlenderbotForCausalLM",
241    "cait_m36_384",
242    "DALLE2_pytorch",
243    "moco",
244    "timm_efficientdet",
245    "ghostnet_100",
246    "regnety_002",
247    "poolformer_m36",
248    "inception_v3",
249    "tinynet_a",
250    "selecsls42b",
251    "mobilevit_s",
252    "pytorch_CycleGAN_and_pix2pix",
253    "vision_maskrcnn",
254    "resmlp_12_224",
255    "dlrm",
256    "resnet50",
257    "dm_nfnet_f0",
258    "pit_b_224",
259    "tf_mixnet_l",
260}
261
262
263DO_NOT_CAST_INPUTS = {"stable_diffusion"}
264
265
266# Maps a benchmark model name to a list of status codes. For any listed entry, we'll
267# capture TORCH_COMPILE_DEBUG logs in CI runs and preseve them (i.e., for upload) if
268# the result status matches one listed.
269CI_PRESERVE_COMPILE_DEBUG = {
270    # For example:
271    # "mnasnet1_0": ["fail_accuracy"],
272}
273
274
275def model_specified_by_path(path_and_class_str):
276    return ":" in path_and_class_str
277
278
279def load_model_from_path(path_and_class_str):
280    configs = {}
281    for kvstr in path_and_class_str.split(","):
282        k, v = kvstr.split(":")
283        configs[k] = v
284
285    for name in ["path", "class"]:
286        if name not in configs:
287            raise RuntimeError(
288                "Invalid --only arguments. Check help message for the correct format"
289            )
290
291    path = configs["path"]
292    class_name = configs["class"]
293
294    if path[:1] != "/":
295        raise RuntimeError(
296            "Use absolute path since dynamo may change the current working directory which makes using relative path tricky"
297        )
298
299    spec = importlib.util.spec_from_file_location("module_name", path)
300    module = importlib.util.module_from_spec(spec)
301    spec.loader.exec_module(module)
302
303    model_class = getattr(module, class_name)
304    assert issubclass(model_class, torch.nn.Module)
305    model = model_class()
306    assert hasattr(model, "get_example_inputs")
307    inputs = model.get_example_inputs()
308    return model, inputs
309
310
311def output_csv(filename, headers, row):
312    global disable_output
313    if disable_output:
314        return
315    if os.path.exists(filename):
316        with open(filename) as fd:
317            lines = list(csv.reader(fd)) or [[]]
318            if headers and len(headers) > len(lines[0]):
319                # if prior results failed the header might not be filled in yet
320                lines[0] = headers
321            else:
322                headers = lines[0]
323    else:
324        lines = [headers]
325    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
326    with open(filename, "w") as fd:
327        writer = csv.writer(fd, lineterminator="\n")
328        for line in lines:
329            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
330
331
332def nothing(f):
333    return f
334
335
336@functools.lru_cache(None)
337def patch_torch_manual_seed():
338    """Make torch manual seed deterministic. Helps with accuracy testing."""
339
340    def deterministic_torch_manual_seed(*args, **kwargs):
341        from torch._C import default_generator
342
343        seed = 1337
344        if HAS_CUDA:
345            import torch.cuda
346
347            if not torch.cuda._is_in_bad_fork():
348                torch.cuda.manual_seed_all(seed)
349        if HAS_XPU:
350            import torch.xpu
351
352            if not torch.xpu._is_in_bad_fork():
353                torch.xpu.manual_seed_all(seed)
354        return default_generator.manual_seed(seed)
355
356    torch.manual_seed = deterministic_torch_manual_seed
357
358
359def empty_gpu_cache(device):
360    """
361    Explicitly empty gpu cache to avoid OOM in subsequent run.
362    """
363
364    if device not in ["cuda", "xpu"]:
365        log.warning(
366            "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
367            device,
368        )
369        return
370
371    if device == "cuda":
372        torch.cuda.empty_cache()
373    elif device == "xpu":
374        torch.xpu.empty_cache()
375
376
377def synchronize():
378    pass
379
380
381def summarize_graph_break(filename):
382    """
383    Sorts and de-dupes the graphs breaks on the reason string. Note that this
384    function is just a best effort to reduce the logging information. We could
385    miss some graph breaks because of de-duping. We can further refine this
386    function as need arises.
387    """
388    log_file = f"{filename.rstrip('.csv')}_graph_breaks.csv"
389    if os.path.exists(log_file):
390        df = pd.read_csv(log_file)
391        df = df.sort_values("reason").drop_duplicates(subset="reason")
392
393        # Specialize for multi tensor sgd as reason is not identical
394        multi_tensor_sgd_row = df.loc[df["reason"].str.contains("_multi_tensor_sgd")]
395        if len(multi_tensor_sgd_row):
396            df = df[
397                ~df["reason"].str.contains("_multi_tensor_sgd")
398            ]  # Drop all sgd rows
399            df = pd.concat(
400                [df, pd.DataFrame([multi_tensor_sgd_row.iloc[0]])], axis=0
401            )  # Add back a single row
402        df.to_csv(f"{log_file.rstrip('.csv')}_deduped.csv", index=False)
403
404
405def print_summary(filename, print_dataframe=False):
406    if not (filename and os.path.exists(filename)):
407        return
408    data = pd.read_csv(filename)
409    if "tag" in data.columns:
410        for tag in data.tag.unique():
411            if tag == "0.0000":
412                continue  # This happens for failed runs
413            print(f"\nSummary for tag={tag}:")
414            print_summary_table(data[data.tag == tag], print_dataframe=print_dataframe)
415    else:
416        print_summary_table(data, print_dataframe=print_dataframe)
417    summarize_graph_break(filename)
418
419
420def print_summary_table(data, print_dataframe=False):
421    if print_dataframe:
422        pd.options.display.max_rows = 1000
423        pd.options.display.max_columns = 1000
424        pd.options.display.width = 2000
425        print(data)
426    width = max(map(len, data.columns))
427    for col in data.columns:
428        try:
429            if col in ("dev", "name", "batch_size", "tag"):
430                continue
431            elif col in ("pct_ops", "pct_time"):
432                print(col.ljust(width), f"{data[col].mean():.3%}")
433            elif col in ("graphs", "graph_calls", "captured_ops", "total_ops"):
434                print(col.ljust(width), f"{data[col].mean():.3f}")
435            elif col in ("compilation_latency"):
436                print(col.ljust(width), f"mean={data[col].mean():.3f} seconds")
437            elif col in ("compression_ratio"):
438                print(col.ljust(width), f"mean={data[col].mean():.3f}x")
439            elif col in ("accuracy"):
440                pass_rate = (data[col] == "pass").mean()
441                print(col.ljust(width), f"pass_rate={100*pass_rate:.2f}%")
442            else:
443                cdata = data[col]
444                print(
445                    col.ljust(width),
446                    f"gmean={gmean(cdata):.2f}x mean={cdata.mean():.3f}x",
447                )
448        except Exception as e:
449            pass
450
451
452def tensor_is_on_xla(tensors):
453    def visit(x: torch.Tensor):
454        nonlocal result
455        if x.device.type == "xla":
456            result = True
457
458    result = False
459    tree_map_only(torch.Tensor, visit, tensors)
460    return result
461
462
463def timed(
464    model,
465    model_iter_fn,
466    example_inputs,
467    times=1,
468    return_result=False,
469    collect_outputs=False,
470):
471    use_xla = tensor_is_on_xla(example_inputs)
472    synchronize()
473
474    if use_xla:
475        xm.mark_step()
476        xm.wait_device_ops()
477
478    time_total = 0
479    # Dont collect outputs to correctly measure timing
480    for _ in range(times):
481        # Put this call inside the loop to reset the seed for each iteration.
482        # Don't include reset_rng_state() to correctly measure timing
483        reset_rng_state(use_xla)
484        t_iter_begin = time.perf_counter()
485        result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
486
487        # instead of calling sync on result_list, we should call mark_step.
488        # In training case, result_list may be empty, but we want to
489        # send all the pending graphs for compilation.
490        if use_xla:
491            # For the model running on regular torchxla (baseline), we need the
492            # mark step to send the accumulated graph for compilation.
493            #
494            # For the model running with dynamo/torchxla bridge, in training case,
495            # we need the mark step to send the optimizer graph out for
496            # compilation.
497            xm.mark_step()
498        t_iter_end = time.perf_counter()
499        time_total += t_iter_end - t_iter_begin
500
501    t_0 = time.perf_counter()
502    if use_xla:
503        xm.wait_device_ops()
504    synchronize()
505    t_1 = time.perf_counter()
506    time_total += t_1 - t_0
507    return (time_total, result) if return_result else time_total
508
509
510def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
511    # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
512    # and consumed like `model(**example_inputs)`.
513    # For other benchmarks, example_inputs are formatted as tuple and consumed
514    # like `model(*example_inputs)`.
515    if isinstance(example_inputs, dict):
516        return (), example_inputs
517    else:
518        return tuple(example_inputs), {}
519
520
521def _register_dataclass_output_as_pytree(example_outputs) -> None:
522    # NOTE(angelayi): For huggingface benchmark, some example outputs are
523    # formatted as a dataclass which pytree cannot consume. So we want
524    # to register the pytree implementation here
525    example_outputs_flat = pytree.tree_leaves(example_outputs)
526    output_dataclass_types = [
527        type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out))
528    ]
529    for output_type in output_dataclass_types:
530        from torch._export.utils import register_dataclass_as_pytree_node
531
532        register_dataclass_as_pytree_node(
533            output_type,
534            serialized_type_name=f"{output_type.__module__}.{output_type.__name__}",
535        )
536
537
538class Stats:
539    totals = collections.defaultdict(collections.Counter)
540
541    @classmethod
542    def reset_counters(cls):
543        for k, v in torch._dynamo.utils.counters.items():
544            cls.totals[k].update(v)
545        ok = torch._dynamo.utils.counters["frames"]["ok"]
546        total = torch._dynamo.utils.counters["frames"]["total"]
547        torch._dynamo.utils.counters.clear()
548        return ok, total
549
550    @classmethod
551    def print_summary(cls):
552        for k, v in sorted(cls.totals.items()):
553            lines = "\n  ".join(map(str, v.most_common(50)))
554            print(f"STATS {k}\n  {lines}")
555
556    @classmethod
557    def aot_summary(cls):
558        return [cls.totals["aot_autograd"]["total"], cls.totals["aot_autograd"]["ok"]]
559
560
561def coverage_experiment(args, model_iter_fn, model, example_inputs):
562    """
563    Test operator/model coverage of TorchDynamo and record statistics
564    taken from a profiler.  This target is mainly intended to check
565    correctness.
566
567    Writes to ./coverage.csv
568    """
569    profiler = Profiler()
570    frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
571    with profiler.prof:
572        frozen_model_iter_fn(model, example_inputs)
573    coverage_result = profiler.results()
574    output_csv(
575        output_filename,
576        (
577            "dev",
578            "name",
579            "batch_size",
580            "graphs",
581            "graph_calls",
582            "captured_ops",
583            "total_ops",
584            "pct_ops",
585            "pct_time",
586        ),
587        [
588            current_device,
589            current_name,
590            current_batch_size,
591        ]
592        + coverage_result.tocsv(),
593    )
594    return coverage_result
595
596
597def speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs):
598    """
599    Measure speedups over eager using the trt inference backend. TRT backend is based fx graph
600    generated by torch._dynamo.
601    Writes to ./speedups_fx2trt.csv
602    """
603    return speedup_experiment(args, model_iter_fn, model, example_inputs)
604
605
606def recompile_profiler_experiment(args, model_iter_fn, model, example_inputs):
607    prof = torch._dynamo.utils.CompilerProfiler()
608    opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)(
609        model_iter_fn
610    )
611    opt_model_iter_fn(model, example_inputs)
612    output_csv(
613        output_filename, ["model", "profiler report"], [current_name, prof.report()]
614    )
615    met = prof.get_metrics()
616    guard_failures = len(met["guard_failures"])
617    return [guard_failures]
618
619
620def randomize_input(inputs):
621    if isinstance(inputs, (list, tuple)):
622        return type(inputs)([randomize_input(x) for x in inputs])
623    elif isinstance(inputs, torch.Tensor):
624        if inputs.dtype in (torch.float32, torch.float64):
625            torch._dynamo.utils.counters["randomize_input"]["times"] += 1
626            return torch.randn_like(inputs)
627        elif inputs.dtype == torch.int64:
628            # Note: we can not simply tune integer tensors as follows
629            #   `return torch.randint_like(inputs, high=inputs.max().item())`
630            # This may break some invariants between tensors.
631            # E.g. in embedding lookup case, one tensor is the length
632            # and another is an indices tensor.
633            return inputs
634        else:
635            raise RuntimeError(
636                f"randomize_input need support tensor of type {inputs.dtype}"
637            )
638    else:
639        raise RuntimeError(
640            f"randomize_input can not handle input of type {type(inputs)}"
641        )
642
643
644def maybe_mark_step(args):
645    if args.trace_on_xla:
646        xm.mark_step()
647
648
649def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
650    """
651    Measure speedups over eager.
652
653    Writes to ./speedups.csv
654    """
655    # if args.dynamic_shapes:
656    #     return speedup_experiment_ds(args, model_iter_fn, model, example_inputs)
657
658    timings = np.zeros((args.repeat, 2), np.float64)
659    # if we randomize the input, we should also check the result is correct
660    should_randomize_input = args.randomize_input
661
662    import contextlib
663
664    from torch._inductor.utils import maybe_profile
665
666    @contextlib.contextmanager
667    def maybe_mark_profile(*args, **kwargs):
668        prof: torch.profiler.profile = kwargs.pop("p", None)
669        mark = kwargs.pop("mark", None)
670        if prof:
671            with torch.profiler.record_function(mark):
672                yield
673        else:
674            yield
675
676    times = args.iterations_per_run
677
678    # Use higher tolerance for XLA since XLA cause numerical unstability when
679    # graph size changes
680    tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
681    torch._dynamo.config.repro_tolerance = tolerance
682
683    with maybe_profile(args.export_profiler_trace) as p:
684        if args.export_aot_inductor:
685            frozen_model_iter_fn = export_aot_inductor(
686                model, example_inputs, args.devices[0]
687            )
688        else:
689            frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
690
691        for rep in trange(args.repeat, desc="running benchmark"):
692            inputs = (
693                randomize_input(copy.deepcopy(example_inputs))
694                if should_randomize_input
695                else example_inputs
696            )
697            # need call mark_step to perform the computation
698            # on randomize_input. Otherwise the first call using the
699            # inputs will incur high penalty then the next one.
700            maybe_mark_step(args)
701
702            # interleave the runs to handle frequency scaling and load changes
703            with maybe_mark_profile(p=p, mark="expected"):
704                timings[rep, 0], expected_output = timed(
705                    model,
706                    model_iter_fn,
707                    inputs,
708                    return_result=True,
709                    times=times,
710                    collect_outputs=args.collect_outputs,
711                )
712
713            # call mark_step between the 2 calls to make the comparison fair.
714            maybe_mark_step(args)
715
716            with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
717                args.compiled_autograd
718            ):
719                timings[rep, 1], actual_output = timed(
720                    model,
721                    frozen_model_iter_fn,
722                    inputs,
723                    return_result=True,
724                    times=times,
725                    collect_outputs=args.collect_outputs,
726                )
727
728    if args.export_profiler_trace:
729        name = args.profiler_trace_name + "_" + model.name
730        if hasattr(args, "rank"):
731            name += f"_rank_{args.rank}"
732        name += ".json"
733        name = os.path.join(torch._dynamo.config.base_dir, name)
734        p.export_chrome_trace(name)
735    median = np.median(timings, axis=0)
736    speedup = median[0] / median[1]
737    if args.dump_raw_metrics:
738        np.save(
739            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
740            timings,
741        )
742
743    first_headers = ["dev", "name", "batch_size"]
744    first_fields = [current_device, current_name, current_batch_size]
745    if "tag" in kwargs:
746        first_headers.append("tag")
747        first_fields.append(kwargs["tag"])
748    headers = first_headers + ["speedup", "abs_latency"]
749    row = first_fields + [float(speedup), median[1] * 1000]
750    msg = f"{speedup:.3f}x"
751    if args.baseline:
752        headers.extend(
753            [
754                "baseline",
755                "speedup_vs_baseline",
756            ]
757        )
758        df = pd.read_csv(args.baseline)
759        try:
760            baseline_speedup = df[df["name"] == current_name]["speedup"].item()
761            row.extend([baseline_speedup, speedup / baseline_speedup])
762            msg = f"{baseline_speedup:.3f}x -> {speedup:.3f}x [{speedup / baseline_speedup:.3f}x]"
763        except (KeyError, ZeroDivisionError):
764            row.extend(
765                [
766                    0.0,
767                    0.0,
768                ]
769            )
770    if "compilation_latency" in kwargs:
771        headers += [
772            "compilation_latency",
773            "compression_ratio",
774            "eager_peak_mem",
775            "dynamo_peak_mem",
776        ]
777        row.append(kwargs["compilation_latency"])
778        row.append(kwargs["compression_ratio"])
779        row.append(kwargs["eager_peak_mem"])
780        row.append(kwargs["dynamo_peak_mem"])
781
782    if "cache_lookup_latency" in kwargs:
783        headers.append("cache_lookup_latency")
784        row.append(kwargs["cache_lookup_latency"])
785
786    if "dynamo_stats" in kwargs:
787        for k, v in kwargs["dynamo_stats"].items():
788            headers.append(k)
789            row.append(v)
790    output_csv(
791        output_filename,
792        headers,
793        row,
794    )
795    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
796    assert (
797        output_filename.find(".csv") > 0
798    ), f"expected output_filename to be a .csv, but got {output_filename}"
799    output_csv(
800        output_filename[:-4] + "_compilation_metrics.csv",
801        first_headers + headers,
802        first_fields + data,
803    )
804    return msg
805
806
807def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
808    """
809    Run dynamic shapes benchmarks.
810
811    Requires dynamic shape compatible models, which provide a list of example inputs.
812
813    Warms up using the first input example and then iterates the inputs,
814    measuring (and expecting minimal) variance between the runtime for different examples.
815
816    """
817    timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64)
818
819    if args.repeat > 5:
820        print(
821            f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n"
822        )
823
824    nwarmup = 4
825    for rep in range(args.repeat):
826        # Start each rep fresh, e.g. only warmup on example 0
827        torch._dynamo.reset()
828        optimized_model_iter_fn = optimize_ctx(model_iter_fn)
829        for _ in range(nwarmup):
830            optimized_model_iter_fn(model, example_inputs[0])
831
832        for input_idx, inputs in enumerate(example_inputs):
833            # interleave the runs to handle frequency scaling and load changes
834            timings[rep, input_idx, 0] = timed(
835                model, model_iter_fn, inputs, return_result=False
836            )
837            # different from regular speedup_experiment, we _DO_ want to allow recompilation
838            timings[rep, input_idx, 1] = timed(
839                model, optimized_model_iter_fn, inputs, return_result=False
840            )
841    medians = np.median(timings, axis=0)
842    speedups = list(medians[:, 0] / medians[:, 1])
843    speedups_mean = np.mean(speedups)
844    speedups_median = np.median(speedups)
845    speedups_var = np.var(speedups)
846
847    # TODO this x[0] is not going to work in general but bert only has 1 input
848    shapes = [x[0].shape for x in example_inputs]
849    shape_keys = sorted(set(shapes))
850    shape_speedups = {
851        shape: [
852            it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
853        ]
854        for shape in shape_keys
855    }
856    output_str = (
857        f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}"
858        + "\nSpeedups by shape: "
859        + "\n".join(
860            [
861                f"{shape}: "
862                + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]])
863                for shape in shape_keys
864            ]
865        )
866    )
867    output_csv(
868        output_filename,
869        ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"),
870        [
871            current_device,
872            current_name,
873            current_batch_size,
874            speedups_mean,
875            speedups_median,
876            speedups_var,
877        ],
878    )
879    return output_str
880
881
882@contextlib.contextmanager
883def override_synchronize_with_onnx_iobinding(iobinding):
884    global synchronize
885    prev_synchrnoize = synchronize
886    try:
887        if iobinding is not None:
888
889            def new_synchronize():
890                iobinding.synchronize_inputs()
891                iobinding.synchronize_outputs()
892
893            synchronize = new_synchronize
894        yield
895    finally:
896        synchronize = prev_synchrnoize
897
898
899def speedup_experiment_onnx(
900    args,
901    model_iter_fn,
902    onnx_model: OnnxModel,
903    model,
904    example_inputs,
905    **kwargs,
906):
907    """
908    Measure speedups over eager.
909
910    This function is responsible for the following:
911        1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement.
912        2. Running ORT with OnnxModel.
913
914    Writes to ./{output_filename}, which should be
915        `pathlib.Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv".
916
917    TODO(bowbao): Record export time and export peak memory usage.
918    """
919    timings = np.zeros((args.repeat, 2), np.float64)
920    is_correct = True
921    should_randomize_input = args.randomize_input
922    times = args.iterations_per_run
923
924    def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs):
925        # Goal is to move the iobinding creation outside of the timer function.
926        iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs)
927
928        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
929            onnx_model.run_with_iobinding(iobinding, outputs)
930            if collect_outputs:
931                return outputs
932
933        return onnxrt_model_iter_fn, iobinding
934
935    def create_onnx_fn(onnx_model: OnnxModel, pt_inputs):
936        # NOTE: Making perf comparison fair by moving out the i/o adapting part.
937        # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here.
938        # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement.
939        onnx_inputs = onnx_model.adapt_pt_inputs_to_onnx(pt_inputs)
940
941        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
942            return onnx_model.run_with_onnx_inputs(onnx_inputs)
943
944        return onnxrt_model_iter_fn
945
946    def timed_onnx(model, onnx_model: OnnxModel, inputs):
947        if current_device == "cpu" or onnx_model.is_cpu():
948            onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs)
949            iobinding = None
950        else:
951            onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn(
952                onnx_model, inputs, expected_output
953            )
954        with override_synchronize_with_onnx_iobinding(iobinding):
955            return timed(
956                model,
957                onnxrt_model_iter_fn,
958                inputs,
959                return_result=True,
960                times=times,
961                collect_outputs=args.collect_outputs,
962            )
963
964    # Insert ONNX warm-up
965    inputs = (
966        randomize_input(copy.deepcopy(example_inputs))
967        if should_randomize_input
968        else example_inputs
969    )
970    _, expected_output = timed(
971        model,
972        model_iter_fn,
973        inputs,
974        return_result=True,
975        times=times,
976        collect_outputs=args.collect_outputs,
977    )
978    for _ in range(2):
979        timed_onnx(model, onnx_model, inputs)
980
981    for rep in range(args.repeat):
982        inputs = (
983            randomize_input(copy.deepcopy(example_inputs))
984            if should_randomize_input
985            else example_inputs
986        )
987        if torch.cuda.device_count() > 1:
988            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
989            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
990            # The second one is used for onnx ort.
991            torch.cuda.set_device(0)
992        timings[rep, 0], expected_output = timed(
993            model,
994            model_iter_fn,
995            inputs,
996            return_result=True,
997            times=times,
998            collect_outputs=args.collect_outputs,
999        )
1000        if torch.cuda.device_count() > 1:
1001            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
1002            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
1003            # The second one is used for onnx ort.
1004            torch.cuda.set_device(1)
1005        timings[rep, 1], actual_output = timed_onnx(model, onnx_model, inputs)
1006
1007    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1008    median = np.median(timings, axis=0)
1009    speedup = median[0] / median[1]
1010    if args.dump_raw_metrics:
1011        np.save(
1012            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
1013            timings,
1014        )
1015
1016    headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
1017    row = [
1018        current_device,
1019        current_name,
1020        current_batch_size,
1021        float(speedup),
1022        median[1] * 1000,
1023    ]
1024    if "compilation_latency" in kwargs:
1025        headers = headers + ["compilation_latency", "compression_ratio"]
1026        row.append(kwargs["compilation_latency"])
1027        row.append(kwargs["compression_ratio"])
1028
1029    output_csv(
1030        output_filename,
1031        headers,
1032        row,
1033    )
1034    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
1035    assert (
1036        output_filename.find(".csv") > 0
1037    ), f"expected output_filename to be a .csv, but got {output_filename}"
1038    output_csv(
1039        output_filename[:-4] + "_compilation_metrics.csv",
1040        ["dev", "name", "batch_size"] + headers,
1041        [current_device, current_name, current_batch_size] + data,
1042    )
1043    return format_speedup(speedup, pvalue, is_correct=is_correct)
1044
1045
1046def overhead_experiment(*args, model_iter_fn):
1047    """
1048    Measure overheads of TorchDynamo by running with no backend (only
1049    eager+FX), and reporting speedup/slowdown over eager.
1050
1051    Writes to ./overheads.csv
1052    """
1053    return speedup_experiment(*args, model_iter_fn)
1054
1055
1056def print_fx(gm, example_inputs):
1057    print(gm.graph)
1058    return gm
1059
1060
1061def print_aten_ops(gm, example_inputs):
1062    from functorch.compile import aot_module
1063
1064    def trace_printer(gm, _):
1065        print(gm.graph)
1066        return gm
1067
1068    return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
1069
1070
1071def baselines(models, model_iter_fn, example_inputs, args):
1072    """
1073    Common measurement code across all baseline experiments.
1074    """
1075    models = list(models)
1076    for idx, (name, model) in enumerate(models):
1077        if idx == 0:
1078            result0 = model_iter_fn(model, example_inputs)
1079        elif model is not None:
1080            try:
1081                result = model_iter_fn(model, example_inputs)
1082                if same(result0, result):
1083                    continue
1084                print(name, "is INCORRECT")
1085            except Exception:
1086                log.exception("error checking %s", name)
1087            models[idx] = (name, None)
1088    timings = np.zeros((args.repeat, len(models)), np.float64)
1089    timings.fill(1.0e10)
1090    for rep in range(args.repeat):
1091        for idx, (name, model) in enumerate(models):
1092            if model is not None:
1093                try:
1094                    timings[rep, idx] = timed(model, model_iter_fn, example_inputs)
1095                except Exception:
1096                    pass
1097    pvalue = [
1098        ttest_ind(timings[:, 0], timings[:, i]).pvalue
1099        for i in range(1, timings.shape[1])
1100    ]
1101    median = np.median(timings, axis=0)
1102    speedup = median[0] / median[1:]
1103    for idx, (name, model) in enumerate(models[1:]):
1104        if model is None:
1105            speedup[idx] = 0.0
1106    result = " ".join(
1107        [
1108            format_speedup(s, p, m is not None)
1109            for s, p, m in zip(speedup, pvalue, [m for n, m in models[1:]])
1110        ]
1111    )
1112    output_csv(
1113        output_filename,
1114        ("dev", "name", "batch_size") + tuple(n for n, m in models[1:]),
1115        [current_device, current_name, current_batch_size]
1116        + [f"{x:.4f}" for x in speedup],
1117    )
1118    return result
1119
1120
1121def xla(args, model_iter_fn, model, example_inputs):
1122    xla_dev = xm.xla_device(devkind=current_device)
1123    model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev)
1124    example_inputs_xla = tree_map_only(
1125        torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs
1126    )
1127    for _ in range(3):  # warmup
1128        timed(model, model_iter_fn, example_inputs)
1129        timed(model_xla, model_iter_fn, example_inputs_xla)
1130    timings = np.zeros((args.repeat, 2), np.float64)
1131    timings.fill(1.0e10)
1132    for rep in range(args.repeat):
1133        timings[rep, 0] = timed(model, model_iter_fn, example_inputs)
1134        timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla)
1135
1136    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1137    time_baseline, time_xla = np.median(timings, axis=0)
1138    speedup = time_baseline / time_xla
1139    output_csv(
1140        output_filename,
1141        ("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"),
1142        [
1143            current_device,
1144            current_name,
1145            current_batch_size,
1146            speedup,
1147            time_baseline,
1148            time_xla,
1149        ],
1150    )
1151    return format_speedup(speedup, pvalue)
1152
1153
1154def try_script(model, example_inputs):
1155    try:
1156        return torch.jit.script(model)
1157    except Exception:
1158        return None
1159
1160
1161class AOTInductorModelCache:
1162    cache = dict()
1163
1164    @classmethod
1165    def load(cls, model, example_inputs, device):
1166        import torch._inductor
1167        import torch.export._trace
1168
1169        key = weakref.ref(model)
1170        if key not in cls.cache:
1171            # Register the output dataclass to pytree
1172            example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1173            with torch.no_grad():
1174                # copy.deepcopy is required to prevent any surprising side-effect,
1175                # see https://github.com/pytorch/pytorch/issues/113029
1176                example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
1177
1178            if pytree._is_namedtuple_instance(example_outputs):
1179                typ = type(example_outputs)
1180                pytree._register_namedtuple(
1181                    typ,
1182                    serialized_type_name=f"{typ.__module__}.{typ.__name__}",
1183                )
1184            else:
1185                _register_dataclass_output_as_pytree(example_outputs)
1186
1187            # TODO(angelayi): change this to predispatch
1188            # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing
1189            # to predispatch to avoid performance regressions
1190            gm = torch.export._trace._export_to_torch_ir(
1191                model,
1192                example_args,
1193                example_kwargs,
1194            )
1195            with torch.no_grad():
1196                so_path = torch._inductor.aot_compile(
1197                    gm, example_args, example_kwargs
1198                )  # type: ignore[arg-type]
1199
1200            cls.cache[key] = torch._export.aot_load(so_path, device)
1201
1202        return cls.cache[key]
1203
1204
1205def export(model, example_inputs):
1206    example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1207    example_outputs = model(*example_args, **example_kwargs)
1208    _register_dataclass_output_as_pytree(example_outputs)
1209
1210    ep = torch.export.export(model, example_args, example_kwargs)
1211
1212    def opt_export(_, example_inputs):
1213        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1214        return ep(*example_args, **example_kwargs)
1215
1216    return opt_export
1217
1218
1219def export_aot_inductor(model, example_inputs, device):
1220    optimized = AOTInductorModelCache.load(model, example_inputs, device)
1221
1222    def opt_aot_inductor(_, example_inputs, collect_outputs=False):
1223        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1224        return optimized(*example_args, **example_kwargs)
1225
1226    return opt_aot_inductor
1227
1228
1229def download_retry_decorator(download_fn):
1230    """
1231    Decorator function for applying retry logic to a download function.
1232
1233    The wrapped function will be called up to 5 times and raises an exception if the function fails each time.
1234    After each unsuccessful attempt, there is a delay before the next attempt, which is increased linearly with the number of tries.
1235
1236    Usage:
1237    @download_retry_decorator
1238    def download_function(model_name: str):
1239        # download logic goes here
1240    """
1241
1242    @functools.wraps(download_fn)
1243    def wrapper(self, *args, **kwargs) -> Any:
1244        tries = 0
1245        total_allowed_tries = MAX_DOWNLOAD_ATTEMPTS
1246        while tries <= total_allowed_tries:
1247            try:
1248                model = download_fn(self, *args, **kwargs)
1249                return model
1250            except Exception as e:
1251                tries += 1
1252                if tries <= total_allowed_tries:
1253                    wait = tries * 30
1254                    print(
1255                        f"Failed to load model: {e}. Trying again ({tries}/{total_allowed_tries}) after {wait}s"
1256                    )
1257                    time.sleep(wait)
1258                else:
1259                    raise RuntimeError(  # noqa: B904
1260                        f"Failed to load model '{args}' with following error(s): {str(e)}."
1261                    )
1262
1263    return wrapper
1264
1265
1266class OnnxModel(abc.ABC):
1267    TORCH_TO_NUMPY_DTYPE = {
1268        torch.float16: np.float16,
1269        torch.float32: np.float32,
1270        torch.float64: np.float64,
1271        torch.uint8: np.uint8,
1272        torch.int8: np.int8,
1273        torch.int16: np.int16,
1274        torch.int32: np.int32,
1275        torch.int64: np.longlong,
1276        torch.bool: np.bool_,
1277    }
1278
1279    _COMPILER_NAME: str
1280
1281    def __init__(
1282        self,
1283        output_directory,
1284        model,
1285        example_inputs,
1286        dynamic_shapes: bool,
1287        copy_before_export: bool = False,
1288    ):
1289        model_name = current_name
1290        self.copy_before_export = copy_before_export
1291        self.model_dir = self._generate_onnx_model_directory(
1292            output_directory, self._COMPILER_NAME, model_name
1293        )
1294        self.model_path = str(
1295            self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
1296        )
1297
1298    def _determine_deepcopy_target_device(self):
1299        if current_device == "cpu":
1300            target_device = "cpu"
1301        else:
1302            if torch.cuda.device_count() > 1:
1303                # Copy to another cuda device to avoid OOM.
1304                target_device = "cuda:1"
1305            else:
1306                target_device = "cuda"
1307        return target_device
1308
1309    def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
1310        # Deepcopy model before export to avoid modification to baseline model.
1311        # To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
1312        model_device = next(model.parameters()).device
1313        model.to("cpu")
1314        model_copy = copy.deepcopy(model).to(target_device)
1315        model.to(model_device)
1316
1317        target_device_example_inputs = tree_map_only(
1318            torch.Tensor, lambda x: x.to(device=target_device), example_inputs
1319        )
1320
1321        return model_copy, target_device_example_inputs
1322
1323    @classmethod
1324    def _generate_onnx_model_directory(
1325        cls, output_directory: str, compiler_name: str, model_name: str
1326    ) -> pathlib.Path:
1327        model_path = pathlib.Path(
1328            output_directory,
1329            ".onnx_models",
1330            model_name,
1331            compiler_name,
1332        )
1333        if model_path.exists() and model_path.is_dir():
1334            shutil.rmtree(model_path)
1335        model_path.mkdir(parents=True, exist_ok=True)
1336        return model_path
1337
1338    @abc.abstractmethod
1339    def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
1340        ...
1341
1342    @abc.abstractmethod
1343    def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
1344        ...
1345
1346    def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]:
1347        pt_inputs = self.format_pt_inputs(pt_inputs)
1348        return {
1349            ort_input.name: pt_input.cpu().numpy()
1350            for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs)
1351        }
1352
1353    def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any:
1354        pt_outputs = [
1355            torch.from_numpy(onnx_output).to(current_device)
1356            for onnx_output in onnx_outputs
1357        ]
1358        if len(pt_outputs) == 1:
1359            return pt_outputs[0]
1360        return pt_outputs
1361
1362    def _init_ort_session(self, model_path: str):
1363        import onnxruntime
1364
1365        if current_device == "cpu":
1366            ort_providers = ["CPUExecutionProvider"]
1367        else:
1368            # NOTE(bowbao): Reduce OOM by running ORT on another gpu.
1369            # TODO(bowbao): This works to avoid OOM, but performance is surprisingly very bad.
1370            cuda_provider_options = {
1371                "device_id": 1 if torch.cuda.device_count() > 1 else 0,
1372            }
1373            ort_providers = [("CUDAExecutionProvider", cuda_provider_options)]
1374        session_options = onnxruntime.SessionOptions()
1375        session_options.log_severity_level = 3  # Error
1376
1377        ort_session = onnxruntime.InferenceSession(
1378            self.model_path,
1379            providers=ort_providers,
1380            sess_options=session_options,
1381        )
1382        return ort_session
1383
1384    def is_cpu(self) -> bool:
1385        return self.onnx_session.get_providers()[0] == "CPUExecutionProvider"
1386
1387    def cpu(self) -> Self:
1388        self.onnx_session.set_providers(["CPUExecutionProvider"])
1389        return self
1390
1391    def create_outputs(self, *example_outputs):
1392        return tuple(torch.empty_like(x) for x in example_outputs)
1393
1394    def create_iobinding(self, pt_inputs, example_outputs):
1395        pt_inputs = self.format_pt_inputs(pt_inputs)
1396        example_outputs = self.format_pt_outputs(example_outputs)
1397
1398        iobinding = self.onnx_session.io_binding()
1399        args = [arg.contiguous() for arg in pt_inputs]
1400        for ort_input, arg in zip(self.onnx_session.get_inputs(), args):
1401            # NOTE: Run ORT on another cuda device to reduce OOM.
1402            if torch.cuda.device_count() > 1:
1403                arg = arg.detach().to("cuda:1")
1404            device = arg.device
1405            iobinding.bind_input(
1406                ort_input.name,
1407                device.type,
1408                device.index or 0,
1409                self.TORCH_TO_NUMPY_DTYPE[arg.dtype],
1410                arg.size(),
1411                arg.data_ptr(),
1412            )
1413
1414        outputs = self.create_outputs(*example_outputs)
1415        for ort_output, output in zip(self.onnx_session.get_outputs(), outputs):
1416            if torch.cuda.device_count() > 1:
1417                output = output.detach().to("cuda:1")
1418            device = output.device
1419            iobinding.bind_output(
1420                ort_output.name,
1421                device.type,
1422                device.index or 0,
1423                self.TORCH_TO_NUMPY_DTYPE[output.dtype],
1424                output.size(),
1425                output.data_ptr(),
1426            )
1427        return iobinding, outputs
1428
1429    def run_with_iobinding(self, iobinding, outputs):
1430        # 'outputs' are torch empty tensors binded to 'iobinding'.
1431        self.onnx_session.run_with_iobinding(iobinding)
1432        return outputs
1433
1434    def run_with_onnx_inputs(self, onnx_inputs):
1435        return self.onnx_session.run(None, onnx_inputs)
1436
1437    @classmethod
1438    def save_tensor_data(cls, numpy_tensor, output_path):
1439        from onnx import numpy_helper
1440
1441        proto_tensor = numpy_helper.from_array(numpy_tensor)
1442        with open(output_path, "wb") as f:
1443            f.write(proto_tensor.SerializeToString())
1444
1445    def run_and_serialize_inputs_outputs(self, pt_inputs):
1446        test_data_dir = self.model_dir / "test_data_set_0"
1447        test_data_dir.mkdir(parents=True, exist_ok=True)
1448
1449        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1450        for i, onnx_input in enumerate(onnx_inputs.values()):
1451            self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
1452
1453        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1454
1455        for i, onnx_output in enumerate(onnx_outputs):
1456            self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
1457
1458        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1459
1460    def run(self, pt_inputs):
1461        # NOTE: For CUDA performance testing, use `run_with_iobinding` to exclude memory
1462        # copying overhead for inputs/outputs between cpu and gpu.
1463        # Otherwise perf number is inaccurate.
1464        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1465        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1466        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1467
1468
1469class OnnxModelFromTorchScript(OnnxModel):
1470    """TorchScript based onnx export. `torch.onnx.export`
1471
1472    TODO(bowbao):
1473    * large model export failed.
1474          Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is
1475          smaller than 2GB.
1476    * OOM on slightly larger model.
1477          Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to
1478          cuda:1, however ORT perf drop significantly.
1479          For now running everything with batch_size 1 set in launch script.
1480    """
1481
1482    _COMPILER_NAME = "torchscript"
1483
1484    def __init__(
1485        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1486    ):
1487        if dynamic_shapes:
1488            raise NotImplementedError("NYI dynamic shapes for OnnxModelFromTorchScript")
1489        super().__init__(
1490            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1491        )
1492        self._export(
1493            model,
1494            example_inputs,
1495            self.model_path,
1496            opset_version=17,
1497            do_constant_folding=False,
1498            verbose=False,
1499        )
1500        self.onnx_session = self._init_ort_session(self.model_path)
1501
1502    def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
1503        if self.copy_before_export:
1504            # Deepcopy model before export to avoid modification to baseline model.
1505            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1506                model, example_inputs, self._determine_deepcopy_target_device()
1507            )
1508
1509        # Hack for huggingface models (kwargs only).
1510        if isinstance(example_inputs, dict):
1511
1512            class WrapperModel(torch.nn.Module):
1513                def __init__(self, model, keys):
1514                    super().__init__()
1515                    self.model = model
1516                    self.keys = keys
1517
1518                def forward(self, *args):
1519                    return self.model(**dict(zip(self.keys, args)))
1520
1521            model = WrapperModel(model, list(example_inputs.keys()))
1522
1523        torch.onnx.export(
1524            model,
1525            self.format_pt_inputs(example_inputs),
1526            output_path,
1527            **kwargs,
1528        )
1529
1530    def format_pt_inputs(self, pt_inputs):
1531        # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary,
1532        # and consumed like `model(**pt_inputs)`.
1533        # For other benchmarks, pt_inputs are formatted as tuple and consumed
1534        # like `model(*pt_inputs)`.
1535        if isinstance(pt_inputs, dict):
1536            pt_inputs = list(pt_inputs.values())
1537        if isinstance(pt_inputs, torch.Tensor):
1538            pt_inputs = (pt_inputs,)
1539        return tuple(arg.contiguous() for arg in pt_inputs)
1540
1541    def format_pt_outputs(self, pt_outputs):
1542        if isinstance(pt_outputs, torch.Tensor):
1543            pt_outputs = (pt_outputs,)
1544
1545        pt_outputs = pytree.tree_leaves(pt_outputs)
1546
1547        # Hack for huggingface model outputs
1548        try:
1549            from transformers import modeling_outputs
1550        except ImportError:
1551            pass
1552        else:
1553
1554            def _to_tuple(x):
1555                if isinstance(x, modeling_outputs.ModelOutput):
1556                    return x.to_tuple()
1557                return x
1558
1559            pt_outputs = pytree.tree_map(_to_tuple, pt_outputs)
1560            pt_outputs = pytree.tree_leaves(pt_outputs)
1561
1562        return pt_outputs
1563
1564
1565class OnnxModelFromDynamo(OnnxModel):
1566    """Dynamo and Fx based export. `torch.onnx.dynamo_export`."""
1567
1568    _COMPILER_NAME = "dynamo"
1569
1570    def __init__(
1571        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1572    ):
1573        super().__init__(
1574            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1575        )
1576        self._dynamic_shapes = dynamic_shapes
1577        self._onnx_program = self._export(model, example_inputs, self.model_path)
1578        # Clear the model proto to save memory.
1579        # The model proto is saved to disk and no longer needed from `onnx_program`.
1580        # `onnx_program` is kept for i/o adapter usage.
1581        self._onnx_program.model_proto.Clear()
1582        self.onnx_session = self._init_ort_session(self.model_path)
1583
1584    def _export(
1585        self, model, example_inputs, output_path: str
1586    ) -> torch.onnx.ONNXProgram:
1587        if self.copy_before_export:
1588            # Deepcopy model before export to avoid modification to baseline model.
1589            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1590                model, example_inputs, self._determine_deepcopy_target_device()
1591            )
1592
1593        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1594        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1595        onnx_program = torch.onnx.dynamo_export(
1596            model, *example_args, **example_kwargs, export_options=options
1597        )
1598
1599        onnx_program.save(output_path)
1600        return onnx_program
1601
1602    def format_pt_inputs(self, pt_inputs):
1603        pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs)
1604        return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)
1605
1606    def format_pt_outputs(self, pt_outputs):
1607        return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs)
1608
1609
1610class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
1611    """Dynamo and Fx based export, with AOT inline post export. `torch.onnx.dynamo_export`."""
1612
1613    _COMPILER_NAME = "dynamo_aot_inline"
1614
1615    def _export(
1616        self, model, example_inputs, output_path: str
1617    ) -> torch.onnx.ONNXProgram:
1618        if self.copy_before_export:
1619            # Deepcopy model before export to avoid modification to baseline model.
1620            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1621                model, example_inputs, self._determine_deepcopy_target_device()
1622            )
1623
1624        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1625        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1626        onnx_program = torch.onnx.dynamo_export(
1627            model, *example_args, **example_kwargs, export_options=options
1628        )
1629        # Apply AOT inline post export.
1630        # Requires onnx >= 1.15
1631        import onnx
1632        import onnx.inliner
1633
1634        # Workaround for inliner not supporting with models larger than 2GB.
1635        # Save model to disk first separating out external data,
1636        # and load back without external data for inliner to work on.
1637        model_proto = onnx_program.model_proto
1638        onnx.save_model(model_proto, output_path, save_as_external_data=True)
1639        model_proto = onnx.load(output_path, load_external_data=False)
1640        model_proto = onnx.inliner.inline_local_functions(model_proto)
1641        onnx.save_model(model_proto, output_path)
1642        return onnx_program
1643
1644
1645class OnnxModelFromDynamoAotOptimize(OnnxModelFromDynamo):
1646    """Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`."""
1647
1648    _COMPILER_NAME = "dynamo_aot_optimize"
1649
1650    def _export(
1651        self, model, example_inputs, output_path: str
1652    ) -> torch.onnx.ONNXProgram:
1653        if self.copy_before_export:
1654            # Deepcopy model before export to avoid modification to baseline model.
1655            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1656                model, example_inputs, self._determine_deepcopy_target_device()
1657            )
1658
1659        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1660        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1661        export_output = torch.onnx.dynamo_export(
1662            model, *example_args, **example_kwargs, export_options=options
1663        )
1664
1665        import onnx
1666        from onnxscript.rewriter.onnxruntime import rewrite
1667
1668        model_proto = rewrite(export_output.model_proto)
1669        onnx.save_model(
1670            model_proto,
1671            output_path,
1672            save_as_external_data=True,
1673            all_tensors_to_one_file=True,
1674        )
1675
1676        return export_output
1677
1678
1679class _OnnxPatch:
1680    @classmethod
1681    def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
1682        """Patch non-tensor outputs to make them comparable with the correct result.
1683
1684        ONNX model always returns a flat tuple of tensors, but the PyTorch model outputs
1685        `correct_result` and `fp64_outputs` can be arbitrary types. This function normalizes
1686        the outputs to make them comparable with the ONNX model output.
1687        """
1688        try:
1689            from transformers import modeling_outputs
1690        except ImportError:
1691            has_transformers = False
1692        else:
1693            has_transformers = True
1694
1695        if has_transformers and isinstance(
1696            correct_result, modeling_outputs.ModelOutput
1697        ):
1698            correct_result = correct_result.to_tuple()
1699            fp64_outputs = fp64_outputs.to_tuple() if fp64_outputs is not None else None
1700        elif type(correct_result).__name__ in (
1701            "MaskedLMOutput",
1702            "Seq2SeqLMOutput",
1703            "CausalLMOutputWithCrossAttentions",
1704            "LongformerMaskedLMOutput",
1705            "Instances",
1706            "SquashedNormal",
1707            "Boxes",
1708            "Normal",
1709            "TanhTransform",
1710            "Foo",
1711            "Variable",
1712        ):
1713            # Copied from `same` function in `torch._dynamo.utils`
1714            correct_result = [
1715                value
1716                for key in correct_result.__dict__.keys()
1717                if (value := getattr(correct_result, key)) is not None
1718            ]
1719            fp64_outputs = (
1720                [
1721                    value
1722                    for key in fp64_outputs.__dict__.keys()
1723                    if (value := getattr(fp64_outputs, key)) is not None
1724                ]
1725                if fp64_outputs is not None
1726                else None
1727            )
1728
1729        # Flatten nested tuple of tensors, i.e. past_key_values
1730        correct_result = pytree.tree_leaves(correct_result)
1731        # Hack to put results from different runs on same device.
1732        # This is needed for ONNX CPU fallback benchmark, where PyTorch eager is run on GPU.
1733        # Assuming outputs from a single run are always on same device!
1734        devices = [x.device for x in correct_result if isinstance(x, torch.Tensor)]
1735        assert devices and all(
1736            x == devices[0] for x in devices
1737        ), "All tensors must be on same device!"
1738        device = devices[0]
1739        new_result = pytree.tree_leaves(new_result)
1740        new_result = pytree.tree_map(
1741            lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
1742            new_result,
1743        )
1744        fp64_outputs = pytree.tree_leaves(fp64_outputs)
1745
1746        return correct_result, new_result, fp64_outputs
1747
1748
1749@dataclasses.dataclass
1750class OnnxExportErrorRow:
1751    device: str
1752    model_name: str
1753    batch_size: int
1754    rule_id: Optional[str] = None
1755    rule_name: Optional[str] = None
1756    diagnostic_level: Optional[str] = None
1757    diagnostic_message: Optional[str] = None
1758    exception_type_name: Optional[str] = None
1759    exception_message: Optional[str] = None
1760
1761    def __post_init__(self):
1762        assert (
1763            self.rule_id is not None
1764            and self.rule_name is not None
1765            and self.diagnostic_level is not None
1766            and self.diagnostic_message is not None
1767        ) or self.exception_type_name, (
1768            "Either rule_id, rule_name, diagnostic_level and diagnostic_message "
1769            "must be set or exception_type_name must be set"
1770        )
1771
1772    @property
1773    def headers(self) -> List[str]:
1774        return [field.name for field in dataclasses.fields(self)]
1775
1776    @property
1777    def row(self) -> List[str]:
1778        return [getattr(self, field.name) for field in dataclasses.fields(self)]
1779
1780
1781class OnnxExportErrorParser:
1782    def __init__(self, device: str, model_name: str, batch_size: int):
1783        self.device = device
1784        self.model_name = model_name
1785        self.batch_size = batch_size
1786
1787    def _qualified_exception_class_name(self, exception: Exception) -> str:
1788        if exception.__class__.__module__ == "builtins":
1789            return exception.__class__.__name__
1790        return f"{exception.__class__.__module__}.{exception.__class__.__name__}"
1791
1792    def parse_diagnostic_context(
1793        self,
1794        diagnostic_context: diagnostics.DiagnosticContext,
1795    ) -> Generator[OnnxExportErrorRow, Any, Any]:
1796        from torch.onnx._internal.fx import diagnostics
1797
1798        for diagnostic in diagnostic_context.diagnostics:
1799            if diagnostic.level >= diagnostics.levels.ERROR:
1800                yield OnnxExportErrorRow(
1801                    device=self.device,
1802                    model_name=self.model_name,
1803                    batch_size=self.batch_size,
1804                    rule_id=diagnostic.rule.id,
1805                    rule_name=diagnostic.rule.name,
1806                    diagnostic_level=diagnostic.level.name,
1807                    diagnostic_message=diagnostic.message,
1808                )
1809
1810    def parse_exception(self, exception: Exception) -> OnnxExportErrorRow:
1811        return OnnxExportErrorRow(
1812            device=self.device,
1813            model_name=self.model_name,
1814            batch_size=self.batch_size,
1815            exception_type_name=self._qualified_exception_class_name(exception),
1816            exception_message=str(exception),
1817        )
1818
1819
1820@dataclasses.dataclass
1821class OnnxContext:
1822    onnx_model: Optional[OnnxModel] = None
1823
1824
1825def optimize_onnx_ctx(
1826    output_directory: str,
1827    onnx_model_cls: Type[OnnxModel],
1828    run_n_iterations: Callable,
1829    dynamic_shapes: bool = False,
1830    copy_before_export: bool = False,
1831) -> Callable:
1832    # NOTE(bowbao): This function creates and returns the onnx version of 'run_n_iterations',
1833    # which does the following:
1834    #   1. Export and cache model.
1835    #   2. Create iobinding for ORT.
1836    #   3. Run ORT for n iterations.
1837    # The cached model is stored in 'context' under the returned callable.
1838    context = OnnxContext()
1839    test_data_dumped = False
1840
1841    def run_n_iterations_onnx(model, inputs, n=2):
1842        from torch.onnx._internal import exporter
1843        from torch.onnx._internal.fx import diagnostics
1844
1845        # NOTE(bowbao): Capture all export & ort errors and diagnostics.
1846        # Serialize to csv, to be parsed and summarized later by '._onnx/reporter.py'.
1847        # TODO: Accuracy mismatch is not reported here in csv.
1848        assert (
1849            output_filename.find(".csv") > 0
1850        ), f"expected output_filename to be a .csv, but got {output_filename}"
1851        output_error_filename = output_filename[:-4] + "_export_error.csv"
1852        parser = OnnxExportErrorParser(current_device, current_name, current_batch_size)
1853        try:
1854            nonlocal context
1855            if context.onnx_model is None:
1856                context.onnx_model = onnx_model_cls(
1857                    output_directory,
1858                    model,
1859                    copy.deepcopy(inputs),
1860                    dynamic_shapes=dynamic_shapes,
1861                    copy_before_export=copy_before_export,
1862                )
1863            onnx_model = context.onnx_model
1864
1865            for _ in range(n):
1866                nonlocal test_data_dumped
1867                if not test_data_dumped:
1868                    # Serializes inputs and outputs to .pb files for further offline analysis.
1869                    # Due to this, this function is not and should not be used for perf measurement.
1870                    outputs = onnx_model.run_and_serialize_inputs_outputs(inputs)
1871                    test_data_dumped = True
1872                else:
1873                    outputs = onnx_model.run(inputs)
1874            return outputs
1875        except exporter.OnnxExporterError as e:
1876            # `torch.onnx.dynamo_export` raises error that encloses diagnostics.
1877            diagnostic_context = e.onnx_program.diagnostic_context
1878            for parsed_error in parser.parse_diagnostic_context(diagnostic_context):
1879                output_csv(
1880                    output_error_filename, parsed_error.headers, parsed_error.row
1881                )
1882            if context.onnx_model is not None:
1883                e.onnx_program.save_diagnostics(
1884                    f"{context.onnx_model.model_dir}/"
1885                    f"{current_onnx_compiler}_{current_name}_{current_device}.sarif"
1886                )
1887
1888            # Check also the raw exception that caused export failure.
1889            # Skip if it is already analyzed by diagnostics.
1890            cause_of_exception = e.__cause__
1891            if not isinstance(
1892                cause_of_exception, diagnostics.RuntimeErrorWithDiagnostic
1893            ):
1894                parsed_error = parser.parse_exception(cause_of_exception)
1895                output_csv(
1896                    output_error_filename, parsed_error.headers, parsed_error.row
1897                )
1898            raise
1899        except Exception as e:
1900            # `torch.onnx.export` errors.
1901            # ORT errors.
1902            parsed_error = parser.parse_exception(e)
1903            output_csv(output_error_filename, parsed_error.headers, parsed_error.row)
1904            raise
1905
1906    run_n_iterations_onnx.context = context
1907
1908    return run_n_iterations_onnx
1909
1910
1911def read_batch_size_from_file(args, filename, model_name):
1912    batch_size = None
1913    if os.path.exists("benchmarks"):
1914        filename = os.path.join("benchmarks", filename)
1915    assert os.path.exists(filename), filename
1916    with open(filename) as f:
1917        lines = f.readlines()
1918        lines = [i.split(",") for i in lines if len(i.strip()) > 0]
1919        for val in lines:
1920            cur_name, b = val
1921            if model_name == cur_name:
1922                batch_size = int(b)
1923    if batch_size is None:
1924        log.warning("Could not find batch size for %s", model_name)
1925    elif batch_size == -1:
1926        raise RuntimeError(
1927            f"Batch size is unset for {model_name} in {args.batch_size_file}"
1928        )
1929    print(f"batch size: {batch_size}")
1930    return batch_size
1931
1932
1933class TimeOutException(Exception):
1934    pass
1935
1936
1937def alarm_handler(signum, frame):
1938    raise TimeOutException
1939
1940
1941def exit_after(s):
1942    """
1943    Decorator to raise TimeoutException if the fn is taking more than s seconds
1944    to run.
1945    """
1946
1947    def outer(fn):
1948        def inner(*args, **kwargs):
1949            signal.signal(signal.SIGALRM, alarm_handler)
1950            signal.alarm(s)
1951            try:
1952                result = fn(*args, **kwargs)
1953            finally:
1954                signal.alarm(0)
1955            return result
1956
1957        return inner
1958
1959    return outer
1960
1961
1962def get_peak_memory():
1963    return torch.cuda.max_memory_allocated() / 10**9
1964
1965
1966def null_experiment(args, model_iter_fn, model, example_inputs):
1967    """
1968    A no-op experiment useful for making sure TorchBenchark alone works properly.
1969    """
1970
1971    return []
1972
1973
1974def cast_to(dtype, model, inputs):
1975    # cast model and inputs to fp16
1976    if dtype == torch.float16:
1977        model = model.half()
1978    else:
1979        model = model.to(dtype)
1980
1981    inputs = tree_map(
1982        lambda x: x.to(dtype)
1983        if isinstance(x, torch.Tensor) and x.is_floating_point()
1984        else x,
1985        inputs,
1986    )
1987    return model, inputs
1988
1989
1990def cast_to_bf16(model, inputs):
1991    return cast_to(torch.bfloat16, model, inputs)
1992
1993
1994def cast_to_fp16(model, inputs):
1995    return cast_to(torch.float16, model, inputs)
1996
1997
1998def cast_to_fp64(model, inputs):
1999    return cast_to(torch.float64, model, inputs)
2000
2001
2002def cast_to_fp32(model, inputs):
2003    return cast_to(torch.float32, model, inputs)
2004
2005
2006class DummyGradScaler:
2007    def scale(self, loss):
2008        return loss
2009
2010
2011def get_dynamo_stats():
2012    # TODO: consider deepcopy'ing the entire counters struct and
2013    # adding a helper to do subtraction on it
2014    return collections.Counter(
2015        {
2016            "calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
2017            "unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
2018            "graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
2019            # NB: The plus removes zero counts
2020            "unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
2021            "autograd_captures": torch._dynamo.utils.counters["compiled_autograd"][
2022                "captures"
2023            ],
2024            "autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"][
2025                "compiles"
2026            ],
2027            "cudagraph_skips": torch._dynamo.utils.counters["inductor"][
2028                "cudagraph_skips"
2029            ],
2030        }
2031    )
2032
2033
2034@contextmanager
2035def maybe_init_distributed(should_init_distributed, rank, world_size, port="6789"):
2036    try:
2037        if should_init_distributed:
2038            torch.cuda.set_device(rank)
2039            os.environ["MASTER_ADDR"] = "localhost"
2040            os.environ["MASTER_PORT"] = port
2041            torch.distributed.init_process_group(
2042                "nccl", rank=rank, world_size=world_size
2043            )
2044        yield
2045    finally:
2046        if should_init_distributed:
2047            torch.distributed.destroy_process_group()
2048
2049
2050@contextmanager
2051def maybe_snapshot_memory(should_snapshot_memory, suffix):
2052    # Enables Memory Snapshot tool for memory deep dives:
2053    # https://pytorch.org/blog/understanding-gpu-memory-1/
2054    try:
2055        if should_snapshot_memory:
2056            torch.cuda.memory._record_memory_history(max_entries=100000)
2057        yield
2058    finally:
2059        if should_snapshot_memory:
2060            try:
2061                torch.cuda.memory._dump_snapshot(
2062                    os.path.join(
2063                        torch._dynamo.config.base_dir,
2064                        f"{output_filename.rstrip('.csv')}_{suffix}.pickle",
2065                    )
2066                )
2067            except Exception as e:
2068                logging.error("Failed to save memory snapshot, %s", e)
2069
2070            torch.cuda.memory._record_memory_history(enabled=None)
2071
2072
2073class BenchmarkRunner:
2074    def __init__(self):
2075        self.model_iter_fn = None
2076        self.grad_scaler = DummyGradScaler()
2077        self.autocast = contextlib.nullcontext
2078        self.autocast_arg = {}
2079        self.optimizer = None
2080        self._args = None
2081
2082    def setup_amp(self, current_device=None):
2083        if self.args.only in self.fp32_only_models:
2084            return
2085
2086        devices = [current_device] if current_device else self.args.devices
2087        if self.args.amp:
2088            # AMP training can lead to small loss values which can undeflow
2089            # gradient values returning in zero gradients. To solve this
2090            # problem, PyTorch introduces GradScaler. GradScaler is a stateful
2091            # structure, that scales the loss values to prevent underflow. Loss
2092            # values are big at the beginning of training (therefore not
2093            # requiring scaling), while loss value tends to be small as network
2094            # starts getting better (requiring scaling). GradScaler manages all
2095            # of this fine tuning, checking the gradients are turning to inf,
2096            # discarding such batches.
2097
2098            # Since we are not running a long iteration, default value of
2099            # init_scale 65536 is going to turn all gradients to inf. Therefore,
2100            # we just use a init_scale of 2.0 for benchmarking purpose.
2101
2102            # Disabling Gradscaler because
2103            #  1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
2104            #  2) Current setup shares grad_scaler for eager and dynamo model,
2105            #  which is bad as Gradscaler has state and can adjust the scaling
2106            #  factor between eager and dynamo run, making accuracy check
2107            #  harder.
2108            # self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
2109            self.autocast = functools.partial(
2110                torch.amp.autocast, device_type=devices[0]
2111            )
2112            if self.args.amp_dtype:
2113                amp_dtype = (
2114                    torch.float16
2115                    if self.args.amp_dtype == "float16"
2116                    else torch.bfloat16
2117                )
2118                self.autocast_arg["dtype"] = amp_dtype
2119
2120    def init_optimizer(self, name, device, params):
2121        if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
2122            if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD:
2123                self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
2124                # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling
2125                # this optimizer because it is a single foreach add, and increases compile time.
2126                # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower.
2127                # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873
2128                # Autotuning: https://github.com/pytorch/pytorch/issues/117447
2129                self.optimizer.step = torch._dynamo.disable(self.optimizer.step)
2130            else:
2131                self.optimizer = torch.optim.Adam(
2132                    params, lr=0.01, capturable=True, foreach=True
2133                )
2134        else:
2135            self.optimizer = None
2136
2137    @property
2138    def args(self):
2139        return self._args
2140
2141    @args.setter
2142    def args(self, args):
2143        self._args = args
2144
2145    @property
2146    def skip_models(self):
2147        return set()
2148
2149    @property
2150    def skip_models_for_cuda(self):
2151        return set()
2152
2153    @property
2154    def skip_models_for_cpu(self):
2155        return set()
2156
2157    @property
2158    def skip_models_for_freezing(self):
2159        return set()
2160
2161    @property
2162    def slow_models(self):
2163        return set()
2164
2165    @property
2166    def very_slow_models(self):
2167        return set()
2168
2169    @property
2170    def non_deterministic_models(self):
2171        return set()
2172
2173    @property
2174    def fp32_only_models(self):
2175        return set()
2176
2177    @property
2178    def force_amp_for_fp16_bf16_models(self):
2179        return set()
2180
2181    @property
2182    def force_fp16_for_bf16_models(self):
2183        return set()
2184
2185    @property
2186    def skip_not_suitable_for_training_models(self):
2187        return set()
2188
2189    @property
2190    def failing_torchinductor_models(self):
2191        return set()
2192
2193    @property
2194    def failing_fx2trt_models(self):
2195        return set()
2196
2197    @property
2198    def skip_accuracy_checks_large_models_dashboard(self):
2199        return set()
2200
2201    @property
2202    def skip_accuracy_check_as_eager_non_deterministic(self):
2203        return set()
2204
2205    @property
2206    def skip_multiprocess_models(self):
2207        return set()
2208
2209    @property
2210    def skip_models_due_to_control_flow(self):
2211        return set()
2212
2213    @property
2214    def guard_on_nn_module_models(self):
2215        return set()
2216
2217    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
2218        raise NotImplementedError
2219
2220    @property
2221    def equal_nan(self):
2222        equal_nan = True
2223        if self.args.float32:
2224            equal_nan = False
2225        return equal_nan
2226
2227    def iter_models(self, args):
2228        for model_name in self.iter_model_names(args):
2229            for device in args.devices:
2230                try:
2231                    yield self.load_model(
2232                        device,
2233                        model_name,
2234                        batch_size=args.batch_size,
2235                    )
2236                except NotImplementedError:
2237                    continue  # bad benchmark implementation
2238
2239    def deepcopy_model(self, model):
2240        return copy.deepcopy(model)
2241
2242    def cast_based_on_args(self, model, example_inputs):
2243        if self.args.float32 or self.args.only in self.fp32_only_models:
2244            if not self.args.float32:
2245                log.warning("Model %s supports float32 only", self.args.only)
2246            model, example_inputs = cast_to_fp32(model, example_inputs)
2247        elif self.args.float16:
2248            if self.args.only in self.force_amp_for_fp16_bf16_models:
2249                log.warning(
2250                    "Model %s does not support float16, running with amp instead",
2251                    self.args.only,
2252                )
2253                self.args.amp = True
2254                self.setup_amp()
2255            else:
2256                model, example_inputs = cast_to_fp16(model, example_inputs)
2257        elif self.args.bfloat16:
2258            if self.args.only in self.force_amp_for_fp16_bf16_models:
2259                log.warning(
2260                    "Model %s does not support bfloat16, running with amp instead",
2261                    self.args.only,
2262                )
2263                self.args.amp = True
2264                self.setup_amp()
2265            elif self.args.only in self.force_fp16_for_bf16_models:
2266                log.warning(
2267                    "Model %s does not support bfloat16, running with float16 instead",
2268                    self.args.only,
2269                )
2270                model, example_inputs = cast_to_fp16(model, example_inputs)
2271            else:
2272                model, example_inputs = cast_to_bf16(model, example_inputs)
2273
2274        return model, example_inputs
2275
2276    def validate_model(self, model, example_inputs):
2277        """
2278        Runs the eager model with example inputs to ensure that eager passes.
2279        """
2280        model = self.deepcopy_model(model)
2281        example_inputs = clone_inputs(example_inputs)
2282        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2283        try:
2284            self.model_iter_fn(model, example_inputs)
2285        except Exception as e:
2286            raise RuntimeError("Eager run failed") from e
2287
2288    def maybe_cast(self, model, example_inputs):
2289        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2290        return model, example_inputs
2291
2292    def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
2293        out_batch_size = batch_size * factor
2294        if out_batch_size > divisor:
2295            out_batch_size = (out_batch_size + 1) // divisor * divisor
2296        else:
2297            out_batch_size = batch_size - 1
2298        return max(0, int(out_batch_size))
2299
2300    def batch_size_finder(self, device, model_name, initial_batch_size=1024):
2301        batch_size = initial_batch_size
2302        while batch_size >= 1:
2303            empty_gpu_cache(current_device)
2304            try:
2305                device, name, model, example_inputs, _ = self.load_model(
2306                    device,
2307                    model_name,
2308                    batch_size,
2309                )
2310                self.model_iter_fn(model, example_inputs)
2311                return batch_size
2312            except RuntimeError as e:
2313                error_str = str(e)
2314                if "channels_last" in error_str:
2315                    break
2316            batch_size = self.decay_batch_exp(batch_size)
2317        return 1
2318
2319    def run_n_iterations(self, mod, inputs):
2320        n = self.args.iterations
2321        for _ in range(n - 1):
2322            self.model_iter_fn(mod, inputs, collect_outputs=False)
2323        return self.model_iter_fn(mod, inputs, collect_outputs=True)
2324
2325    @torch._disable_dynamo(recursive=True)
2326    def optimizer_zero_grad(self, mod):
2327        if self.optimizer is not None:
2328            self.optimizer.zero_grad(True)
2329        else:
2330            mod.zero_grad(True)
2331
2332    def optimizer_step(self):
2333        if self.optimizer is not None:
2334            self.optimizer.step()
2335
2336    def get_benchmark_indices(self, length):
2337        start = self._args.partition_id * (length // self._args.total_partitions)
2338        end = (
2339            (self._args.partition_id + 1) * (length // self._args.total_partitions)
2340            if self._args.partition_id < self._args.total_partitions - 1
2341            else length
2342        )
2343        return start, end
2344
2345    def get_fsdp_auto_wrap_policy(self, model_name: str):
2346        from diffusers.models.transformer_2d import Transformer2DModel
2347        from torchbenchmark.models.nanogpt.model import Block
2348        from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2349
2350        from transformers.models.t5.modeling_t5 import T5Block
2351        from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
2352
2353        from torch.distributed.fsdp.wrap import (
2354            ModuleWrapPolicy,
2355            size_based_auto_wrap_policy,
2356        )
2357
2358        # handcrafted wrap policy
2359        MODEL_FSDP_WRAP = {
2360            "stable_diffusion_unet": (Transformer2DModel,),
2361            "hf_T5": (T5Block,),
2362            "hf_T5_base": (T5Block,),
2363            "hf_T5_large": (T5Block,),
2364            "hf_Whisper": (WhisperEncoderLayer,),
2365            "llama_v2_7b_16h": (LlamaDecoderLayer,),
2366            "nanogpt": (Block,),
2367        }
2368
2369        if model_name not in MODEL_FSDP_WRAP:
2370            # default to using wrap policy based on module size
2371            return functools.partial(
2372                size_based_auto_wrap_policy, recurse=True, min_num_params=int(1e5)
2373            )
2374
2375        return ModuleWrapPolicy(MODEL_FSDP_WRAP[model_name])
2376
2377    def deepcopy_and_maybe_parallelize(self, model):
2378        model = self.deepcopy_model(model)
2379        if self.args.ddp:
2380            assert (
2381                torch.distributed.is_available()
2382            ), "Can't use DDP without a distributed enabled build"
2383            from torch.nn.parallel import DistributedDataParallel as DDP
2384
2385            model = DDP(model, find_unused_parameters=True)
2386        elif self.args.fsdp:
2387            assert (
2388                torch.distributed.is_available()
2389            ), "Can't use FSDP without a distributed enabled build"
2390            from torch.distributed.fsdp import (
2391                FullyShardedDataParallel as FSDP,
2392                MixedPrecision,
2393            )
2394
2395            if self.args.float16:
2396                dtype = torch.float16
2397            elif self.args.bfloat16:
2398                dtype = torch.bfloat16
2399            else:
2400                dtype = torch.float32
2401
2402            mp_policy = MixedPrecision(
2403                param_dtype=dtype,
2404                # Gradient communication precision.
2405                reduce_dtype=dtype,
2406                # Buffer precision.
2407                buffer_dtype=dtype,
2408            )
2409
2410            model = FSDP(
2411                model,
2412                use_orig_params=True,
2413                device_id=torch.cuda.current_device()
2414                if self.args.devices[-1] == "cuda"
2415                else None,
2416                mixed_precision=mp_policy,
2417                limit_all_gathers=True,
2418                auto_wrap_policy=self.get_fsdp_auto_wrap_policy(self.args.only),
2419            )
2420        return model
2421
2422    def check_accuracy(
2423        self, name, model, example_inputs, optimize_ctx, experiment, tag
2424    ):
2425        """
2426        Checks accuracy.
2427        1) Collect the outputs with fp64 datatype. This is useful for error checking.
2428        2) Checks if eager itself has variations.
2429        """
2430        start_stats = get_dynamo_stats()
2431
2432        def record_status(accuracy_status, dynamo_start_stats):
2433            """
2434            Records the status in the csv file
2435            """
2436            if current_name in self.non_deterministic_models:
2437                if accuracy_status in (
2438                    "pass",
2439                    "eager_two_runs_differ",
2440                    "fail_accuracy",
2441                ):
2442                    accuracy_status = "pass"
2443
2444            headers = ["dev", "name", "batch_size", "accuracy"]
2445            fields = [current_device, current_name, current_batch_size, accuracy_status]
2446
2447            if tag is not None:
2448                headers.insert(3, "tag")
2449                fields.insert(3, tag)
2450
2451            dynamo_stats = get_dynamo_stats()
2452            dynamo_stats.subtract(dynamo_start_stats)
2453            for k, v in dynamo_stats.items():
2454                headers.append(k)
2455                fields.append(v)
2456
2457            output_csv(output_filename, headers, fields)
2458            return accuracy_status
2459
2460        if name in self.skip_accuracy_checks_large_models_dashboard:
2461            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2462
2463        with self.pick_grad(name, self.args.training):
2464            # Collect the fp64 reference outputs to be used later for accuracy checking.
2465            fp64_outputs = None
2466            model_fp64 = None
2467            inputs_fp64 = None
2468            try:
2469                model_fp64, inputs_fp64 = cast_to_fp64(
2470                    self.deepcopy_and_maybe_parallelize(model),
2471                    clone_inputs(example_inputs),
2472                )
2473                self.init_optimizer(name, current_device, model_fp64.parameters())
2474                fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
2475                fp64_outputs = tree_map(
2476                    lambda x: x.to(torch.float64)
2477                    if isinstance(x, torch.Tensor) and x.is_floating_point()
2478                    else x,
2479                    fp64_outputs,
2480                )
2481            except Exception:
2482                log.warning(
2483                    "fp64 golden ref were not generated for %s. Setting accuracy check to cosine",
2484                    name,
2485                )
2486                self.args.cosine = True
2487                fp64_outputs = None
2488            finally:
2489                del model_fp64, inputs_fp64
2490                empty_gpu_cache(current_device)
2491
2492            tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
2493                self.args.training, current_device, name
2494            )
2495
2496            # Cast the model to float16/float32 as necessary
2497            model, example_inputs = self.maybe_cast(model, example_inputs)
2498            accuracy_status = "pass"
2499
2500            # Get results of native pytorch
2501            reset_rng_state()
2502            model_copy = None
2503            try:
2504                model_copy = self.deepcopy_and_maybe_parallelize(model)
2505                self.init_optimizer(name, current_device, model_copy.parameters())
2506                correct_result = self.run_n_iterations(
2507                    model_copy, clone_inputs(example_inputs)
2508                )
2509            except Exception as e:
2510                accuracy_status = (
2511                    "eager_1st_run_OOM"
2512                    if isinstance(e, torch.cuda.OutOfMemoryError)
2513                    else "eager_1st_run_fail"
2514                )
2515                log.exception("")
2516                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2517            finally:
2518                del model_copy
2519                empty_gpu_cache(current_device)
2520
2521            # Rerun native pytorch
2522            reset_rng_state()
2523            model_copy = None
2524            try:
2525                model_copy = self.deepcopy_and_maybe_parallelize(model)
2526                self.init_optimizer(name, current_device, model_copy.parameters())
2527                correct_rerun_result = self.run_n_iterations(
2528                    model_copy, clone_inputs(example_inputs)
2529                )
2530            except Exception as e:
2531                accuracy_status = (
2532                    "eager_2nd_run_OOM"
2533                    if isinstance(e, torch.cuda.OutOfMemoryError)
2534                    else "eager_2nd_run_fail"
2535                )
2536                log.exception("")
2537                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2538            finally:
2539                del model_copy
2540                empty_gpu_cache(current_device)
2541
2542            # Two eager runs should have exactly same result
2543            is_same = True
2544            try:
2545                if (
2546                    name not in self.skip_accuracy_check_as_eager_non_deterministic
2547                    and not same(
2548                        correct_result,
2549                        correct_rerun_result,
2550                        fp64_ref=None,
2551                        cos_similarity=False,
2552                        tol=0,
2553                        equal_nan=self.equal_nan,
2554                    )
2555                ):
2556                    is_same = False
2557            except Exception as e:
2558                # Sometimes torch.allclose may throw RuntimeError
2559                is_same = False
2560
2561            if not is_same:
2562                accuracy_status = "eager_two_runs_differ"
2563                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2564
2565            correct_rerun_result = None
2566
2567            # Run with Dynamo
2568            reset_rng_state()
2569            torch._dynamo.reset()
2570            model_copy = None
2571            try:
2572                model_copy = self.deepcopy_and_maybe_parallelize(model)
2573                self.init_optimizer(name, current_device, model_copy.parameters())
2574                if self.args.export or self.args.export_aot_inductor:
2575                    # apply export on module directly
2576                    # no need for n iterations
2577                    # the logic should be the same to self.model_iter_fn (forward_pass)
2578                    with self.autocast(**self.autocast_arg):
2579                        optimized_model_iter_fn = optimize_ctx(
2580                            model_copy, example_inputs
2581                        )
2582                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2583                else:
2584                    optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2585                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
2586                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2587            except Exception as e:
2588                log.exception("")
2589                print(
2590                    "TorchDynamo optimized model failed to run because of following error"
2591                )
2592                accuracy_status = (
2593                    "OOM"
2594                    if isinstance(e, torch.cuda.OutOfMemoryError)
2595                    else "fail_to_run"
2596                )
2597                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2598            finally:
2599                del model_copy
2600
2601            if name in self.skip_accuracy_check_as_eager_non_deterministic:
2602                return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2603
2604            if (
2605                current_onnx_compiler == "torchscript"
2606                or current_onnx_compiler == "dynamo"
2607            ):
2608                # Workaround for ONNX for non-tensor outputs
2609                (
2610                    correct_result,
2611                    new_result,
2612                    fp64_outputs,
2613                ) = _OnnxPatch.patch_non_tensor_outputs(
2614                    correct_result, new_result, fp64_outputs
2615                )
2616                # Relax tolerance for ONNX cuda
2617                if current_device == "cuda":
2618                    tolerance = 1e-2
2619
2620                # TODO: store correct_result into the dumped file for offline onnx model validation.
2621                # The downside and potential problem, is that the output formats may be different.
2622                # E.g., the output order might not match, None might be part of output, etc.
2623
2624            try:
2625                if self.args.training and self.args.amp:
2626                    if process_fn := self.get_output_amp_train_process_func.get(
2627                        name, None
2628                    ):
2629                        correct_result = process_fn(correct_result)
2630                        new_result = process_fn(new_result)
2631                        fp64_outputs = process_fn(fp64_outputs)
2632
2633                if not same(
2634                    correct_result,
2635                    new_result,
2636                    fp64_outputs,
2637                    equal_nan=self.equal_nan,
2638                    cos_similarity=cos_similarity,
2639                    tol=tolerance,
2640                ):
2641                    is_same = False
2642            except Exception as e:
2643                # Sometimes torch.allclose may throw RuntimeError
2644                is_same = False
2645
2646            if not is_same:
2647                if self.args.skip_accuracy_check:
2648                    accuracy_status = "pass_due_to_skip"
2649                else:
2650                    accuracy_status = "fail_accuracy"
2651                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2652
2653        return record_status(accuracy_status, dynamo_start_stats=start_stats)
2654
2655    def check_tolerance(
2656        self, name, model, example_inputs, optimize_ctx, base_device="cpu"
2657    ):
2658        """
2659        Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
2660        """
2661        tolerance_status = "pass"
2662        if name in self.skip_accuracy_checks_large_models_dashboard:
2663            tolerance_status = "pass_due_to_skip"
2664            return tolerance_status
2665        # Cast the model to float16/float32 as necessary
2666        model, example_inputs = self.maybe_cast(model, example_inputs)
2667
2668        with self.pick_grad(name, self.args.training):
2669            # Get results of native pytorch
2670            reset_rng_state()
2671            model_copy = copy.deepcopy(model)
2672            model_copy = model_copy.to(base_device)
2673            example_inputs_copy = copy.deepcopy(example_inputs)
2674            example_inputs_copy = tree_map(
2675                lambda x: x.to(base_device), example_inputs_copy
2676            )
2677            self.init_optimizer(name, base_device, model_copy.parameters())
2678            correct_result = self.run_n_iterations(model_copy, example_inputs_copy)
2679
2680            # Run with Dynamo
2681            # Sometime CI fails with random triton compilation failure which will be skipped for now
2682            # TODO: revisit this after switching to new Triton runtime
2683            reset_rng_state()
2684            torch._dynamo.reset()
2685            try:
2686                self.init_optimizer(name, current_device, model.parameters())
2687                optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2688                new_result = optimized_model_iter_fn(model, example_inputs)
2689            except Exception as e:
2690                log.exception("")
2691                print(
2692                    "TorchDynamo optimized model failed to run because of following error"
2693                )
2694                return "fail_to_run"
2695
2696            def dump_max_mean_values(tol, ref, res):
2697                if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
2698                    for refi, resi in zip(ref, res):
2699                        dump_max_mean_values(tol, refi, resi)
2700                elif isinstance(ref, dict):
2701                    for k in ref.keys():
2702                        dump_max_mean_values(tol, ref[k], res[k])
2703                elif isinstance(ref, torch.Tensor):
2704                    res = res.to(base_device)
2705                    t = torch.abs(ref - res) / (1 + torch.abs(ref))
2706                    tol.append(t.flatten().to(torch.float32))
2707                return tol
2708
2709            tol = []
2710            dump_max_mean_values(tol, correct_result, new_result)
2711            tol = torch.cat(tol)
2712            tol = torch.tensor(tol)
2713            max = torch.max(tol)
2714            mean = torch.mean(tol)
2715            div = torch.std(tol)
2716            headers = ["dev", "name", "batch_size", "max", "mean", "std"]
2717            fields = [
2718                current_device,
2719                current_name,
2720                current_batch_size,
2721                max.item(),
2722                mean.item(),
2723                div.item(),
2724            ]
2725            output_csv(output_filename, headers, fields)
2726        return tolerance_status
2727
2728    def run_performance_test(
2729        self, name, model, example_inputs, optimize_ctx, experiment, tag=None
2730    ):
2731        if self.args.xla:
2732            with self.pick_grad(name, self.args.training):
2733                return experiment(*self.maybe_cast(model, example_inputs))
2734
2735        def warmup(fn, model, example_inputs, mode, niters=5):
2736            peak_mem = 0
2737            start_stats = get_dynamo_stats()
2738            try:
2739                if current_device == "cuda":
2740                    torch.cuda.reset_peak_memory_stats()
2741                    empty_gpu_cache(current_device)
2742                t0 = time.perf_counter()
2743                for _ in range(niters):
2744                    fn(model, example_inputs)
2745                t1 = time.perf_counter()
2746                latency = t1 - t0
2747                if current_device == "cuda":
2748                    peak_mem = get_peak_memory()
2749                elif current_device == "cpu":
2750                    total = psutil.virtual_memory().total
2751                    percentage = psutil.Process(os.getpid()).memory_percent()
2752                    peak_mem = percentage * total / 10**9
2753            except Exception:
2754                log.exception("Backend %s failed in warmup()", mode)
2755                return sys.exit(-1)
2756            dynamo_stats = get_dynamo_stats()
2757            dynamo_stats.subtract(start_stats)
2758            return latency, peak_mem, dynamo_stats
2759
2760        # Cast the model to float16/float32 as necessary
2761        model, example_inputs = self.maybe_cast(model, example_inputs)
2762
2763        # Use distributed wrapping as necessary
2764        model = self.deepcopy_and_maybe_parallelize(model)
2765
2766        self.init_optimizer(name, current_device, model.parameters())
2767
2768        # The self.autocast context is needed for the model we export with aot_compile,
2769        # similar to what we do in the check_accuracy function
2770        ctx = (
2771            self.autocast(**self.autocast_arg)
2772            if self.args.export_aot_inductor
2773            else contextlib.nullcontext()
2774        )
2775
2776        with self.pick_grad(name, self.args.training), ctx:
2777            ok, total = Stats.reset_counters()
2778            experiment_kwargs = {}
2779            if tag is not None:
2780                experiment_kwargs["tag"] = tag
2781            results = []
2782            with maybe_snapshot_memory(
2783                self.args.snapshot_memory, f"eager_{self.args.only}"
2784            ):
2785                eager_latency, eager_peak_mem, _ = warmup(
2786                    self.model_iter_fn, model, example_inputs, "eager"
2787                )
2788                if self.args.use_warm_peak_memory:
2789                    _, eager_peak_mem, _ = warmup(
2790                        self.model_iter_fn, model, example_inputs, "eager", niters=1
2791                    )
2792
2793            if self.args.export_aot_inductor:
2794                t_0 = time.perf_counter()
2795                optimized_model_iter_fn = optimize_ctx
2796                t_1 = time.perf_counter()
2797                aot_compilation_time = t_1 - t_0
2798            else:
2799                optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2800                aot_compilation_time = 0
2801
2802            with maybe_enable_compiled_autograd(
2803                self.args.compiled_autograd
2804            ), maybe_snapshot_memory(
2805                self.args.snapshot_memory, f"compiled_{self.args.only}"
2806            ):
2807                dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
2808                    optimized_model_iter_fn, model, example_inputs, "dynamo"
2809                )
2810                if self.args.use_warm_peak_memory:
2811                    _, dynamo_peak_mem, _ = warmup(
2812                        optimized_model_iter_fn,
2813                        model,
2814                        example_inputs,
2815                        "dynamo",
2816                        niters=1,
2817                    )
2818
2819            if self.args.profile_dynamo_cache_lookup:
2820                with torch.profiler.profile(
2821                    activities=[torch.profiler.ProfilerActivity.CPU]
2822                ) as prof:
2823                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
2824                        warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
2825
2826                events = list(
2827                    filter(
2828                        lambda event: "TorchDynamo Cache Lookup" in event.key,
2829                        prof.key_averages(),
2830                    )
2831                )
2832                dynamo_cache_lookup_latency = events[0].self_cpu_time_total
2833
2834            compilation_time = dynamo_latency - eager_latency + aot_compilation_time
2835            compression_ratio = (
2836                eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0
2837            )
2838            if self.args.print_memory:
2839                print(
2840                    f"memory: eager: {eager_peak_mem:.2f} GB, "
2841                    f"dynamo: {dynamo_peak_mem:.2f} GB, "
2842                    f"ratio: {compression_ratio:.2f}"
2843                )
2844
2845            if self.args.print_compilation_time:
2846                print(f"Compilation time: {compilation_time:.2f}")
2847
2848            if experiment.func is speedup_experiment:
2849                experiment_kwargs["compilation_latency"] = compilation_time
2850                experiment_kwargs["compression_ratio"] = compression_ratio
2851                experiment_kwargs["eager_peak_mem"] = eager_peak_mem
2852                experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
2853                experiment_kwargs["dynamo_stats"] = dynamo_stats
2854                if self.args.profile_dynamo_cache_lookup:
2855                    experiment_kwargs[
2856                        "cache_lookup_latency"
2857                    ] = dynamo_cache_lookup_latency
2858
2859            if experiment.func is coverage_experiment:
2860                ok, total = Stats.reset_counters()
2861                results = []
2862                # run with torch._dynamo few times to populate the cache
2863                for _ in range(3):
2864                    optimized_model_iter_fn(model, example_inputs)
2865                _, frames_second_pass = Stats.reset_counters()  # should be 0
2866                if frames_second_pass > 0:
2867                    optimized_model_iter_fn(model, example_inputs)
2868                    _, frames_third_pass = Stats.reset_counters()  # should be 0
2869                else:
2870                    frames_third_pass = 0
2871
2872                results.append(
2873                    f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
2874                )
2875
2876            if experiment.func is speedup_experiment_onnx:
2877                experiment = functools.partial(
2878                    experiment, optimized_model_iter_fn.context.onnx_model
2879                )
2880
2881            if not hasattr(model, name):
2882                model.name = name
2883            results.append(experiment(model, example_inputs, **experiment_kwargs))
2884            return " ".join(map(str, results))
2885
2886    def minify_model(
2887        self,
2888        name,
2889        model,
2890        example_inputs,
2891        optimize_ctx,
2892        experiment,
2893        tag,
2894    ):
2895        logging.info("Minifying %s...", name)
2896        os.environ["TORCH_COMPILE_DEBUG"] = "1"
2897        os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
2898        os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
2899
2900        self.check_accuracy(name, model, example_inputs, optimize_ctx, experiment, tag)
2901
2902        if self.args.output_directory:
2903            repro_dir = self.args.output_directory
2904        else:
2905            repro_dir = torch._dynamo.config.base_dir
2906
2907        try:
2908            shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
2909        except OSError as e:
2910            logging.error("Could not find repro script for model %s", name)
2911        else:
2912            logging.info(
2913                "Repro script for model %s with minified graph saved to %s",
2914                name,
2915                repro_dir,
2916            )
2917
2918    def maybe_preserve_compile_debug(self, name, status):
2919        if (
2920            name in CI_PRESERVE_COMPILE_DEBUG
2921            and status in CI_PRESERVE_COMPILE_DEBUG[name]
2922        ):
2923            src_dir = torch._dynamo.utils.get_debug_dir()
2924            if os.path.isdir(src_dir):
2925                dbg_dir = os.path.join(
2926                    os.getcwd(), "test", "debug", "torch_compile_debug"
2927                )
2928                dst_dir = os.path.join(dbg_dir, os.path.basename(src_dir))
2929                try:
2930                    os.makedirs(dbg_dir, exist_ok=True)
2931                    os.rename(src_dir, dst_dir)
2932                    log.warning("Moved %s to %s", src_dir, dst_dir)
2933                except OSError:
2934                    log.exception("Failed to preserve %s", src_dir)
2935
2936    def run_one_model(
2937        self,
2938        name,
2939        model,
2940        example_inputs,
2941        optimize_ctx,
2942        experiment,
2943        explain=False,
2944        tag=None,
2945    ):
2946        mode = "train" if self.args.training else "eval"
2947        msg = f"{current_device:4} {mode:5} {current_name:34} "
2948        if tag:
2949            msg += f" {tag:26}"
2950        print(msg, flush=True)
2951
2952        start_stats = get_dynamo_stats()
2953
2954        if self.args.accuracy:
2955            status = self.check_accuracy(
2956                name, model, example_inputs, optimize_ctx, experiment, tag
2957            )
2958            print(status)
2959            if status == "fail_accuracy" and self.args.minify:
2960                self.minify_model(
2961                    name, model, example_inputs, optimize_ctx, experiment, tag
2962                )
2963        elif self.args.tolerance:
2964            status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
2965            print(status)
2966        elif self.args.performance:
2967            status = self.run_performance_test(
2968                name, model, example_inputs, optimize_ctx, experiment, tag
2969            )
2970            print(status)
2971        empty_gpu_cache(current_device)
2972
2973        self.maybe_preserve_compile_debug(name, status)
2974
2975        if self.args.timing:
2976            from torch._dynamo.utils import op_count, print_time_report
2977            from torch.utils._stats import simple_call_counter
2978
2979            print_time_report()
2980            stats = "STATS: "
2981            stats = stats + " | ".join(
2982                itertools.chain(
2983                    [f"call_* op count: {op_count}"],
2984                    (f"{key}:{value}" for key, value in simple_call_counter.items()),
2985                )
2986            )
2987            print(stats)
2988        stats = get_dynamo_stats()
2989        stats.subtract(start_stats)
2990
2991        if explain:
2992            print(
2993                f"Dynamo produced {stats['unique_graphs']} graphs "
2994                f"covering {stats['calls_captured']} ops with "
2995                f"{stats['graph_breaks']} graph breaks ({stats['unique_graph_breaks']} unique)"
2996            )
2997
2998        if explain or self.args.log_graph_breaks or self.args.print_graph_breaks:
2999            filename = f"{output_filename.rstrip('.csv')}_graph_breaks.csv"
3000
3001            def add_double_quotes(x):
3002                # Delimiter because reason could have comma
3003                return f'"{x}"'
3004
3005            for graph_break in graph_break_reasons:
3006                reason = add_double_quotes(graph_break.reason)
3007                user_stack = add_double_quotes(
3008                    ", ".join([str(x) for x in graph_break.user_stack])
3009                )
3010                output_csv(
3011                    filename,
3012                    ["model", "reason", "user_stack"],
3013                    [current_name, reason, user_stack],
3014                )
3015
3016        if self.args.stats:
3017            Stats.print_summary()
3018
3019
3020def help(fn):
3021    return fn.__doc__
3022
3023
3024diff_branch_default = "DIFF-BRANCH-DEFAULT"
3025
3026
3027def should_diff_branch(args):
3028    return args.diff_branch != diff_branch_default
3029
3030
3031def parse_args(args=None):
3032    parser = argparse.ArgumentParser()
3033    parser.add_argument(
3034        "--filter", "-k", action="append", help="filter benchmarks with regexp"
3035    )
3036    parser.add_argument(
3037        "--exclude", "-x", action="append", help="filter benchmarks with regexp"
3038    )
3039    parser.add_argument(
3040        "--exclude-exact", action="append", help="filter benchmarks with exact match"
3041    )
3042    parser.add_argument(
3043        "--total-partitions",
3044        type=int,
3045        default=1,
3046        choices=range(1, 10),
3047        help="Total number of partitions we want to divide the benchmark suite into",
3048    )
3049    parser.add_argument(
3050        "--partition-id",
3051        type=int,
3052        default=0,
3053        help="ID of the benchmark suite partition to be run. Used to divide CI tasks",
3054    )
3055    parser.add_argument(
3056        "--devices", "--device", "-d", action="append", help="cpu or cuda"
3057    )
3058    parser.add_argument("--device-index", help="CUDA device index")
3059    parser.add_argument(
3060        "--repeat", "-n", type=int, default=30, help="number of timing runs"
3061    )
3062    iterations_per_run_help = """
3063        Run this may iterations for each time measurement. This is mainly used for
3064        XLA training. We want to run multiple iterations per measurement so the
3065        tracing and computation for different iteartions can overlap with each
3066        other. This makes sure we have an accurate xla baseline.
3067    """
3068    parser.add_argument(
3069        "--iterations-per-run", type=int, default=1, help=iterations_per_run_help
3070    )
3071    parser.add_argument(
3072        "--randomize-input",
3073        action="store_true",
3074        help="Whether to randomize the input values. Dimensions will be kept the same.",
3075    )
3076    parser.add_argument(
3077        "--threads",
3078        "-t",
3079        type=int,
3080        help="number of threads to use for eager and inductor",
3081    )
3082    parser.add_argument(
3083        "--nopython", action="store_true", help="Turn graph breaks into errors"
3084    )
3085    parser.add_argument(
3086        "--no-skip",
3087        action="store_true",
3088        help="run models that are in the global SKIP list",
3089    )
3090    parser.add_argument(
3091        "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend"
3092    )
3093    parser.add_argument(
3094        "--dump-raw-metrics",
3095        action="store_true",
3096        help="dump raw timing metrics from speedup experiment",
3097    )
3098    parser.add_argument(
3099        "--log-operator-inputs",
3100        action="store_true",
3101        default=False,
3102    )
3103    parser.add_argument(
3104        "--channels-last",
3105        action="store_true",
3106        default=False,
3107        help="use channels last format",
3108    )
3109    parser.add_argument(
3110        "--batch-size", "--batch_size", type=int, help="batch size for benchmarking"
3111    )
3112    parser.add_argument(
3113        "--iterations", type=int, default=2, help="how many iterations to run"
3114    )
3115    parser.add_argument(
3116        "--batch-size-file", type=str, help="String to load batch size from"
3117    )
3118    parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
3119    parser.add_argument(
3120        "--freezing", action="store_true", help="turn on freezing", default=False
3121    )
3122    parser.add_argument(
3123        "--ci", action="store_true", help="Flag to tell that its a CI run"
3124    )
3125    parser.add_argument(
3126        "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run"
3127    )
3128    parser.add_argument(
3129        "--skip-fp64-check", action="store_true", help="skip accuracy check using fp64"
3130    )
3131    parser.add_argument(
3132        "--fast", "-f", action="store_true", help="skip slow benchmarks"
3133    )
3134    parser.add_argument(
3135        "--only",
3136        help="""Run just one model from torchbench. Or
3137        specify the path and class name of the model in format like:
3138        --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
3139
3140        Due to the fact that dynamo changes current working directory,
3141        the path should be an absolute path.
3142
3143        The class should have a method get_example_inputs to return the inputs
3144        for the model. An example looks like
3145        ```
3146        class LinearModel(nn.Module):
3147            def __init__(self):
3148                super().__init__()
3149                self.linear = nn.Linear(10, 10)
3150
3151            def forward(self, x):
3152                return self.linear(x)
3153
3154            def get_example_inputs(self):
3155                return (torch.randn(2, 10),)
3156        ```
3157    """,
3158    )
3159    parser.add_argument(
3160        "--multiprocess",
3161        action="store_true",
3162        help="Create n processes based on the number of devices (distributed use case).",
3163    )
3164    parser.add_argument(
3165        "--ddp",
3166        action="store_true",
3167        help="Wraps model in DDP before running it, and uses dynamo DDPOptmizer (graph breaks) by default.",
3168    )
3169    parser.add_argument(
3170        "--fsdp",
3171        action="store_true",
3172        help="""Wraps model in FSDP before running it.
3173        Doesn't recursively wrap, mainly useful for checking dynamo UnspecNNModule compatibility
3174    """,
3175    )
3176    parser.add_argument(
3177        "--optimize-ddp-mode",
3178        type=str,
3179        default="ddp_optimizer",
3180        help="Specify the DDP optimization mode -- the value of torch._dynamo.config.optimize_ddp.",
3181    )
3182    parser.add_argument(
3183        "--distributed-master-port",
3184        default="6789",
3185        help="Port to bind for for torch.distributed.  Use the default unless it's conflicting with another user",
3186    )
3187    parser.add_argument(
3188        "--dynamic-shapes",
3189        action="store_true",
3190        help="Runs a dynamic shapes version of the benchmark, if available.",
3191    )
3192    parser.add_argument(
3193        "--propagate-real-tensors",
3194        action="store_true",
3195        help="Capture as much data dependent as you can by unsoundly propagating real tensors",
3196    )
3197    parser.add_argument(
3198        "--dynamic-batch-only",
3199        action="store_true",
3200        help="Only assume batch dimension is dynamic.  Implies --dynamic-shapes",
3201    )
3202    parser.add_argument(
3203        "--specialize-int", action="store_true", help="Run with specialize_int=True."
3204    )
3205    parser.add_argument(
3206        "--use-eval-mode",
3207        action="store_true",
3208        help="sets model.eval() to reduce randomness",
3209    )
3210    parser.add_argument(
3211        "--skip-accuracy-check",
3212        action="store_true",
3213        help="keeps running even when accuracy fails",
3214    )
3215    parser.add_argument(
3216        "--generate-aot-autograd-stats",
3217        action="store_true",
3218        help="Generates AOT Autograd stats like how mnay graphs are sent to AOT",
3219    )
3220    parser.add_argument(
3221        "--inductor-settings",
3222        action="store_true",
3223        help="Use same settings as --inductor for baseline comparisons",
3224    )
3225    parser.add_argument(
3226        "--suppress-errors",
3227        action="store_true",
3228        help="Suppress errors instead of raising them",
3229    )
3230    parser.add_argument(
3231        "--output",
3232        help="Overrides the output filename",
3233    )
3234    parser.add_argument(
3235        "--output-directory",
3236        help="Overrides the directory to place output files.",
3237    )
3238    parser.add_argument(
3239        "--disable-output",
3240        action="store_true",
3241        help="Disable writing of output files, e.g., for warm-up runs",
3242    )
3243    parser.add_argument(
3244        "--baseline",
3245        help="Compare with a prior --output",
3246    )
3247    parser.add_argument(
3248        "--part",
3249        default=None,
3250        help="Specify the part of the model to run.",
3251    )
3252    parser.add_argument(
3253        "--export-profiler-trace",
3254        action="store_true",
3255        help="exports trace of kineto profiler",
3256    )
3257    parser.add_argument(
3258        "--profiler-trace-name",
3259        "--profiler_trace_name",
3260        help="Overwrites exported trace name",
3261    )
3262    parser.add_argument(
3263        "--diff-branch",
3264        default=diff_branch_default,
3265        help="delta current branch against given branch.",
3266    )
3267    parser.add_argument(
3268        "--tag", default=None, help="Specify a tag to be included in csv files."
3269    )
3270    parser.add_argument(
3271        "--explain",
3272        action="store_true",
3273        help="print some graph/op statistics during the run, similar to .explain()",
3274    )
3275    parser.add_argument(
3276        "--stats",
3277        action="store_true",
3278        help="print graph counter stats",
3279    )
3280    parser.add_argument(
3281        "--use-warm-peak-memory",
3282        "--use_warm_peak_memory",
3283        action="store_true",
3284        help="Measure peak memory using a warm run to reduce autotuning noise",
3285    )
3286    parser.add_argument(
3287        "--print-memory",
3288        action="store_true",
3289        help="print extra memory statistics",
3290    )
3291    parser.add_argument(
3292        "--print-compilation-time",
3293        action="store_true",
3294        help="print compilation latency",
3295    )
3296    parser.add_argument(
3297        "--print-dataframe-summary",
3298        action="store_true",
3299        help="print dataframe result used for calculating accuracy",
3300    )
3301    parser.add_argument(
3302        "--disable-cudagraphs",
3303        action="store_true",
3304        help="Disables cudagraphs for Inductor",
3305    )
3306    parser.add_argument(
3307        "--disable-split-reductions",
3308        action="store_true",
3309        help="Disables split reductions for Inductor",
3310    )
3311    parser.add_argument(
3312        "--disable-persistent-reductions",
3313        action="store_true",
3314        help="Disables split reductions for Inductor",
3315    )
3316    parser.add_argument(
3317        "--disable-divisible-by-16",
3318        action="store_true",
3319        help="Disables divisible by 16 hint to Triton for Inductor",
3320    )
3321    parser.add_argument(
3322        "--inductor-compile-mode",
3323        default=None,
3324        help="torch.compile mode argument for inductor runs.",
3325    )
3326    parser.add_argument(
3327        "--print-graph-breaks",
3328        action="store_true",
3329        help="Show a warning whenever graph break",
3330    )
3331    parser.add_argument(
3332        "--log-graph-breaks",
3333        action="store_true",
3334        help="log graph breaks in a file",
3335    )
3336    parser.add_argument(
3337        "--trace-on-xla",
3338        action="store_true",
3339        help="Whether to trace the model on XLA or on eager device",
3340    )
3341    parser.add_argument(
3342        "--xla-tolerance",
3343        type=float,
3344        default=1e-2,
3345        help="XLA needs a loose tolerance to pass the correctness check",
3346    )
3347    parser.add_argument(
3348        "--collect-outputs",
3349        action="store_true",
3350        help="""Whether to collect outputs for training. Set this to true if we
3351        want to verify the numerical correctness of graidents. But that may
3352        cause time measurement not accurate""",
3353    )
3354    parser.add_argument(
3355        "--enable-activation-checkpointing",
3356        action="store_true",
3357        help="Enables activation checkpointing for HF models",
3358    )
3359    parser.add_argument("--timing", action="store_true", help="Emits phase timing")
3360
3361    parser.add_argument(
3362        "--progress",
3363        action="store_true",
3364        help="Print n/k models message between each model run.",
3365    )
3366
3367    parser.add_argument(
3368        "--timeout",
3369        type=int,
3370        default=2000,
3371        help="timeout (second) for benchmarking.",
3372    )
3373
3374    parser.add_argument(
3375        "--per_process_memory_fraction",
3376        type=float,
3377        default=1,
3378        help="Set per-process GPU memory fraction (limit) for reducing usable size and reproducing OOMs",
3379    )
3380
3381    parser.add_argument(
3382        "--no-translation-validation",
3383        action="store_true",
3384        help="Disable translation validation for accuracy builds.",
3385    )
3386
3387    parser.add_argument(
3388        "--minify",
3389        action="store_true",
3390        help="Enable minification when failure is below tolerance. Save repro script for each model.",
3391    )
3392
3393    parser.add_argument(
3394        "--compiled-autograd",
3395        action="store_true",
3396        help="Enables compiled autograd on compiled benchmark",
3397    )
3398
3399    parser.add_argument(
3400        "--profile_dynamo_cache_lookup",
3401        "--profile-dynamo-cache-lookup",
3402        action="store_true",
3403        help="profiles TorchDynamo cache lookup",
3404    )
3405
3406    parser.add_argument(
3407        "--snapshot-memory",
3408        "--snapshot_memory",
3409        action="store_true",
3410        help="Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/",
3411    )
3412
3413    group_latency = parser.add_mutually_exclusive_group()
3414    group_latency.add_argument(
3415        "--cold-start-latency",
3416        "--cold_start_latency",
3417        action="store_true",
3418        help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
3419    )
3420    group_latency.add_argument(
3421        "--warm-start-latency",
3422        "--warm_start_latency",
3423        action="store_true",
3424        help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run",
3425    )
3426
3427    group_fuser = parser.add_mutually_exclusive_group()
3428    # --nvfuser is now the default, keep the option to not break scripts
3429    group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
3430    group_fuser.add_argument("--nnc", action="store_true", help="enable NNC for GPUs")
3431
3432    group_prec = parser.add_mutually_exclusive_group()
3433    group_prec.add_argument("--float16", action="store_true", help="cast model to fp16")
3434    group_prec.add_argument(
3435        "--bfloat16", action="store_true", help="cast model to bf16"
3436    )
3437    group_prec.add_argument("--float32", action="store_true", help="cast model to fp32")
3438    group_prec.add_argument(
3439        "--amp", action="store_true", help="use automatic mixed precision"
3440    )
3441    parser.add_argument(
3442        "--amp-dtype",
3443        choices=("bfloat16", "float16"),
3444        help="the data type used with automatic mixed precision",
3445    )
3446    group_printout = parser.add_mutually_exclusive_group()
3447    group_printout.add_argument(
3448        "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
3449    )
3450    group_printout.add_argument(
3451        "--quiet", "-q", action="store_true", help="suppress debug printouts"
3452    )
3453
3454    group = parser.add_mutually_exclusive_group()
3455    group.add_argument(
3456        "--coverage", action="store_true", help="(default) " + help(coverage_experiment)
3457    )
3458    group.add_argument(
3459        "--overhead", action="store_true", help=help(overhead_experiment)
3460    )
3461    group.add_argument(
3462        "--speedup-dynamo-ts",
3463        action="store_true",
3464        help="TorchDynamo frontend with torchscript backend",
3465    )
3466    group.add_argument(
3467        "--speedup-fx2trt", action="store_true", help=help(speedup_experiment_fx2trt)
3468    )
3469    group.add_argument(
3470        "--speedup-fx2trt-fp16",
3471        action="store_true",
3472        help=help(speedup_experiment_fx2trt),
3473    )
3474    group.add_argument(
3475        "--print-fx",
3476        action="store_true",
3477        help="Print fx traces captured from model",
3478    )
3479    group.add_argument(
3480        "--print-aten-ops",
3481        action="store_true",
3482        help="Print traces of aten ops captured by AOT autograd",
3483    )
3484    group.add_argument(
3485        "--inductor",
3486        action="store_true",
3487        help="Measure speedup with TorchInductor",
3488    )
3489    group.add_argument(
3490        "--quantization",
3491        choices=[
3492            "int8dynamic",
3493            "int8weightonly",
3494            "int4weightonly",
3495            "autoquant",
3496            "noquant",
3497        ],
3498        default=None,
3499        help="Measure speedup of torchao quantization with TorchInductor baseline",
3500    )
3501    group.add_argument(
3502        "--export",
3503        action="store_true",
3504        help="Measure pass rate with export",
3505    )
3506    group.add_argument(
3507        "--export-aot-inductor",
3508        action="store_true",
3509        help="Measure pass rate with Export+AOTInductor",
3510    )
3511    group.add_argument(
3512        "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
3513    )
3514    group.add_argument(
3515        "--torchscript-onnx",
3516        "--torchscript_onnx",
3517        action="store_true",
3518        help="Measure speedup with TorchScript ONNX, i.e. `torch.onnx.export`",
3519    )
3520    group.add_argument(
3521        "--dynamo-onnx",
3522        "--dynamo_onnx",
3523        action="store_true",
3524        help="Measure speedup with Dynamo ONNX, i.e. `torch.onnx.dynamo_export`",
3525    )
3526    group.add_argument(
3527        "--dynamo-onnx-aot-inline",
3528        "--dynamo_onnx_aot_inline",
3529        action="store_true",
3530        help="Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`",
3531    )
3532    group.add_argument(
3533        "--dynamo-onnx-aot-optimize",
3534        "--dynamo_onnx_aot_optimize",
3535        action="store_true",
3536        help="Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`",
3537    )
3538    group.add_argument(
3539        "--backend",
3540        choices=torch._dynamo.list_backends(exclude_tags=None),
3541        help="measure speedup with a given backend",
3542    )
3543    group.add_argument("--nothing", action="store_true", help=help(null_experiment))
3544    group.add_argument(
3545        "--log-conv-args",
3546        action="store_true",
3547        help="Dump convolution input/weight/bias's shape/stride/dtype and other options to json",
3548    )
3549    group.add_argument(
3550        "--recompile-profiler",
3551        "--recompile_profiler",
3552        action="store_true",
3553        help="Run the dynamo recompilation profiler on each model.",
3554    )
3555    group.add_argument(
3556        "--find-batch-sizes",
3557        action="store_true",
3558        help="finds the largest batch size that could fit on GPUs",
3559    )
3560
3561    mode_group = parser.add_mutually_exclusive_group(required=True)
3562    mode_group.add_argument(
3563        "--accuracy",
3564        action="store_true",
3565        help="Checks accuracy with small batch size and eval mode",
3566    )
3567    mode_group.add_argument(
3568        "--performance", action="store_true", help="Measures performance speedup"
3569    )
3570    mode_group.add_argument(
3571        "--tolerance",
3572        action="store_true",
3573        help="extracts the tolerance for each model with small batch size and eval mode",
3574    )
3575    run_mode_group = parser.add_mutually_exclusive_group(required=True)
3576    run_mode_group.add_argument(
3577        "--training",
3578        action="store_true",
3579        help="Performs training",
3580    )
3581    run_mode_group.add_argument(
3582        "--inference", action="store_true", help="Performs inference"
3583    )
3584    return parser.parse_args(args)
3585
3586
3587def process_entry(rank, runner, original_dir, args):
3588    args.rank = rank
3589    with maybe_init_distributed(
3590        args.init_distributed,
3591        rank=rank,
3592        world_size=args.world_size,
3593        port=args.distributed_master_port,
3594    ):
3595        return run(runner, args, original_dir)
3596
3597
3598def maybe_fresh_cache(args):
3599    cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os.environ
3600    if not cache_dir_assigned and (
3601        args.cold_start_latency or args.warm_start_latency or args.ci
3602    ):
3603        return fresh_inductor_cache()
3604    else:
3605        return contextlib.nullcontext()
3606
3607
3608def main(runner, original_dir=None, args=None):
3609    if original_dir:
3610        os.chdir(original_dir)
3611    args = parse_args() if not args else parse_args(args)
3612    if args.baseline:
3613        args.baseline = os.path.abspath(args.baseline)
3614
3615    if should_diff_branch(args):
3616        import git
3617
3618        # We do this here so we error out earlier if there's an issue
3619        repo = git.Repo()
3620        if repo.is_dirty():
3621            raise RuntimeError(
3622                "--diff-branch called on dirty branch. Commit, stash, or reset."
3623            )
3624        main_branch = repo.active_branch.name
3625        if main_branch == args.diff_branch:
3626            raise RuntimeError(
3627                f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?"
3628            )
3629
3630    with maybe_fresh_cache(args):
3631        args.init_distributed = args.only and args.multiprocess
3632        if args.init_distributed:
3633            # NB: Do NOT query device count before CUDA initialization; we're
3634            # going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3635            # https://github.com/pytorch/pytorch/issues/107300
3636            device_count = torch.cuda.device_count()
3637            if device_count <= 1:
3638                log.warning(
3639                    "The use multiprocess flag is set but there are <= 1 devices available."
3640                )
3641            # multiprocess path
3642            args.world_size = device_count
3643            mp.spawn(
3644                process_entry, args=(runner, original_dir, args), nprocs=device_count
3645            )
3646        elif args.only and args.warm_start_latency:
3647            # Warm start mode. Enable FX graph caching and perform back-to-back runs in
3648            # separate processes (but ensure the inductor cache is preserved across runs).
3649            env = os.environ.copy()
3650            env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
3651            cmd = [sys.executable] + sys.argv
3652            cmd.remove("--warm-start-latency")
3653
3654            print(f"Performing cold-start run for {args.only}")
3655            warmup_cmd = cmd + ["--repeat=1", "--disable-output"]
3656            subprocess.check_call(warmup_cmd, timeout=args.timeout, env=env)
3657
3658            print(f"Performing warm-start run for {args.only}")
3659            subprocess.check_call(cmd, timeout=args.timeout, env=env)
3660        else:
3661            # single process path just uses the main process
3662            args.world_size = 1
3663            process_entry(0, runner, original_dir, args)
3664
3665
3666def write_csv_when_exception(args, name: str, status: str, device=None):
3667    print(status)
3668    placeholder_batch_size = 0
3669    devices = [device] if device is not None else args.devices
3670    if args.accuracy:
3671        headers = ["dev", "name", "batch_size", "accuracy"]
3672        rows = [[device, name, placeholder_batch_size, status] for device in devices]
3673    elif args.performance:
3674        headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
3675        rows = [[device, name, placeholder_batch_size, 0.0, 0.0] for device in devices]
3676    else:
3677        headers = []
3678        rows = [[device, name, placeholder_batch_size, 0.0] for device in devices]
3679
3680    for row in rows:
3681        output_csv(output_filename, headers, row)
3682
3683
3684def run(runner, args, original_dir=None):
3685    # Pass the parsed args object to benchmark runner object
3686    runner.args = args
3687
3688    args.filter = args.filter or [r"."]
3689    args.exclude = args.exclude or [r"^$"]
3690    args.exclude_exact = args.exclude_exact or []
3691
3692    if args.inductor:
3693        assert args.backend is None
3694        args.backend = "inductor"
3695    if args.quantization:
3696        assert args.backend is None
3697        args.backend = "torchao"
3698    if args.dynamic_batch_only:
3699        args.dynamic_shapes = True
3700        torch._dynamo.config.assume_static_by_default = True
3701    if args.dynamic_shapes:
3702        if not args.dynamic_batch_only:
3703            torch._dynamo.config.assume_static_by_default = False
3704    if args.propagate_real_tensors:
3705        # TODO: Separate flag for data dependent
3706        torch._dynamo.config.capture_scalar_outputs = True
3707        torch._dynamo.config.capture_dynamic_output_shape_ops = True
3708        torch._functorch.config.fake_tensor_propagate_real_tensors = True
3709    if args.specialize_int:
3710        torch._dynamo.config.specialize_int = True
3711    if args.ci:
3712        if args.accuracy:
3713            # Run fewer iterations when checking accuracy
3714            args.repeat = min(args.repeat, 2)
3715
3716            # Set translation validation on by default on CI accuracy runs.
3717            torch.fx.experimental._config.translation_validation = True
3718
3719        ci = functools.partial(
3720            CI, args.backend, training=args.training, dynamic=args.dynamic_shapes
3721        )
3722    if args.ddp:
3723        assert args.training, "DDP benchmark requires --training mode"
3724        torch._dynamo.config.optimize_ddp = args.optimize_ddp_mode
3725        if args.only == "dlrm":
3726            log.error(
3727                "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP"
3728            )
3729            return sys.exit(-1)
3730    if args.accuracy:
3731        # Use small batch size. We use >1 batch size to ensure we test
3732        # batch_norm type of operators that work on batch dims.
3733        # TODO - Go through the failures for batch size = 2
3734        if args.batch_size is None:
3735            if runner.suite_name == "huggingface":
3736                args.batch_size = 1
3737            elif runner.suite_name == "torchbench":
3738                args.batch_size = 4
3739            else:
3740                # Larger batch size of TIMM models to have stable batch_norm
3741                assert runner.suite_name == "timm_models"
3742                args.batch_size = 8
3743
3744        # Remove sources of randomness
3745        if runner.suite_name not in ("timm_models", "huggingface"):
3746            # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
3747            args.use_eval_mode = True
3748        inductor_config.fallback_random = True
3749        if args.only is not None and args.only not in {
3750            "alexnet",
3751            "Background_Matting",
3752            "pytorch_CycleGAN_and_pix2pix",
3753            "pytorch_unet",
3754            "Super_SloMo",
3755            "vgg16",
3756            # https://github.com/pytorch/pytorch/issues/96724
3757            "Wav2Vec2ForCTC",
3758            "Wav2Vec2ForPreTraining",
3759            "sam",
3760            "sam_fast",
3761            "resnet50_quantized_qat",
3762            "mobilenet_v2_quantized_qat",
3763        }:
3764            # some of the models do not support use_deterministic_algorithms
3765            torch.use_deterministic_algorithms(True)
3766        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
3767        torch.backends.cudnn.deterministic = True
3768        torch.backends.cudnn.allow_tf32 = False
3769        torch.backends.cudnn.benchmark = False
3770        torch.backends.cuda.matmul.allow_tf32 = False
3771
3772        # Remove randomeness when torch manual seed is called
3773        patch_torch_manual_seed()
3774
3775        # Some models e.g. yolov3 assert batch size on n_gpus
3776        if "CUDA_VISIBLE_DEVICES" not in os.environ and not args.multiprocess:
3777            args.device_index = "0"
3778
3779        # Stricter check to disable fallbacks
3780        args.suppress_errors = False
3781
3782    if args.device_index is not None:
3783        if args.multiprocess:
3784            print("Cannot specify both --device_index and --multiprocess")
3785            return sys.exit(-1)
3786        os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index
3787
3788    elif args.performance:
3789        # Ensure that we test on real scenarios
3790        args.use_eval_mode = False
3791
3792    if args.partition_id > args.total_partitions or args.partition_id < 0:
3793        print("Invalid partition id")
3794        return sys.exit(-1)
3795
3796    if not args.devices:
3797        if torch.cuda.is_available():
3798            args.devices = ["cuda"]
3799        else:
3800            log.warning("torch.cuda.is_available() == False, using CPU")
3801            args.devices = ["cpu"]
3802
3803    if args.devices != ["cpu"] and (HAS_CUDA or HAS_XPU):
3804        global synchronize
3805        synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize
3806
3807    if (
3808        args.devices == ["cuda"]
3809        and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
3810    ):
3811        # OOM errors on an RTX 3090 with 24gb RAM
3812        runner.skip_models.update(
3813            {
3814                # torchbench
3815                "hf_Longformer",
3816                "timm_nfnet",
3817                "timm_efficientdet",
3818            }
3819        )
3820        if args.training:
3821            runner.skip_models.add("hf_T5")
3822
3823    if args.nnc:
3824        torch._C._jit_override_can_fuse_on_cpu(True)
3825        torch._C._jit_override_can_fuse_on_gpu(True)
3826        torch._C._jit_set_texpr_fuser_enabled(True)
3827        torch._C._jit_set_nvfuser_enabled(False)
3828
3829    if args.threads:
3830        torch.set_num_threads(args.threads)
3831
3832    if args.verbose:
3833        torch._logging.set_logs(dynamo=logging.DEBUG)
3834
3835    if args.print_graph_breaks:
3836        torch._logging.set_logs(graph_breaks=True)
3837
3838    if args.quiet:
3839        torch._logging.set_logs(dynamo=logging.ERROR)
3840
3841    torch._dynamo.config.suppress_errors = args.suppress_errors
3842
3843    if args.training:
3844        runner.model_iter_fn = runner.forward_and_backward_pass
3845        runner.skip_models.update(runner.skip_not_suitable_for_training_models)
3846    else:
3847        runner.model_iter_fn = runner.forward_pass
3848
3849    if args.fast:
3850        runner.skip_models.update(runner.slow_models)
3851
3852    if args.devices == ["cpu"]:
3853        runner.skip_models.update(runner.very_slow_models)
3854        runner.skip_models.update(runner.skip_models_for_cpu)
3855    elif args.devices == ["cuda"]:
3856        runner.skip_models.update(runner.skip_models_for_cuda)
3857
3858    if not args.multiprocess:
3859        runner.skip_models.update(runner.skip_multiprocess_models)
3860
3861    if args.freezing:
3862        runner.skip_models.update(runner.skip_models_for_freezing)
3863
3864    if args.no_skip:
3865        runner.skip_models.clear()
3866
3867    experiment = null_experiment
3868    global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
3869    optimize_ctx = contextlib.nullcontext()
3870
3871    if args.disable_output:
3872        disable_output = True
3873
3874    if args.overhead:
3875        optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
3876        experiment = speedup_experiment
3877        output_filename = "overheads.csv"
3878    elif args.inductor:
3879        inductor_config.debug = args.verbose
3880        if args.threads:
3881            inductor_config.cpp.threads = args.threads
3882
3883        optimize_ctx = functools.partial(
3884            torch.compile,
3885            backend="inductor",
3886            fullgraph=args.nopython,
3887            mode=args.inductor_compile_mode,
3888        )
3889        experiment = speedup_experiment
3890        output_filename = "inductor.csv"
3891    elif args.export:
3892        optimize_ctx = export
3893        experiment = speedup_experiment
3894        output_filename = "export.csv"
3895    elif args.xla:
3896        (dev,) = args.devices
3897        os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]
3898        torch._dynamo.mark_dynamic = MagicMock()
3899        experiment = xla
3900        output_filename = "xla.csv"
3901    elif args.torchscript_onnx:
3902        optimize_ctx = functools.partial(
3903            optimize_onnx_ctx,
3904            args.output_directory or ".",
3905            OnnxModelFromTorchScript,
3906            copy_before_export=args.performance,  # Accuarcy bench already did deepcopy
3907        )
3908        experiment = speedup_experiment_onnx
3909        output_filename = "torchscript_onnx.csv"
3910        current_onnx_compiler = "torchscript"
3911    elif args.dynamo_onnx:
3912        optimize_ctx = functools.partial(
3913            optimize_onnx_ctx,
3914            args.output_directory or ".",
3915            OnnxModelFromDynamo,
3916            dynamic_shapes=args.dynamic_shapes,
3917            copy_before_export=args.performance,
3918        )
3919        experiment = speedup_experiment_onnx
3920        output_filename = "dynamo_onnx.csv"
3921        current_onnx_compiler = "dynamo"
3922    elif args.dynamo_onnx_aot_inline:
3923        optimize_ctx = functools.partial(
3924            optimize_onnx_ctx,
3925            args.output_directory or ".",
3926            OnnxModelFromDynamoAotInline,
3927            dynamic_shapes=args.dynamic_shapes,
3928            copy_before_export=args.performance,
3929        )
3930        experiment = speedup_experiment_onnx
3931        output_filename = "dynamo_onnx_aot_inline.csv"
3932        current_onnx_compiler = "dynamo"
3933    elif args.dynamo_onnx_aot_optimize:
3934        optimize_ctx = functools.partial(
3935            optimize_onnx_ctx,
3936            args.output_directory or ".",
3937            OnnxModelFromDynamoAotOptimize,
3938            dynamic_shapes=args.dynamic_shapes,
3939            copy_before_export=args.performance,
3940        )
3941        experiment = speedup_experiment_onnx
3942        output_filename = "dynamo_onnx_aot_optimize.csv"
3943        current_onnx_compiler = "dynamo"
3944    elif args.speedup_dynamo_ts:
3945        optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
3946        experiment = speedup_experiment
3947        output_filename = "speedup_dynamo_ts.csv"
3948    elif args.prims_nvfuser:
3949        optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython)
3950        experiment = speedup_experiment
3951        backend_str = "prims_nvfuser"
3952        output_filename = f"accuracy_aot_{backend_str}.csv"
3953    elif args.print_fx:
3954        optimize_ctx = torch._dynamo.optimize(
3955            print_fx,
3956            nopython=args.nopython,
3957        )
3958    elif args.print_aten_ops:
3959        optimize_ctx = torch._dynamo.optimize(
3960            print_aten_ops,
3961            nopython=args.nopython,
3962        )
3963    elif args.nothing:
3964        optimize_ctx = nothing
3965        experiment = speedup_experiment
3966        output_filename = "nothing.csv"
3967    elif args.backend or args.export_aot_inductor:
3968        if args.export_aot_inductor:
3969            assert not args.training, "AOTInductor only supports inference"
3970            optimize_ctx = functools.partial(
3971                export_aot_inductor, device=args.devices[0]
3972            )
3973
3974            # AOTInductor doesn't support control flow yet
3975            runner.skip_models.update(runner.skip_models_due_to_control_flow)
3976        elif args.backend == "torchao":
3977            assert "cuda" in args.devices, "Quantization requires CUDA device."
3978            assert args.bfloat16, "Quantization requires dtype bfloat16."
3979            try:
3980                from torchao_backend import setup_baseline, torchao_optimize_ctx
3981            except ImportError:
3982                from userbenchmark.dynamo.dynamobench.torchao_backend import (
3983                    setup_baseline,
3984                    torchao_optimize_ctx,
3985                )
3986
3987            setup_baseline()
3988            baseline_ctx = functools.partial(
3989                torch.compile,
3990                backend="inductor",
3991                fullgraph=args.nopython,
3992                mode=args.inductor_compile_mode,
3993            )
3994            runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
3995            optimize_ctx = torchao_optimize_ctx(args.quantization)
3996        else:
3997            optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
3998        experiment = speedup_experiment
3999        if args.accuracy:
4000            output_filename = f"accuracy_{args.backend}.csv"
4001        elif args.tolerance:
4002            output_filename = f"tolerance_{args.backend}.csv"
4003        else:
4004            output_filename = f"speedup_{args.backend}.csv"
4005    elif args.recompile_profiler:
4006        output_filename = "recompile_profiler_log.csv"
4007        experiment = recompile_profiler_experiment
4008    else:
4009        optimize_ctx = torch._dynamo.optimize(
4010            fx_insert_profiling, nopython=args.nopython
4011        )
4012        experiment = coverage_experiment
4013        output_filename = "coverage.csv"
4014
4015    if args.inductor or args.backend == "inductor" or args.export_aot_inductor:
4016        inductor_config.triton.cudagraphs = not args.disable_cudagraphs
4017        inductor_config.triton.persistent_reductions = (
4018            not args.disable_persistent_reductions
4019        )
4020        inductor_config.split_reductions = not args.disable_split_reductions
4021        inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
4022        if args.inference:
4023            inductor_config.freezing = args.freezing
4024
4025    runner.setup_amp()
4026
4027    if args.output:
4028        output_filename = args.output
4029
4030    if output_filename:
4031        if args.output_directory:
4032            output_filename = os.path.join(args.output_directory, output_filename)
4033        else:
4034            output_filename = os.path.join(
4035                torch._dynamo.config.base_dir, output_filename
4036            )
4037
4038    if args.find_batch_sizes and args.only:
4039        for device in args.devices:
4040            batch_size = runner.batch_size_finder(device, args.only)
4041            print(args.only, batch_size)
4042            output_csv(output_filename, [], [args.only, batch_size])
4043        return
4044
4045    if args.export_profiler_trace:
4046        if args.profiler_trace_name is None:
4047            if args.backend:
4048                args.profiler_trace_name = args.backend
4049            elif args.inductor:
4050                args.profiler_trace_name = "inductor"
4051            else:
4052                args.profiler_trace_name = "profile"
4053        else:
4054            args.profiler_trace_name = args.profiler_trace_name
4055
4056    if args.no_translation_validation:
4057        # Overwrite 'translation_validation' config, if specified.
4058        torch.fx.experimental._config.translation_validation = False
4059
4060    experiment = functools.partial(experiment, args, runner.model_iter_fn)
4061
4062    if args.only and should_diff_branch(args):
4063        import git
4064
4065        repo = git.Repo()
4066        main_branch = repo.active_branch.name
4067        try:
4068            # Adding diff-branch again to the args will override previous value
4069            call_args = (
4070                [sys.executable] + sys.argv + [f"--diff-branch={diff_branch_default}"]
4071            )
4072            # Run for main branch
4073            subprocess.check_call(call_args + [f"--tag={main_branch}"])
4074            # Run for comparison branch
4075            repo.git.checkout(args.diff_branch)
4076            subprocess.check_call(call_args + [f"--tag={args.diff_branch}"])
4077        finally:
4078            # Go back to main branch
4079            repo.git.checkout(main_branch)
4080    elif args.only:
4081        model_name = args.only
4082        for device in args.devices:
4083            batch_size = args.batch_size
4084            if args.batch_size_file:
4085                batch_size = read_batch_size_from_file(
4086                    args, args.batch_size_file, model_name
4087                )
4088            if model_specified_by_path(args.only):
4089                model, example_inputs = load_model_from_path(args.only)
4090                name = model.__class__.__name__
4091                model = model.to(device=device)
4092                example_inputs = tree_map_only(
4093                    torch.Tensor, lambda x: x.to(device=device), example_inputs
4094                )
4095            else:
4096                name = model_name
4097                try:
4098                    with tqdm(desc="loading model"):
4099                        extra_args = []
4100                        if hasattr(args, "rank") and hasattr(args, "world_size"):
4101                            extra_args += [
4102                                "--rank",
4103                                str(args.rank),
4104                                "--world_size",
4105                                str(args.world_size),
4106                            ]
4107
4108                        if args.part:
4109                            (
4110                                device,
4111                                name,
4112                                model,
4113                                example_inputs,
4114                                batch_size,
4115                            ) = runner.load_model(
4116                                device,
4117                                model_name,
4118                                batch_size=batch_size,
4119                                part=args.part,
4120                                extra_args=extra_args,
4121                            )
4122                        else:
4123                            if args.fsdp:
4124                                # Always load model on cpu for fsdp
4125                                # When initializing FSDP, we will use the cuda device if args.cuda is set
4126                                (
4127                                    _,
4128                                    name,
4129                                    model,
4130                                    example_inputs,
4131                                    batch_size,
4132                                ) = runner.load_model(
4133                                    "cpu",
4134                                    model_name,
4135                                    batch_size=batch_size,
4136                                    extra_args=extra_args,
4137                                )
4138                            else:
4139                                (
4140                                    device,
4141                                    name,
4142                                    model,
4143                                    example_inputs,
4144                                    batch_size,
4145                                ) = runner.load_model(
4146                                    device,
4147                                    model_name,
4148                                    batch_size=batch_size,
4149                                    extra_args=extra_args,
4150                                )
4151                except Exception as e:
4152                    import traceback
4153
4154                    mode = "train" if args.training else "eval"
4155                    print(f"{device:4} {mode:5} {name:34} ")
4156                    print(traceback.format_exc())
4157                    status = (
4158                        "model_fail_to_load"
4159                        if isinstance(e, NotImplementedError)
4160                        else "eager_fail_to_run"
4161                    )
4162                    write_csv_when_exception(args, name, status, device)
4163                    continue  # bad benchmark implementation
4164
4165            if args.trace_on_xla:
4166                xla_dev = xm.xla_device()
4167                model = model.to(device=xla_dev)
4168                example_inputs = tree_map_only(
4169                    torch.Tensor, lambda x: x.to(device=xla_dev), example_inputs
4170                )
4171
4172            current_name = name
4173            current_device = device
4174            current_batch_size = batch_size
4175            set_model_name(name)
4176
4177            # Look for stuff that looks like batch size, and mark it dynamic.
4178            # Better integration would integrate directly with benchmark suite
4179            # but cannot conveniently do this
4180            # NB: This must be done late enough so that we don't do more
4181            # conversions on the inputs
4182            # NB: Assumes only the first batch-y like dimension is the batch
4183            marked = False
4184
4185            def detect_and_mark_batch(t):
4186                nonlocal marked
4187                for i, s in enumerate(t.size()):
4188                    if s == batch_size:
4189                        torch._dynamo.mark_dynamic(t, i)
4190                        marked = True
4191                        break
4192
4193            if (
4194                args.dynamic_batch_only
4195                and batch_size > 1
4196                and model_name not in CI_SKIP_DYNAMIC_BATCH_ONLY
4197            ):
4198                tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs)
4199                assert marked, f"nothing in example_inputs had a dim with {batch_size}"
4200
4201            if args.log_operator_inputs:
4202                log_operator_inputs(
4203                    model, example_inputs, runner.model_iter_fn, name, args
4204                )
4205                continue
4206
4207            if args.per_process_memory_fraction != 1:
4208                torch.cuda.set_per_process_memory_fraction(
4209                    args.per_process_memory_fraction
4210                )
4211            if model_name in DO_NOT_CAST_INPUTS:
4212                model, _ = runner.cast_based_on_args(model, example_inputs)
4213
4214            else:
4215                model, example_inputs = runner.cast_based_on_args(model, example_inputs)
4216            runner.setup_amp(current_device)
4217            guard_ctx = contextlib.nullcontext()
4218            if name in runner.guard_on_nn_module_models:
4219                guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True)
4220
4221            with guard_ctx:
4222                runner.run_one_model(
4223                    name,
4224                    model,
4225                    example_inputs,
4226                    optimize_ctx,
4227                    experiment,
4228                    explain=args.explain,
4229                    tag=args.tag,
4230                )
4231        if args.generate_aot_autograd_stats:
4232            stats_file = output_filename.split(".csv")[0] + "_stats.csv"
4233            output_csv(
4234                stats_file,
4235                ("dev", "name", "batch_size", "total_aot_graphs", "ok_aot_graphs"),
4236                [
4237                    current_device,
4238                    current_name,
4239                    current_batch_size,
4240                    *Stats.aot_summary(),
4241                ],
4242            )
4243    else:
4244        metrics.purge_old_log_files()
4245        if output_filename and os.path.exists(output_filename):
4246            os.unlink(output_filename)
4247        if original_dir:
4248            os.chdir(original_dir)
4249        model_names = list(runner.iter_model_names(args))
4250        nmodels = len(model_names)
4251        for i, name in enumerate(model_names):
4252            current_name = name
4253            if args.progress:
4254                print(f"Running model {i+1}/{nmodels}", flush=True)
4255
4256            try:
4257                timeout = args.timeout
4258                if should_diff_branch(args):
4259                    timeout *= 2
4260                env = os.environ.copy()
4261                if args.ci and name in CI_PRESERVE_COMPILE_DEBUG:
4262                    env["TORCH_COMPILE_DEBUG"] = "1"
4263                subprocess.check_call(
4264                    [sys.executable] + sys.argv + [f"--only={name}"],
4265                    timeout=timeout,
4266                    env=env,
4267                )
4268            except subprocess.TimeoutExpired:
4269                write_csv_when_exception(args, name, "timeout")
4270            except subprocess.CalledProcessError as e:
4271                print("Run failed with return code: ", e.returncode, file=sys.stderr)
4272                print("Output: ", e.output, file=sys.stderr)
4273                print("Error: ", e.stderr, file=sys.stderr)
4274        print_summary(output_filename, print_dataframe=args.print_dataframe_summary)
4275
4276
4277def log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
4278    mode = "training" if args.training else "eval"
4279    output = os.path.join(os.path.dirname(args.output), f"{name}_{mode}.txt")
4280
4281    # TODO - add option for coalescing inputs over multiple runs
4282    if os.path.exists(output):
4283        print(f"Skipping {name}, {output} already exists")
4284        return
4285
4286    print(f"Running {name}")
4287    try:
4288        from .microbenchmarks.operator_inp_utils import OperatorInputsMode
4289    except ImportError:
4290        from microbenchmarks.operator_inp_utils import OperatorInputsMode
4291
4292    operator_mode = OperatorInputsMode()
4293    fake_tensor_mode = FakeTensorMode()
4294
4295    with torch._subclasses.fake_tensor.FakeCopyMode(fake_tensor_mode):
4296        model_fake = copy.deepcopy(model)
4297        example_inputs_fake = copy.deepcopy(example_inputs)
4298    try:
4299        with fake_tensor_mode, operator_mode:
4300            model_iter_fn(model_fake, example_inputs_fake, collect_outputs=False)
4301    except Exception as e:
4302        print(f"{name} failed to run with fake tensors, trying real. Exception: {e}")
4303        operator_mode = OperatorInputsMode()
4304        try:
4305            with operator_mode:
4306                model_iter_fn(model, example_inputs, collect_outputs=False)
4307        except Exception as e2:
4308            print(f"{name} failed to run with real. Exception: {e2}")
4309            raise
4310
4311    print(f"Writing output to {output}")
4312    operator_mode.log_to_file(output)
4313
4314
4315if __name__ == "__main__":
4316    raise RuntimeError(
4317        f"You shouldn't run {sys.argv[0]} directly, instead try timm_model.py, torchbench.py or huggingface.py"
4318    )
4319