xref: /aosp_15_r20/external/pytorch/torch/_dynamo/repro/after_aot.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import argparse
3import copy
4import functools
5import io
6import logging
7import os
8import shutil
9import subprocess
10import sys
11import textwrap
12import uuid
13from importlib import import_module
14from tempfile import TemporaryFile
15from typing import Any, Callable, Dict, Union
16
17import torch
18import torch.fx as fx
19import torch.nn as nn
20from torch._dynamo.debug_utils import (
21    _cuda_system_info_comment,
22    AccuracyError,
23    backend_accuracy_fails,
24    BuckTargetWriter,
25    cast_to_fp64,
26    extra_imports,
27    generate_config_string,
28    helper_for_dump_minify,
29    InputReader,
30    InputWriter,
31    MAX_CONSTANT_NUMEL_INLINE,
32    minifier_dir,
33    NNModuleToString,
34    NopInputReader,
35    same_two_models,
36)
37from torch._dynamo.utils import clone_inputs, counters, same
38from torch.fx.experimental.proxy_tensor import make_fx
39from torch.fx.experimental.symbolic_shapes import (
40    fx_placeholder_targets,
41    has_free_symbols,
42)
43from torch.hub import tqdm
44
45from .. import config
46
47
48log = logging.getLogger(__name__)
49
50
51inductor_config = import_module("torch._inductor.config")
52use_buck = inductor_config.is_fbcode()
53
54# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
55#                           MAIN ENTRY POINT
56# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
57
58
59def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str):
60    """
61    Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
62    forward and backward call separately with the backend compiler_fn - like
63    inductor or nvfuser. Intercepting after Aot Autograd presents neat
64    abstraction, where all the params are lifted as graph inputs, making it easy
65    to save the graph as a string.
66    """
67
68    @functools.wraps(unconfigured_compiler_fn)
69    def debug_wrapper(gm, example_inputs, **kwargs):
70        from torch._subclasses import FakeTensorMode
71
72        compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
73
74        from torch._functorch.aot_autograd import get_aot_graph_name
75
76        graph_name = get_aot_graph_name()
77
78        # TODO: why do we need to deepcopy the original graph?
79        orig_graph = copy.deepcopy(gm.graph)
80        assert config.repro_after in ("dynamo", "aot", None)
81
82        try:
83            # Call the compiler_fn - which is either aot_autograd or inductor
84            # with fake inputs
85            inner_compiled_fn = compiler_fn(gm, example_inputs)
86        except Exception as e:
87            # TODO: Failures here are troublesome because no real inputs,
88            # need a different serialization strategy
89            if config.repro_after == "aot":
90                if config.repro_level == 1:
91                    dump_compiler_graph_state(
92                        fx.GraphModule(gm, orig_graph),
93                        example_inputs,
94                        compiler_name,
95                    )
96                elif config.repro_level == 2:
97                    dump_to_minify(
98                        fx.GraphModule(gm, orig_graph),
99                        example_inputs,
100                        compiler_name,
101                    )
102                log.error("CompilerError")
103            raise
104
105        # We may run regular PyTorch compute that may trigger Dynamo, do NOT
106        # recursively attempt to accuracy minify in that case!
107        def deferred_for_real_inputs(real_inputs):
108            # This is a bit obscure: if we recursively try to accuracy minify
109            # the SAME function, this would trigger.  But most of the time
110            # we should never hit this branch
111            if config.repro_after != "aot":
112                return inner_compiled_fn(real_inputs)
113            with config.patch(repro_after=None):
114                return inner_debug_fn(real_inputs)
115
116        def inner_debug_fn(real_inputs):
117            """
118            Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
119            example_inputs can be fake tensors. We can call compiler_fn (which is
120            inductor or nvfuser) with fake tensors but the actually compiled_fn
121            should be called with real tensors. Therefore, the actual invocation
122            is deferred.
123            """
124            # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor
125            # because inductor clears the tensor list in its codegen. And example_inputs
126            # are available only for the first invocation.
127            fake_mode = FakeTensorMode()
128            copy_tensor_attrs = [
129                fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x
130                for x in real_inputs
131            ]
132            if config.repro_level == 3:
133                # Always dump the original module in case we have segfaults
134                dump_to_minify(
135                    fx.GraphModule(gm, orig_graph), real_inputs, compiler_name
136                )
137
138            if config.repro_level == 4:
139                if compiler_name != "inductor":
140                    raise NotImplementedError(
141                        "Accuracy minification is supported for inductor only"
142                    )
143                failed = not same_two_models(
144                    gm,
145                    inner_compiled_fn,
146                    real_inputs,
147                    only_fwd=True,
148                    ignore_non_fp=config.repro_ignore_non_fp,
149                )
150
151                if failed:
152                    log.warning(
153                        "Accuracy failed for the AOT Autograd graph %s", graph_name
154                    )
155                    dump_compiler_graph_state(
156                        fx.GraphModule(gm, orig_graph),
157                        real_inputs,
158                        f"{compiler_name}_accuracy",
159                    )
160                    dump_to_minify(
161                        fx.GraphModule(gm, orig_graph),
162                        real_inputs,
163                        f"{compiler_name}_accuracy",
164                    )
165                    raise AccuracyError("Bad accuracy detected")
166                else:
167                    # Call the compiled function with real inputs
168                    return inner_compiled_fn(real_inputs)
169            else:
170                try:
171                    # Call the compiled function with real inputs
172                    out = inner_compiled_fn(real_inputs)
173                    # sync cuda kernels to ensure IMA detection
174                    for arg in example_inputs:
175                        if isinstance(arg, torch.Tensor) and arg.is_cuda:
176                            torch.cuda.synchronize()
177                            break
178                    return out
179                except Exception as e:
180                    if config.repro_level == 1:
181                        dump_compiler_graph_state(
182                            fx.GraphModule(gm, orig_graph),
183                            copy_tensor_attrs,
184                            compiler_name,
185                        )
186                    elif config.repro_level == 2:
187                        dump_to_minify(
188                            fx.GraphModule(gm, orig_graph),
189                            copy_tensor_attrs,
190                            compiler_name,
191                        )
192                    raise
193
194        if config.repro_after == "aot":
195            compiled_fn = deferred_for_real_inputs
196            compiled_fn._boxed_call = True  # type: ignore[attr-defined]
197            return compiled_fn
198        else:
199            return inner_compiled_fn
200
201    return debug_wrapper
202
203
204# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
205#                           DUMP REPROS
206# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
207
208
209def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=None):
210    model_str = textwrap.dedent(
211        f"""
212import torch
213from torch import tensor, device
214import torch.fx as fx
215from torch._dynamo.testing import rand_strided
216from math import inf
217import torch._inductor.inductor_prims
218
219{generate_config_string(stable_output=stable_output)}
220
221isolate_fails_code_str = None
222
223{extra_imports}
224
225        """
226    )
227    if not stable_output:
228        model_str += f"# torch version: {torch.version.__version__}\n"
229        if hasattr(torch.version, "cuda"):
230            model_str += f"# torch cuda version: {torch.version.cuda}\n"
231        if hasattr(torch.version, "git_version"):
232            model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
233        model_str += _cuda_system_info_comment()
234
235    model_str += NNModuleToString.convert(gm)
236
237    # get hint shape/stride when dynamic shape enabled
238    def hint_if_symint(x):
239        return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x)
240
241    writer = InputWriter(save_dir)
242    for placeholder, arg in zip(fx_placeholder_targets(gm), args):
243        if isinstance(arg, (int, torch.SymInt)):
244            writer.symint(placeholder, arg)
245        elif isinstance(arg, torch.Tensor):
246            # TODO: improve these names with FQN
247            writer.tensor(placeholder, arg)
248        else:
249            raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}")
250
251    model_str += "\n".join(writer.lines()) + "\n"
252
253    model_str += "mod = Repro()\n"
254    return model_str
255
256
257def save_graph_repro(
258    fd,
259    gm,
260    args,
261    compiler_name,
262    *,
263    stable_output=False,
264    save_dir=None,
265    command="run",
266    accuracy=None,
267    tracing_mode=None,
268    check_str=None,
269):
270    if any(
271        isinstance(arg, torch.fx.experimental._backward_state.BackwardState)
272        for arg in args
273    ):
274        fd.write(
275            "Repro is not generated due to existence of BackwardState in graph input"
276        )
277        return
278    fd.write(
279        generate_compiler_repro_string(
280            gm,
281            args,
282            stable_output=stable_output,
283            save_dir=save_dir,
284        )
285    )
286    if accuracy is None:
287        accuracy = "_accuracy" in compiler_name
288    if tracing_mode is None:
289        tracing_mode = "real"
290        if any(has_free_symbols(a) for a in args):
291            tracing_mode = "symbolic"
292    fd.write("if __name__ == '__main__':\n")
293    fd.write("    from torch._dynamo.repro.after_aot import run_repro\n")
294    fd.write(
295        f"    with torch.no_grad():\n"
296        f"        run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, "
297        f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n"
298        f"        # To run it separately, do \n"
299        f"        # mod, args = run_repro(mod, load_args, accuracy={accuracy!r}, command='get_args', "
300        f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n"
301        f"        # mod(*args)"
302    )
303
304
305def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None):
306    subdir = os.path.join(minifier_dir(), "checkpoints")
307    if not os.path.exists(subdir):
308        os.makedirs(subdir, exist_ok=True)
309    file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
310    log.warning(
311        "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name
312    )
313    with open(file_name, "w") as fd:
314        save_graph_repro(
315            fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy
316        )
317    curdir = os.getcwd()
318    repro_path = os.path.join(curdir, "repro.py")
319    try:
320        shutil.copyfile(file_name, repro_path)
321        log.warning("Copying repro file for convenience to %s", repro_path)
322        if use_buck:
323            BuckTargetWriter(file_name).write()
324    except OSError:
325        log.warning("No write permissions for %s", repro_path)
326
327
328# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
329#                           DUMP MINIFIER
330# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
331
332
333def dump_to_minify(gm, args, compiler_name: str):
334    out = io.StringIO()
335    # TODO: factor this out
336    subdir = os.path.join(minifier_dir(), "checkpoints")
337    if not os.path.exists(subdir):
338        os.makedirs(subdir, exist_ok=True)
339    save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify")
340    return helper_for_dump_minify(out.getvalue())
341
342
343def isolate_fails(
344    fx_g,
345    args,
346    compiler_name: str,
347    env=None,
348    save_dir=None,
349    accuracy=None,
350    tracing_mode=None,
351    check_str=None,
352):
353    if env is None:
354        env = {}
355    subdir = os.path.join(os.getcwd(), "isolate")
356    if not os.path.exists(subdir):
357        os.makedirs(subdir, exist_ok=True)
358    file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
359    with open(file_name, "w") as fd:
360        save_graph_repro(
361            fd,
362            fx_g,
363            args,
364            compiler_name,
365            save_dir=save_dir,
366            command="minifier-query",
367            accuracy=accuracy,
368            tracing_mode=tracing_mode,
369            check_str=check_str,
370        )
371    # with open(file_name, "r") as fd:
372    #     print(fd.read())
373    new_env = os.environ.copy()
374    new_env = {**new_env, **env}
375    stdout, stderr = TemporaryFile(), TemporaryFile()
376
377    if use_buck:
378        cmd = BuckTargetWriter(file_name).write(print_msg=False)
379    else:
380        cmd = ["python", file_name]
381
382    p = subprocess.Popen(
383        cmd,
384        cwd=subdir,
385        stdout=stdout,
386        stderr=stderr,
387        env=new_env,
388    )
389    p.wait()
390
391    stdout.seek(0)
392    stderr.seek(0)
393    print(
394        textwrap.indent(stdout.read().decode("utf-8"), prefix=">>  "), file=sys.stdout
395    )
396    print(
397        textwrap.indent(stderr.read().decode("utf-8"), prefix=">>  "), file=sys.stderr
398    )
399    # print(f"Isolated test failed - {file_name}")
400    return p.returncode != 0
401
402
403# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
404#                       MINIFIER TOOLS
405# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
406
407
408def inductor_fails(fx_g, args, check_str=None):
409    has_cuda = False
410    for arg in args:
411        if isinstance(arg, torch.Tensor) and arg.is_cuda:
412            has_cuda = True
413            break
414
415    def sync():
416        if has_cuda:
417            # Ensures that segfaults are surfaced
418            torch.cuda.synchronize()
419
420    from torch._inductor.compile_fx import compile_fx_inner
421
422    try:
423        result = fx_g(*args)
424        assert isinstance(result, (tuple, list))
425        assert not any(isinstance(x, (tuple, list)) for x in result)
426    except Exception:
427        return False
428
429    sync()
430
431    try:
432        compile_mod = compile_fx_inner(fx_g, args)
433        compile_mod(args)
434        sync()
435    except Exception as e:
436        if check_str is not None and check_str not in repr(e):
437            return False
438        print(repr(e))
439        return True
440    return False
441
442
443def inductor_accuracy_fails(
444    fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False
445):
446    from torch._inductor.compile_fx import compile_fx_inner
447
448    return backend_aot_accuracy_fails(
449        fx_g,
450        args,
451        compile_fx_inner,
452        require_fp64=require_fp64,
453        ignore_non_fp=ignore_non_fp,
454    )
455
456
457backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True)
458
459
460# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
461#                           REPRO MAIN
462# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
463
464
465def repro_common(options, mod, load_args):
466    # Invariant for graphs we generate with the repro script
467    assert not any(mod.named_parameters())
468    for n, b in mod.named_buffers():
469        if b.numel() > MAX_CONSTANT_NUMEL_INLINE:
470            log.warning(
471                "Constant %s was not serialized, generated random data instead. "
472                "If you think this is affecting you, please comment on "
473                "https://github.com/pytorch/pytorch/issues/100468",
474                n,
475            )
476
477    if not hasattr(load_args, "_version"):
478        log.warning(
479            "load_args does not have a _version attribute, please file a bug to PyTorch "
480            "and describe how you generate this repro script"
481        )
482    else:
483        if load_args._version > 0:
484            log.warning(
485                "load_args is version %s, but this version of PyTorch only supports "
486                "version 0.  We will try to run it anyway but there may be an incompatibility; "
487                "if so, try upgrading your version of PyTorch.",
488                load_args._version,
489            )
490
491    nop_reader = NopInputReader()
492    load_args(nop_reader)
493
494    with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar:
495        input_reader = InputReader(save_dir=options.save_dir, pbar=pbar)
496        load_args(input_reader)
497        args = input_reader.args
498
499    # Turn mod into a GraphModule the slow way
500    # TODO: speed this up
501    mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
502
503    torch._inductor.config.generate_intermediate_hooks = True
504
505    return mod, args
506
507
508ACCURACY_FAILS: Dict[str, Callable[[nn.Module, Any], bool]] = {
509    "": inductor_fails,
510    # This might look inverted but it's not.  strict_accuracy means "we will
511    # minify any time we see anything that diverges", whereas accuracy is more
512    # conservative, and will only minify if there is a meaningful fp64
513    # divergence
514    "accuracy": functools.partial(
515        inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True
516    ),
517    "strict_accuracy": inductor_accuracy_fails,
518}
519
520
521def repro_minifier_query(options, mod, load_args):
522    mod, args = repro_common(options, mod, load_args)
523    fail_fn = functools.partial(
524        ACCURACY_FAILS[options.accuracy], check_str=options.check_str
525    )
526    if fail_fn(mod, args):
527        sys.exit(1)
528    else:
529        sys.exit(0)
530
531
532def repro_minify(options, mod, load_args):
533    from functorch.compile import minifier
534
535    mod, args = repro_common(options, mod, load_args)
536    compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor"
537
538    favored_device = 1 if torch.cuda.device_count() >= 2 else 0
539    env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)}
540
541    module_fails: Any
542    if options.isolate:
543        module_fails = functools.partial(
544            isolate_fails,
545            env=env_variables,
546            compiler_name=compiler_name,
547            save_dir=options.save_dir,
548            accuracy=options.accuracy,
549            tracing_mode=options.tracing_mode,
550        )
551    else:
552        module_fails = ACCURACY_FAILS[options.accuracy]
553
554    minifier(
555        mod,
556        args,
557        module_fails=functools.partial(module_fails, check_str=options.check_str),
558        dump_state=functools.partial(
559            dump_compiler_graph_state, compiler_name=compiler_name
560        ),
561        save_dir=options.save_dir,
562        offload_to_disk=options.offload_to_disk,
563        skip_offload=options.skip_saving_eager_intermediates,
564        skip_sanity=options.skip_sanity,
565        max_granularity=options.max_granularity,
566    )
567
568
569def repro_analyze(options, mod, load_args):
570    from torch._inductor.compile_fx import compile_fx_inner
571    from torch._inductor.hooks import intermediate_hook
572
573    mod, args = repro_common(options, mod, load_args)
574
575    # TODO: The logic for cloning inputs/models here is intentionally
576    # modeled off of run_fwd_maybe_bwd, but arguably it is better not to
577    # clone inputs (as you are doubling your effective GPU memory usage).
578    # It is certainly faster though!  It probably makes sense to let the
579    # user specify the offload strategy.
580
581    with tqdm(desc="Compiling"):
582        compiled = compile_fx_inner(mod, args)
583    total = counters["inductor"]["intermediate_hooks"]
584
585    known_names = set()
586
587    def save_hook(name, val):
588        known_names.add(name)
589        if not options.skip_saving_inductor_intermediates:
590            writer.write_tensor(os.path.join("inductor", name), val)
591        pbar.update(1)  # type: ignore[has-type]
592
593    writer = torch.utils._content_store.ContentStoreWriter(
594        options.save_dir, stable_hash=options.stable_hash
595    )
596    reader = torch.utils._content_store.ContentStoreReader(options.save_dir)
597
598    new_args = clone_inputs(args)
599    with intermediate_hook(save_hook), tqdm(
600        desc="Saving inductor intermediates", total=total
601    ) as pbar:
602        compiled(new_args)
603        assert not new_args
604
605    def compare_tuples(tuple1, tuple2):
606        diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]]
607        diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices]
608
609        if not diff_values:
610            return None
611        else:
612            return " and ".join(f"{a} != {b}" for a, b in diff_values)
613
614    def check_hook(name, val):
615        meta = writer.compute_tensor_metadata(val)
616        meta2 = reader.read_tensor_metadata(os.path.join("inductor", name))
617        reason = compare_tuples(meta, meta2)
618        if reason is not None:
619            pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})")
620        pbar.update(1)
621
622    if not options.skip_check_deterministic:
623        new_args = clone_inputs(args)
624        with intermediate_hook(check_hook), tqdm(
625            desc="Checking inductor determinism", total=total
626        ) as pbar:
627            compiled(new_args)
628            assert not new_args
629
630    class WriterInterp(fx.Interpreter):
631        def __init__(self, mod, subdir) -> None:
632            super().__init__(mod)
633            self.subdir = subdir
634
635        def run_node(self, n):
636            r = super().run_node(n)
637            name = n.name
638            if name in known_names:
639                pbar.update(1)
640                writer.write_tensor(os.path.join(self.subdir, name), r)
641            return r
642
643    # NB: the module cast doesn't actually do anything, since there are no
644    # parameters/buffers on the module
645    if not options.skip_saving_float64_intermediates:
646        new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args))
647        with tqdm(desc="Saving float64 intermediates", total=total) as pbar:
648            WriterInterp(new_mod, "float64").boxed_run(new_args)
649        assert not new_args
650
651    class ExactReaderInterp(fx.Interpreter):
652        def run_node(self, n):
653            r = super().run_node(n)
654            name = n.name
655            if name in known_names:
656                meta = writer.compute_tensor_metadata(r)
657                meta2 = reader.read_tensor_metadata(os.path.join("float64", name))
658                reason = compare_tuples(meta, meta2)
659                if reason is not None:
660                    pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})")
661                pbar.update(1)
662            return r
663
664    # TODO: check eager determinism
665
666    if not options.skip_check_deterministic:
667        new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args))
668        with tqdm(desc="Checking float64 determinism", total=total) as pbar:
669            ExactReaderInterp(new_mod).boxed_run(new_args)
670            assert not new_args
671
672    # Now that we've saved everything, interp through the eager graph
673    # and do comparisons
674    class ReaderInterp(fx.Interpreter):
675        def run_node(self, n):
676            r = super().run_node(n)
677            name = n.name
678            if name in known_names:
679                inductor = reader.read_tensor(os.path.join("inductor", name))
680                float64 = reader.read_tensor(os.path.join("float64", name))
681                logged = False
682
683                def log_error(msg, *args):
684                    nonlocal logged
685                    logged = True
686                    pbar.write(f"DIVERGED at {name}: {msg % args}")
687
688                if not same(
689                    r,
690                    inductor,
691                    float64,
692                    tol=torch._dynamo.config.repro_tolerance,
693                    equal_nan=True,
694                    log_error=log_error,
695                ):
696                    assert logged
697                pbar.update(1)
698            return r
699
700    with tqdm(desc="Checking divergence", total=total) as pbar:
701        ReaderInterp(mod).boxed_run(args)
702    assert not args
703
704
705def repro_get_args(options, mod, load_args):
706    mod, args = repro_common(options, mod, load_args)
707    return mod, args
708
709
710def repro_run(options, mod, load_args):
711    from torch._inductor.compile_fx import compile_fx_inner
712
713    mod, args = repro_common(options, mod, load_args)
714
715    from torch.cuda import synchronize
716
717    compiled = compile_fx_inner(mod, args)
718
719    if options.accuracy != "":
720        # We don't really respect --accuracy vs --strict-accuracy here, it
721        # seems counterintuitive
722        if not same_two_models(
723            mod,
724            compiled,
725            args,
726            only_fwd=True,
727            ignore_non_fp=config.repro_ignore_non_fp,
728        ):
729            raise AccuracyError("Bad accuracy detected")
730    else:
731        need_sync = False
732        for arg in args:
733            if isinstance(arg, torch.Tensor) and arg.is_cuda:
734                need_sync = True
735                break
736        ref = compiled(list(args))
737        if need_sync:
738            synchronize()  # ensure segfaults are surfaced
739    return lambda: compiled(list(args))
740
741
742# TODO: lazily load the inputs or something, rather than cloning them
743def run_repro(
744    mod,
745    load_args,
746    *,
747    command="run",
748    accuracy: Union[bool, str] = "",
749    save_dir=None,
750    tracing_mode=None,
751    patch_code=None,
752    check_str=None,
753    **kwargs,
754):
755    for k in kwargs:
756        log.warning(
757            "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch",
758            k,
759        )
760
761    if accuracy is True:
762        accuracy = "accuracy"
763    elif accuracy is False:
764        accuracy = ""
765
766    if patch_code is not None:
767        log.warning(
768            "patch_code no longer works on this version of PyTorch, silently ignoring"
769        )
770
771    parser = argparse.ArgumentParser(
772        description=f"""\
773An after_aot repro script, typically triggering a bug in PyTorch Inductor.
774When run with no arguments, this script defaults to running '{command}'.
775Extra flags may be available; to find out more, try '{command} --help'.
776There are also alternate subcommands available, see below.
777
778default settings on this script:
779  {accuracy=}
780  {tracing_mode=}
781  {save_dir=}
782  {check_str=}
783""",
784        formatter_class=argparse.RawTextHelpFormatter,
785    )
786
787    def common_flags(parser):
788        accuracy_group = parser.add_mutually_exclusive_group()
789        accuracy_group.add_argument(
790            "--no-accuracy",
791            dest="accuracy",
792            action="store_const",
793            const="",
794            default=accuracy,
795            help="do not test accuracy, just run the module and see if it errors",
796        )
797        accuracy_group.add_argument(
798            "--accuracy",
799            action="store_const",
800            const="accuracy",
801            default=accuracy,
802            help="""\
803test if the RMSE between the compiled module and the fp64 reference is greater
804than eager and the fp64 reference. This is usually more reliable than the
805standard allclose test, as we expect numeric differences from compiling, often
806improving accuracy over eager.  RMSE test allows for compiled module to
807diverge greatly from eager, as long as this divergence moves it closer to the
808'true' mathematical value of the network.  Caveats: (1) double precision can
809still suffer from rounding error, so it is not a perfect reference (see for
810example 'Herbie: Automatically Improving Floating Point Accuracy') for
811approaches that detect the necessary working precision and compute it in
812arbitrary precision floating point; unfortunately, this is not practical for
813tensor computation; (2) if there are not enough samples in the output being
814compared, we may get unlucky and have an unlucky greater RMSE than eager; this
815could be overcome by applying a more rigorous statistical test at some
816p-value, which we leave for future work.
817""",
818        )
819        accuracy_group.add_argument(
820            "--strict-accuracy",
821            dest="accuracy",
822            action="store_const",
823            const="strict_accuracy",
824            default=accuracy,
825            help="""\
826by default, when doing accuracy minification we will reject reductions which
827change the divergence from a floating point divergence to a integral/boolean
828divergence.  This is because some operations like ReLU involve temporarily
829sharp boundaries that smooth out again afterwards; without requiring
830divergence on floating point, the minifier will often fixate on divergent
831boolean tensor even though this is not the true source of the divergence.
832However, rejecting these reductions makes it more difficult for the minifier
833to make process.  Using this option will let the minifier progress for ALL
834divergences--you just might not end up with a useful repro in the end.""",
835        )
836
837        parser.add_argument(
838            "--save-dir",
839            type=str,
840            default=save_dir,
841            metavar="DIR",
842            help="directory where saved inputs live",
843        )
844        parser.add_argument(
845            "--no-save-dir",
846            dest="save_dir",
847            action="store_const",
848            const=None,
849            help="don't use any directory for saved inputs",
850        )
851        parser.add_argument(
852            "--tracing-mode",
853            type=str,
854            metavar="{real,fake,symbolic}",
855            default=tracing_mode,
856            help="how to trace the repro module into a GraphModule with metadata",
857        )
858
859    subparsers = parser.add_subparsers(
860        dest="command", metavar="{run,minify,analyze}", required=True
861    )
862
863    parser_run = subparsers.add_parser(
864        "run",
865        help="just run the repro",
866    )
867    common_flags(parser_run)
868
869    parser_minify = subparsers.add_parser(
870        "minify", help="run the minifier on the repro"
871    )
872    common_flags(parser_minify)
873    parser_get_args = subparsers.add_parser("get_args", help="get the args")
874    common_flags(parser_get_args)
875    parser_minify_isolate = parser_minify.add_mutually_exclusive_group()
876    parser_minify_isolate.add_argument(
877        "--isolate",
878        action="store_true",
879        default=True,
880        help="run in separate processes to avoid interference (default)",
881    )
882    parser_minify_isolate.add_argument(
883        "--no-isolate",
884        dest="isolate",
885        action="store_false",
886        help="speed up by running all compilation in same process",
887    )
888    parser_minify.add_argument(
889        "--skip-saving-eager-intermediates",
890        action="store_true",
891        help="skip saving eager intermediates on --minify",
892    )
893    # TODO: make this an option for --analyze too
894    parser_minify.add_argument(
895        "--offload-to-disk",
896        action="store_true",
897        help="during minification, offload delta debugging intermediates to disk.  Use if you're OOMing",
898    )
899    parser_minify.add_argument(
900        "--skip-sanity",
901        action="store_true",
902        help="skip sanity check at beginning of minification on original graph",
903    )
904    parser_minify.add_argument(
905        "--max-granularity",
906        type=int,
907        default=None,
908        help="start at this granularity and work down; must be power of 2",
909    )
910    parser_minify.add_argument(
911        "--check-str",
912        type=str,
913        default=check_str,
914        help="require minified program to fail with error containing this string",
915    )
916
917    parser_analyze = subparsers.add_parser(
918        "analyze", help="run the accuracy analyzer on the repro"
919    )
920    common_flags(parser_analyze)
921    parser_analyze.add_argument(
922        "--skip-saving-inductor-intermediates",
923        action="store_true",
924        help="skip saving inductor intermediates on --analyze",
925    )
926    parser_analyze.add_argument(
927        "--skip-saving-float64-intermediates",
928        action="store_true",
929        help="skip saving float64 intermediates",
930    )
931    parser_analyze.add_argument(
932        "--skip-check-deterministic",
933        action="store_true",
934        help="skip checking that the network is deterministic",
935    )
936    parser_analyze.add_argument(
937        "--stable-hash",
938        action="store_true",
939        help="use SHA-1 checksum instead of fast (but possibly unsound) hash",
940    )
941
942    # Run the repro in the context of minification, inverting exit code meaning
943    parser_minifier_query = subparsers.add_parser(
944        "minifier-query",
945    )
946    common_flags(parser_minifier_query)
947    parser_minifier_query.add_argument(
948        "--check-str",
949        type=str,
950        default=check_str,
951        help="require minified program to fail with error containing this string",
952    )
953
954    args = None
955    if len(sys.argv) <= 1:
956        args = [command, *sys.argv[1:]]
957
958    options = parser.parse_args(args)
959    COMMAND_FNS = {
960        "minify": repro_minify,
961        "analyze": repro_analyze,
962        "minifier-query": repro_minifier_query,
963        "run": repro_run,
964        "get_args": repro_get_args,
965    }
966    return COMMAND_FNS[options.command](options, mod, load_args)
967