1# mypy: allow-untyped-defs 2import argparse 3import copy 4import functools 5import io 6import logging 7import os 8import shutil 9import subprocess 10import sys 11import textwrap 12import uuid 13from importlib import import_module 14from tempfile import TemporaryFile 15from typing import Any, Callable, Dict, Union 16 17import torch 18import torch.fx as fx 19import torch.nn as nn 20from torch._dynamo.debug_utils import ( 21 _cuda_system_info_comment, 22 AccuracyError, 23 backend_accuracy_fails, 24 BuckTargetWriter, 25 cast_to_fp64, 26 extra_imports, 27 generate_config_string, 28 helper_for_dump_minify, 29 InputReader, 30 InputWriter, 31 MAX_CONSTANT_NUMEL_INLINE, 32 minifier_dir, 33 NNModuleToString, 34 NopInputReader, 35 same_two_models, 36) 37from torch._dynamo.utils import clone_inputs, counters, same 38from torch.fx.experimental.proxy_tensor import make_fx 39from torch.fx.experimental.symbolic_shapes import ( 40 fx_placeholder_targets, 41 has_free_symbols, 42) 43from torch.hub import tqdm 44 45from .. import config 46 47 48log = logging.getLogger(__name__) 49 50 51inductor_config = import_module("torch._inductor.config") 52use_buck = inductor_config.is_fbcode() 53 54# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 55# MAIN ENTRY POINT 56# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 57 58 59def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): 60 """ 61 Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both 62 forward and backward call separately with the backend compiler_fn - like 63 inductor or nvfuser. Intercepting after Aot Autograd presents neat 64 abstraction, where all the params are lifted as graph inputs, making it easy 65 to save the graph as a string. 66 """ 67 68 @functools.wraps(unconfigured_compiler_fn) 69 def debug_wrapper(gm, example_inputs, **kwargs): 70 from torch._subclasses import FakeTensorMode 71 72 compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) 73 74 from torch._functorch.aot_autograd import get_aot_graph_name 75 76 graph_name = get_aot_graph_name() 77 78 # TODO: why do we need to deepcopy the original graph? 79 orig_graph = copy.deepcopy(gm.graph) 80 assert config.repro_after in ("dynamo", "aot", None) 81 82 try: 83 # Call the compiler_fn - which is either aot_autograd or inductor 84 # with fake inputs 85 inner_compiled_fn = compiler_fn(gm, example_inputs) 86 except Exception as e: 87 # TODO: Failures here are troublesome because no real inputs, 88 # need a different serialization strategy 89 if config.repro_after == "aot": 90 if config.repro_level == 1: 91 dump_compiler_graph_state( 92 fx.GraphModule(gm, orig_graph), 93 example_inputs, 94 compiler_name, 95 ) 96 elif config.repro_level == 2: 97 dump_to_minify( 98 fx.GraphModule(gm, orig_graph), 99 example_inputs, 100 compiler_name, 101 ) 102 log.error("CompilerError") 103 raise 104 105 # We may run regular PyTorch compute that may trigger Dynamo, do NOT 106 # recursively attempt to accuracy minify in that case! 107 def deferred_for_real_inputs(real_inputs): 108 # This is a bit obscure: if we recursively try to accuracy minify 109 # the SAME function, this would trigger. But most of the time 110 # we should never hit this branch 111 if config.repro_after != "aot": 112 return inner_compiled_fn(real_inputs) 113 with config.patch(repro_after=None): 114 return inner_debug_fn(real_inputs) 115 116 def inner_debug_fn(real_inputs): 117 """ 118 Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, 119 example_inputs can be fake tensors. We can call compiler_fn (which is 120 inductor or nvfuser) with fake tensors but the actually compiled_fn 121 should be called with real tensors. Therefore, the actual invocation 122 is deferred. 123 """ 124 # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor 125 # because inductor clears the tensor list in its codegen. And example_inputs 126 # are available only for the first invocation. 127 fake_mode = FakeTensorMode() 128 copy_tensor_attrs = [ 129 fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x 130 for x in real_inputs 131 ] 132 if config.repro_level == 3: 133 # Always dump the original module in case we have segfaults 134 dump_to_minify( 135 fx.GraphModule(gm, orig_graph), real_inputs, compiler_name 136 ) 137 138 if config.repro_level == 4: 139 if compiler_name != "inductor": 140 raise NotImplementedError( 141 "Accuracy minification is supported for inductor only" 142 ) 143 failed = not same_two_models( 144 gm, 145 inner_compiled_fn, 146 real_inputs, 147 only_fwd=True, 148 ignore_non_fp=config.repro_ignore_non_fp, 149 ) 150 151 if failed: 152 log.warning( 153 "Accuracy failed for the AOT Autograd graph %s", graph_name 154 ) 155 dump_compiler_graph_state( 156 fx.GraphModule(gm, orig_graph), 157 real_inputs, 158 f"{compiler_name}_accuracy", 159 ) 160 dump_to_minify( 161 fx.GraphModule(gm, orig_graph), 162 real_inputs, 163 f"{compiler_name}_accuracy", 164 ) 165 raise AccuracyError("Bad accuracy detected") 166 else: 167 # Call the compiled function with real inputs 168 return inner_compiled_fn(real_inputs) 169 else: 170 try: 171 # Call the compiled function with real inputs 172 out = inner_compiled_fn(real_inputs) 173 # sync cuda kernels to ensure IMA detection 174 for arg in example_inputs: 175 if isinstance(arg, torch.Tensor) and arg.is_cuda: 176 torch.cuda.synchronize() 177 break 178 return out 179 except Exception as e: 180 if config.repro_level == 1: 181 dump_compiler_graph_state( 182 fx.GraphModule(gm, orig_graph), 183 copy_tensor_attrs, 184 compiler_name, 185 ) 186 elif config.repro_level == 2: 187 dump_to_minify( 188 fx.GraphModule(gm, orig_graph), 189 copy_tensor_attrs, 190 compiler_name, 191 ) 192 raise 193 194 if config.repro_after == "aot": 195 compiled_fn = deferred_for_real_inputs 196 compiled_fn._boxed_call = True # type: ignore[attr-defined] 197 return compiled_fn 198 else: 199 return inner_compiled_fn 200 201 return debug_wrapper 202 203 204# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 205# DUMP REPROS 206# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 207 208 209def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=None): 210 model_str = textwrap.dedent( 211 f""" 212import torch 213from torch import tensor, device 214import torch.fx as fx 215from torch._dynamo.testing import rand_strided 216from math import inf 217import torch._inductor.inductor_prims 218 219{generate_config_string(stable_output=stable_output)} 220 221isolate_fails_code_str = None 222 223{extra_imports} 224 225 """ 226 ) 227 if not stable_output: 228 model_str += f"# torch version: {torch.version.__version__}\n" 229 if hasattr(torch.version, "cuda"): 230 model_str += f"# torch cuda version: {torch.version.cuda}\n" 231 if hasattr(torch.version, "git_version"): 232 model_str += f"# torch git version: {torch.version.git_version}\n\n\n" 233 model_str += _cuda_system_info_comment() 234 235 model_str += NNModuleToString.convert(gm) 236 237 # get hint shape/stride when dynamic shape enabled 238 def hint_if_symint(x): 239 return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x) 240 241 writer = InputWriter(save_dir) 242 for placeholder, arg in zip(fx_placeholder_targets(gm), args): 243 if isinstance(arg, (int, torch.SymInt)): 244 writer.symint(placeholder, arg) 245 elif isinstance(arg, torch.Tensor): 246 # TODO: improve these names with FQN 247 writer.tensor(placeholder, arg) 248 else: 249 raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") 250 251 model_str += "\n".join(writer.lines()) + "\n" 252 253 model_str += "mod = Repro()\n" 254 return model_str 255 256 257def save_graph_repro( 258 fd, 259 gm, 260 args, 261 compiler_name, 262 *, 263 stable_output=False, 264 save_dir=None, 265 command="run", 266 accuracy=None, 267 tracing_mode=None, 268 check_str=None, 269): 270 if any( 271 isinstance(arg, torch.fx.experimental._backward_state.BackwardState) 272 for arg in args 273 ): 274 fd.write( 275 "Repro is not generated due to existence of BackwardState in graph input" 276 ) 277 return 278 fd.write( 279 generate_compiler_repro_string( 280 gm, 281 args, 282 stable_output=stable_output, 283 save_dir=save_dir, 284 ) 285 ) 286 if accuracy is None: 287 accuracy = "_accuracy" in compiler_name 288 if tracing_mode is None: 289 tracing_mode = "real" 290 if any(has_free_symbols(a) for a in args): 291 tracing_mode = "symbolic" 292 fd.write("if __name__ == '__main__':\n") 293 fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") 294 fd.write( 295 f" with torch.no_grad():\n" 296 f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " 297 f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" 298 f" # To run it separately, do \n" 299 f" # mod, args = run_repro(mod, load_args, accuracy={accuracy!r}, command='get_args', " 300 f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" 301 f" # mod(*args)" 302 ) 303 304 305def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): 306 subdir = os.path.join(minifier_dir(), "checkpoints") 307 if not os.path.exists(subdir): 308 os.makedirs(subdir, exist_ok=True) 309 file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") 310 log.warning( 311 "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name 312 ) 313 with open(file_name, "w") as fd: 314 save_graph_repro( 315 fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy 316 ) 317 curdir = os.getcwd() 318 repro_path = os.path.join(curdir, "repro.py") 319 try: 320 shutil.copyfile(file_name, repro_path) 321 log.warning("Copying repro file for convenience to %s", repro_path) 322 if use_buck: 323 BuckTargetWriter(file_name).write() 324 except OSError: 325 log.warning("No write permissions for %s", repro_path) 326 327 328# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 329# DUMP MINIFIER 330# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 331 332 333def dump_to_minify(gm, args, compiler_name: str): 334 out = io.StringIO() 335 # TODO: factor this out 336 subdir = os.path.join(minifier_dir(), "checkpoints") 337 if not os.path.exists(subdir): 338 os.makedirs(subdir, exist_ok=True) 339 save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify") 340 return helper_for_dump_minify(out.getvalue()) 341 342 343def isolate_fails( 344 fx_g, 345 args, 346 compiler_name: str, 347 env=None, 348 save_dir=None, 349 accuracy=None, 350 tracing_mode=None, 351 check_str=None, 352): 353 if env is None: 354 env = {} 355 subdir = os.path.join(os.getcwd(), "isolate") 356 if not os.path.exists(subdir): 357 os.makedirs(subdir, exist_ok=True) 358 file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") 359 with open(file_name, "w") as fd: 360 save_graph_repro( 361 fd, 362 fx_g, 363 args, 364 compiler_name, 365 save_dir=save_dir, 366 command="minifier-query", 367 accuracy=accuracy, 368 tracing_mode=tracing_mode, 369 check_str=check_str, 370 ) 371 # with open(file_name, "r") as fd: 372 # print(fd.read()) 373 new_env = os.environ.copy() 374 new_env = {**new_env, **env} 375 stdout, stderr = TemporaryFile(), TemporaryFile() 376 377 if use_buck: 378 cmd = BuckTargetWriter(file_name).write(print_msg=False) 379 else: 380 cmd = ["python", file_name] 381 382 p = subprocess.Popen( 383 cmd, 384 cwd=subdir, 385 stdout=stdout, 386 stderr=stderr, 387 env=new_env, 388 ) 389 p.wait() 390 391 stdout.seek(0) 392 stderr.seek(0) 393 print( 394 textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout 395 ) 396 print( 397 textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr 398 ) 399 # print(f"Isolated test failed - {file_name}") 400 return p.returncode != 0 401 402 403# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 404# MINIFIER TOOLS 405# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 406 407 408def inductor_fails(fx_g, args, check_str=None): 409 has_cuda = False 410 for arg in args: 411 if isinstance(arg, torch.Tensor) and arg.is_cuda: 412 has_cuda = True 413 break 414 415 def sync(): 416 if has_cuda: 417 # Ensures that segfaults are surfaced 418 torch.cuda.synchronize() 419 420 from torch._inductor.compile_fx import compile_fx_inner 421 422 try: 423 result = fx_g(*args) 424 assert isinstance(result, (tuple, list)) 425 assert not any(isinstance(x, (tuple, list)) for x in result) 426 except Exception: 427 return False 428 429 sync() 430 431 try: 432 compile_mod = compile_fx_inner(fx_g, args) 433 compile_mod(args) 434 sync() 435 except Exception as e: 436 if check_str is not None and check_str not in repr(e): 437 return False 438 print(repr(e)) 439 return True 440 return False 441 442 443def inductor_accuracy_fails( 444 fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False 445): 446 from torch._inductor.compile_fx import compile_fx_inner 447 448 return backend_aot_accuracy_fails( 449 fx_g, 450 args, 451 compile_fx_inner, 452 require_fp64=require_fp64, 453 ignore_non_fp=ignore_non_fp, 454 ) 455 456 457backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) 458 459 460# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 461# REPRO MAIN 462# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 463 464 465def repro_common(options, mod, load_args): 466 # Invariant for graphs we generate with the repro script 467 assert not any(mod.named_parameters()) 468 for n, b in mod.named_buffers(): 469 if b.numel() > MAX_CONSTANT_NUMEL_INLINE: 470 log.warning( 471 "Constant %s was not serialized, generated random data instead. " 472 "If you think this is affecting you, please comment on " 473 "https://github.com/pytorch/pytorch/issues/100468", 474 n, 475 ) 476 477 if not hasattr(load_args, "_version"): 478 log.warning( 479 "load_args does not have a _version attribute, please file a bug to PyTorch " 480 "and describe how you generate this repro script" 481 ) 482 else: 483 if load_args._version > 0: 484 log.warning( 485 "load_args is version %s, but this version of PyTorch only supports " 486 "version 0. We will try to run it anyway but there may be an incompatibility; " 487 "if so, try upgrading your version of PyTorch.", 488 load_args._version, 489 ) 490 491 nop_reader = NopInputReader() 492 load_args(nop_reader) 493 494 with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: 495 input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) 496 load_args(input_reader) 497 args = input_reader.args 498 499 # Turn mod into a GraphModule the slow way 500 # TODO: speed this up 501 mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) 502 503 torch._inductor.config.generate_intermediate_hooks = True 504 505 return mod, args 506 507 508ACCURACY_FAILS: Dict[str, Callable[[nn.Module, Any], bool]] = { 509 "": inductor_fails, 510 # This might look inverted but it's not. strict_accuracy means "we will 511 # minify any time we see anything that diverges", whereas accuracy is more 512 # conservative, and will only minify if there is a meaningful fp64 513 # divergence 514 "accuracy": functools.partial( 515 inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True 516 ), 517 "strict_accuracy": inductor_accuracy_fails, 518} 519 520 521def repro_minifier_query(options, mod, load_args): 522 mod, args = repro_common(options, mod, load_args) 523 fail_fn = functools.partial( 524 ACCURACY_FAILS[options.accuracy], check_str=options.check_str 525 ) 526 if fail_fn(mod, args): 527 sys.exit(1) 528 else: 529 sys.exit(0) 530 531 532def repro_minify(options, mod, load_args): 533 from functorch.compile import minifier 534 535 mod, args = repro_common(options, mod, load_args) 536 compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor" 537 538 favored_device = 1 if torch.cuda.device_count() >= 2 else 0 539 env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)} 540 541 module_fails: Any 542 if options.isolate: 543 module_fails = functools.partial( 544 isolate_fails, 545 env=env_variables, 546 compiler_name=compiler_name, 547 save_dir=options.save_dir, 548 accuracy=options.accuracy, 549 tracing_mode=options.tracing_mode, 550 ) 551 else: 552 module_fails = ACCURACY_FAILS[options.accuracy] 553 554 minifier( 555 mod, 556 args, 557 module_fails=functools.partial(module_fails, check_str=options.check_str), 558 dump_state=functools.partial( 559 dump_compiler_graph_state, compiler_name=compiler_name 560 ), 561 save_dir=options.save_dir, 562 offload_to_disk=options.offload_to_disk, 563 skip_offload=options.skip_saving_eager_intermediates, 564 skip_sanity=options.skip_sanity, 565 max_granularity=options.max_granularity, 566 ) 567 568 569def repro_analyze(options, mod, load_args): 570 from torch._inductor.compile_fx import compile_fx_inner 571 from torch._inductor.hooks import intermediate_hook 572 573 mod, args = repro_common(options, mod, load_args) 574 575 # TODO: The logic for cloning inputs/models here is intentionally 576 # modeled off of run_fwd_maybe_bwd, but arguably it is better not to 577 # clone inputs (as you are doubling your effective GPU memory usage). 578 # It is certainly faster though! It probably makes sense to let the 579 # user specify the offload strategy. 580 581 with tqdm(desc="Compiling"): 582 compiled = compile_fx_inner(mod, args) 583 total = counters["inductor"]["intermediate_hooks"] 584 585 known_names = set() 586 587 def save_hook(name, val): 588 known_names.add(name) 589 if not options.skip_saving_inductor_intermediates: 590 writer.write_tensor(os.path.join("inductor", name), val) 591 pbar.update(1) # type: ignore[has-type] 592 593 writer = torch.utils._content_store.ContentStoreWriter( 594 options.save_dir, stable_hash=options.stable_hash 595 ) 596 reader = torch.utils._content_store.ContentStoreReader(options.save_dir) 597 598 new_args = clone_inputs(args) 599 with intermediate_hook(save_hook), tqdm( 600 desc="Saving inductor intermediates", total=total 601 ) as pbar: 602 compiled(new_args) 603 assert not new_args 604 605 def compare_tuples(tuple1, tuple2): 606 diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] 607 diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] 608 609 if not diff_values: 610 return None 611 else: 612 return " and ".join(f"{a} != {b}" for a, b in diff_values) 613 614 def check_hook(name, val): 615 meta = writer.compute_tensor_metadata(val) 616 meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) 617 reason = compare_tuples(meta, meta2) 618 if reason is not None: 619 pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})") 620 pbar.update(1) 621 622 if not options.skip_check_deterministic: 623 new_args = clone_inputs(args) 624 with intermediate_hook(check_hook), tqdm( 625 desc="Checking inductor determinism", total=total 626 ) as pbar: 627 compiled(new_args) 628 assert not new_args 629 630 class WriterInterp(fx.Interpreter): 631 def __init__(self, mod, subdir) -> None: 632 super().__init__(mod) 633 self.subdir = subdir 634 635 def run_node(self, n): 636 r = super().run_node(n) 637 name = n.name 638 if name in known_names: 639 pbar.update(1) 640 writer.write_tensor(os.path.join(self.subdir, name), r) 641 return r 642 643 # NB: the module cast doesn't actually do anything, since there are no 644 # parameters/buffers on the module 645 if not options.skip_saving_float64_intermediates: 646 new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) 647 with tqdm(desc="Saving float64 intermediates", total=total) as pbar: 648 WriterInterp(new_mod, "float64").boxed_run(new_args) 649 assert not new_args 650 651 class ExactReaderInterp(fx.Interpreter): 652 def run_node(self, n): 653 r = super().run_node(n) 654 name = n.name 655 if name in known_names: 656 meta = writer.compute_tensor_metadata(r) 657 meta2 = reader.read_tensor_metadata(os.path.join("float64", name)) 658 reason = compare_tuples(meta, meta2) 659 if reason is not None: 660 pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})") 661 pbar.update(1) 662 return r 663 664 # TODO: check eager determinism 665 666 if not options.skip_check_deterministic: 667 new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) 668 with tqdm(desc="Checking float64 determinism", total=total) as pbar: 669 ExactReaderInterp(new_mod).boxed_run(new_args) 670 assert not new_args 671 672 # Now that we've saved everything, interp through the eager graph 673 # and do comparisons 674 class ReaderInterp(fx.Interpreter): 675 def run_node(self, n): 676 r = super().run_node(n) 677 name = n.name 678 if name in known_names: 679 inductor = reader.read_tensor(os.path.join("inductor", name)) 680 float64 = reader.read_tensor(os.path.join("float64", name)) 681 logged = False 682 683 def log_error(msg, *args): 684 nonlocal logged 685 logged = True 686 pbar.write(f"DIVERGED at {name}: {msg % args}") 687 688 if not same( 689 r, 690 inductor, 691 float64, 692 tol=torch._dynamo.config.repro_tolerance, 693 equal_nan=True, 694 log_error=log_error, 695 ): 696 assert logged 697 pbar.update(1) 698 return r 699 700 with tqdm(desc="Checking divergence", total=total) as pbar: 701 ReaderInterp(mod).boxed_run(args) 702 assert not args 703 704 705def repro_get_args(options, mod, load_args): 706 mod, args = repro_common(options, mod, load_args) 707 return mod, args 708 709 710def repro_run(options, mod, load_args): 711 from torch._inductor.compile_fx import compile_fx_inner 712 713 mod, args = repro_common(options, mod, load_args) 714 715 from torch.cuda import synchronize 716 717 compiled = compile_fx_inner(mod, args) 718 719 if options.accuracy != "": 720 # We don't really respect --accuracy vs --strict-accuracy here, it 721 # seems counterintuitive 722 if not same_two_models( 723 mod, 724 compiled, 725 args, 726 only_fwd=True, 727 ignore_non_fp=config.repro_ignore_non_fp, 728 ): 729 raise AccuracyError("Bad accuracy detected") 730 else: 731 need_sync = False 732 for arg in args: 733 if isinstance(arg, torch.Tensor) and arg.is_cuda: 734 need_sync = True 735 break 736 ref = compiled(list(args)) 737 if need_sync: 738 synchronize() # ensure segfaults are surfaced 739 return lambda: compiled(list(args)) 740 741 742# TODO: lazily load the inputs or something, rather than cloning them 743def run_repro( 744 mod, 745 load_args, 746 *, 747 command="run", 748 accuracy: Union[bool, str] = "", 749 save_dir=None, 750 tracing_mode=None, 751 patch_code=None, 752 check_str=None, 753 **kwargs, 754): 755 for k in kwargs: 756 log.warning( 757 "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", 758 k, 759 ) 760 761 if accuracy is True: 762 accuracy = "accuracy" 763 elif accuracy is False: 764 accuracy = "" 765 766 if patch_code is not None: 767 log.warning( 768 "patch_code no longer works on this version of PyTorch, silently ignoring" 769 ) 770 771 parser = argparse.ArgumentParser( 772 description=f"""\ 773An after_aot repro script, typically triggering a bug in PyTorch Inductor. 774When run with no arguments, this script defaults to running '{command}'. 775Extra flags may be available; to find out more, try '{command} --help'. 776There are also alternate subcommands available, see below. 777 778default settings on this script: 779 {accuracy=} 780 {tracing_mode=} 781 {save_dir=} 782 {check_str=} 783""", 784 formatter_class=argparse.RawTextHelpFormatter, 785 ) 786 787 def common_flags(parser): 788 accuracy_group = parser.add_mutually_exclusive_group() 789 accuracy_group.add_argument( 790 "--no-accuracy", 791 dest="accuracy", 792 action="store_const", 793 const="", 794 default=accuracy, 795 help="do not test accuracy, just run the module and see if it errors", 796 ) 797 accuracy_group.add_argument( 798 "--accuracy", 799 action="store_const", 800 const="accuracy", 801 default=accuracy, 802 help="""\ 803test if the RMSE between the compiled module and the fp64 reference is greater 804than eager and the fp64 reference. This is usually more reliable than the 805standard allclose test, as we expect numeric differences from compiling, often 806improving accuracy over eager. RMSE test allows for compiled module to 807diverge greatly from eager, as long as this divergence moves it closer to the 808'true' mathematical value of the network. Caveats: (1) double precision can 809still suffer from rounding error, so it is not a perfect reference (see for 810example 'Herbie: Automatically Improving Floating Point Accuracy') for 811approaches that detect the necessary working precision and compute it in 812arbitrary precision floating point; unfortunately, this is not practical for 813tensor computation; (2) if there are not enough samples in the output being 814compared, we may get unlucky and have an unlucky greater RMSE than eager; this 815could be overcome by applying a more rigorous statistical test at some 816p-value, which we leave for future work. 817""", 818 ) 819 accuracy_group.add_argument( 820 "--strict-accuracy", 821 dest="accuracy", 822 action="store_const", 823 const="strict_accuracy", 824 default=accuracy, 825 help="""\ 826by default, when doing accuracy minification we will reject reductions which 827change the divergence from a floating point divergence to a integral/boolean 828divergence. This is because some operations like ReLU involve temporarily 829sharp boundaries that smooth out again afterwards; without requiring 830divergence on floating point, the minifier will often fixate on divergent 831boolean tensor even though this is not the true source of the divergence. 832However, rejecting these reductions makes it more difficult for the minifier 833to make process. Using this option will let the minifier progress for ALL 834divergences--you just might not end up with a useful repro in the end.""", 835 ) 836 837 parser.add_argument( 838 "--save-dir", 839 type=str, 840 default=save_dir, 841 metavar="DIR", 842 help="directory where saved inputs live", 843 ) 844 parser.add_argument( 845 "--no-save-dir", 846 dest="save_dir", 847 action="store_const", 848 const=None, 849 help="don't use any directory for saved inputs", 850 ) 851 parser.add_argument( 852 "--tracing-mode", 853 type=str, 854 metavar="{real,fake,symbolic}", 855 default=tracing_mode, 856 help="how to trace the repro module into a GraphModule with metadata", 857 ) 858 859 subparsers = parser.add_subparsers( 860 dest="command", metavar="{run,minify,analyze}", required=True 861 ) 862 863 parser_run = subparsers.add_parser( 864 "run", 865 help="just run the repro", 866 ) 867 common_flags(parser_run) 868 869 parser_minify = subparsers.add_parser( 870 "minify", help="run the minifier on the repro" 871 ) 872 common_flags(parser_minify) 873 parser_get_args = subparsers.add_parser("get_args", help="get the args") 874 common_flags(parser_get_args) 875 parser_minify_isolate = parser_minify.add_mutually_exclusive_group() 876 parser_minify_isolate.add_argument( 877 "--isolate", 878 action="store_true", 879 default=True, 880 help="run in separate processes to avoid interference (default)", 881 ) 882 parser_minify_isolate.add_argument( 883 "--no-isolate", 884 dest="isolate", 885 action="store_false", 886 help="speed up by running all compilation in same process", 887 ) 888 parser_minify.add_argument( 889 "--skip-saving-eager-intermediates", 890 action="store_true", 891 help="skip saving eager intermediates on --minify", 892 ) 893 # TODO: make this an option for --analyze too 894 parser_minify.add_argument( 895 "--offload-to-disk", 896 action="store_true", 897 help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", 898 ) 899 parser_minify.add_argument( 900 "--skip-sanity", 901 action="store_true", 902 help="skip sanity check at beginning of minification on original graph", 903 ) 904 parser_minify.add_argument( 905 "--max-granularity", 906 type=int, 907 default=None, 908 help="start at this granularity and work down; must be power of 2", 909 ) 910 parser_minify.add_argument( 911 "--check-str", 912 type=str, 913 default=check_str, 914 help="require minified program to fail with error containing this string", 915 ) 916 917 parser_analyze = subparsers.add_parser( 918 "analyze", help="run the accuracy analyzer on the repro" 919 ) 920 common_flags(parser_analyze) 921 parser_analyze.add_argument( 922 "--skip-saving-inductor-intermediates", 923 action="store_true", 924 help="skip saving inductor intermediates on --analyze", 925 ) 926 parser_analyze.add_argument( 927 "--skip-saving-float64-intermediates", 928 action="store_true", 929 help="skip saving float64 intermediates", 930 ) 931 parser_analyze.add_argument( 932 "--skip-check-deterministic", 933 action="store_true", 934 help="skip checking that the network is deterministic", 935 ) 936 parser_analyze.add_argument( 937 "--stable-hash", 938 action="store_true", 939 help="use SHA-1 checksum instead of fast (but possibly unsound) hash", 940 ) 941 942 # Run the repro in the context of minification, inverting exit code meaning 943 parser_minifier_query = subparsers.add_parser( 944 "minifier-query", 945 ) 946 common_flags(parser_minifier_query) 947 parser_minifier_query.add_argument( 948 "--check-str", 949 type=str, 950 default=check_str, 951 help="require minified program to fail with error containing this string", 952 ) 953 954 args = None 955 if len(sys.argv) <= 1: 956 args = [command, *sys.argv[1:]] 957 958 options = parser.parse_args(args) 959 COMMAND_FNS = { 960 "minify": repro_minify, 961 "analyze": repro_analyze, 962 "minifier-query": repro_minifier_query, 963 "run": repro_run, 964 "get_args": repro_get_args, 965 } 966 return COMMAND_FNS[options.command](options, mod, load_args) 967