1# mypy: ignore-errors 2 3import datetime 4import difflib 5import functools 6import inspect 7import json 8import os 9import re 10import tempfile 11import threading 12import unittest 13from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 14 15import torch 16import torch._dynamo 17import torch.utils._pytree as pytree 18from torch._dynamo.utils import clone_input 19from torch._library.custom_ops import CustomOpDef 20from torch._subclasses.schema_check_mode import SchemaCheckMode 21from torch._utils_internal import get_file_path_2 22from torch.overrides import TorchFunctionMode 23from torch.testing._internal.optests import ( 24 aot_autograd_check, 25 autograd_registration_check, 26 fake_check, 27) 28 29 30def dontGenerateOpCheckTests(reason: str): 31 def inner(fun): 32 fun._torch_dont_generate_opcheck_tests = True 33 return fun 34 35 return inner 36 37 38def is_abstract(tensor: torch.Tensor) -> bool: 39 if tensor.is_meta: 40 return True 41 if torch._subclasses.fake_tensor.is_fake(tensor): 42 return True 43 return False 44 45 46def safe_schema_check( 47 op: torch._ops.OpOverload, 48 args: Tuple[Any, ...], 49 kwargs: Dict[str, Any], 50 *, 51 copy_inputs: bool = True, 52) -> Any: 53 if copy_inputs: 54 args, kwargs = deepcopy_tensors((args, kwargs)) 55 if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): 56 return None 57 with SchemaCheckMode(): 58 result = op(*args, **kwargs) 59 return result 60 61 62def safe_autograd_registration_check( 63 op: torch._ops.OpOverload, 64 args: Tuple[Any, ...], 65 kwargs: Dict[str, Any], 66 *, 67 copy_inputs: bool = True, 68) -> None: 69 if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): 70 return 71 if copy_inputs: 72 args, kwargs = deepcopy_tensors((args, kwargs)) 73 # Don't perform autograd_registration_check if none of the inputs require grad. 74 if not pytree.tree_any_only( 75 torch.Tensor, lambda x: x.requires_grad, (args, kwargs) 76 ): 77 return 78 return autograd_registration_check(op, args, kwargs) 79 80 81def safe_fake_check( 82 op: torch._ops.OpOverload, 83 args: Tuple[Any, ...], 84 kwargs: Dict[str, Any], 85 *, 86 copy_inputs: bool = True, 87) -> None: 88 if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): 89 return None 90 if copy_inputs: 91 args, kwargs = deepcopy_tensors((args, kwargs)) 92 return fake_check(op, args, kwargs) 93 94 95def safe_aot_autograd_check( 96 op: torch._ops.OpOverload, 97 args: Tuple[Any, ...], 98 kwargs: Dict[str, Any], 99 dynamic: bool, 100 *, 101 copy_inputs: bool = True, 102) -> Any: 103 # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy 104 # inputs. 105 if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): 106 return None 107 108 def func(*args, **kwargs): 109 args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs)) 110 return op(*args, **kwargs) 111 112 # aot_autograd_check runs func(*args, **kwargs) multiple times 113 # and assumes `func` does not modify its inputs. 114 return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto") 115 116 117def deepcopy_tensors(inputs: Any) -> Any: 118 return pytree.tree_map_only(torch.Tensor, clone_input, inputs) 119 120 121# Test util requirements 122# - The test util must have signature (op: OpOverload, args, kwargs) 123# - The test util must NOT mutate args, kwargs. 124# - The test utils in this list must not be prefixes of each other. For example, 125# having both "test_schema" and "test_schema_is_functional" is NOT OK. 126# - The order of items in this dict matters (for opcheck), we'll run them 127# in order. 128ALL_TEST_UTILS = { 129 "test_schema": safe_schema_check, 130 "test_autograd_registration": safe_autograd_registration_check, 131 "test_faketensor": safe_fake_check, 132 "test_aot_dispatch_static": functools.partial( 133 safe_aot_autograd_check, 134 dynamic=False, 135 ), 136 "test_aot_dispatch_dynamic": functools.partial( 137 safe_aot_autograd_check, 138 dynamic=True, 139 ), 140} 141 142GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit" 143 144DEFAULT_TEST_UTILS = [ 145 "test_schema", 146 "test_autograd_registration", 147 "test_faketensor", 148 "test_aot_dispatch_dynamic", 149] 150 151DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [ 152 "test_aot_dispatch_static", 153] 154 155 156def generate_opcheck_tests( 157 testcase: Any, 158 namespaces: List[str], 159 failures_dict_path: Optional[str] = None, 160 additional_decorators: Dict[str, Callable] = None, 161 test_utils: List[str] = DEFAULT_TEST_UTILS, 162) -> None: 163 """Given an existing TestCase, use the existing tests to generate 164 additional validation tests for custom operators. 165 166 For {all existing tests in the TestCase} x {all test utils}, 167 we will generate one new test. The new test runs a TorchFunctionMode 168 that intercepts ``op(*args, **kwargs)`` calls and invokes 169 ``test_util(op, *args, **kwargs)``, where ``op`` is an operator. 170 171 The test_util that we support are in ALL_TEST_UTILS. They are: 172 - test_schema: This runs SchemaCheckMode. 173 - test_autograd_registration: This runs autograd_registration_check. 174 - test_faketensor: This runs CrossRefFakeMode. 175 - test_aot_dispatch_static: This runs aot_autograd_check, which: 176 checks that the outputs (and gradients, if they are computable) 177 are the same under eager-mode PyTorch and using AOTAutograd. 178 - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but 179 runs AOTAutograd using dynamic shapes instead of static shapes. 180 181 The generated test will have name ``{test_util}__{original_name}``. 182 For example, if there is a method named ``test_cumsum``, then 183 we will generate a ``test_schema__test_cumsum``, 184 ``test_faketensor__test_cumsum``, etc. 185 186 For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit 187 188 Args: 189 testcase: The testcase we will modify and generate additional tests for. 190 namespaces: We will only intercept calls to custom operators with these 191 namespaces. 192 failures_dict_path: See ``validate_failures_dict_structure`` for more details 193 test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"] 194 """ 195 if additional_decorators is None: 196 additional_decorators = {} 197 test_methods = [ 198 m 199 for m in dir(testcase) 200 if m.startswith("test_") and callable(getattr(testcase, m)) 201 ] 202 if failures_dict_path is None: 203 # The default failures_dict_path is failures_dict.json in 204 # the same directory as the test file. 205 prev_frame = inspect.currentframe().f_back 206 filename = inspect.getframeinfo(prev_frame)[0] 207 failures_dict_path = get_file_path_2( 208 os.path.dirname(filename), "failures_dict.json" 209 ) 210 failures_dict = FailuresDict.load( 211 failures_dict_path, create_file=should_update_failures_dict() 212 ) 213 validate_failures_dict_structure(failures_dict, test_utils, testcase) 214 validate_failures_dict_formatting(failures_dict_path) 215 216 def construct_method(attr, prefix, tester): 217 method = getattr(testcase, attr) 218 if getattr(method, "_torch_dont_generate_opcheck_tests", False): 219 return 220 new_method_name = prefix + "__" + attr 221 222 @functools.wraps(method) 223 def new_method(*args, **kwargs): 224 with OpCheckMode( 225 namespaces, 226 prefix, 227 tester, 228 failures_dict, 229 f"{testcase.__name__}.{new_method_name}", 230 failures_dict_path, 231 ): 232 result = method(*args, **kwargs) 233 return result 234 235 if pytestmark := new_method.__dict__.get("pytestmark"): 236 import pytest 237 238 # check if we need to simplify the parametrize marks 239 # NB: you need to add this mark to your pytest.ini 240 opcheck_only_one = False 241 for mark in pytestmark: 242 if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one": 243 opcheck_only_one = True 244 245 if opcheck_only_one: 246 new_pytestmark = [] 247 for mark in pytestmark: 248 if isinstance(mark, pytest.Mark) and mark.name == "parametrize": 249 argnames, argvalues = mark.args 250 assert not mark.kwargs, "NYI" 251 # Special case for device, we want to run on all 252 # devices 253 if argnames != "device": 254 new_pytestmark.append( 255 pytest.mark.parametrize( 256 argnames, (next(iter(argvalues)),) 257 ) 258 ) 259 continue 260 new_pytestmark.append(mark) 261 new_method.__dict__["pytestmark"] = new_pytestmark 262 263 if new_method_name in additional_decorators: 264 for dec in additional_decorators[new_method_name]: 265 new_method = dec(new_method) 266 267 if hasattr(testcase, new_method_name): 268 raise RuntimeError( 269 f"Tried to autogenerate {new_method_name} but {testcase} already " 270 f"has method named {new_method_name}. Please rename the original " 271 f"method on the TestCase." 272 ) 273 setattr(testcase, new_method_name, new_method) 274 275 test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils} 276 for attr in test_methods: 277 for prefix, tester in test_utils.items(): 278 construct_method(attr, prefix, tester) 279 280 generate_tag_tests(testcase, failures_dict, additional_decorators) 281 282 283def generate_tag_tests(testcase, failures_dict, additional_decorators): 284 def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests): 285 def inner(self): 286 try: 287 op = torch._library.utils.lookup_op(qualname) 288 except AttributeError as e: 289 # Operator not importable in this test file 290 raise unittest.SkipTest(f"Can't import operator {qualname}") from e 291 op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags 292 if not op_marked_as_compliant: 293 return 294 if not definitely_not_pt2_compliant: 295 return 296 raise AssertionError( 297 f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag " 298 f"but it failed some of the generated opcheck tests " 299 f"({xfailed_tests}). This may lead to silent correctness issues, " 300 f"please fix this." 301 ) 302 303 return inner 304 305 for qualname, test_dict in failures_dict.data.items(): 306 xfailed_tests = [ 307 test 308 for test, status_dict in test_dict.items() 309 # We're about to delete the following test after Ed's PR 310 # to specialize on C++ .size() calls 311 if "test_aot_dispatch_static" not in test 312 and status_dict["status"] == "xfail" 313 ] 314 definitely_not_pt2_compliant = len(xfailed_tests) > 0 315 generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests) 316 317 # Could result in collisions, but unlikely. We'll raise if we see one below. 318 mangled_qualname = qualname.replace("::", "_").replace(".", "_") 319 test_name = "test_pt2_compliant_tag_" + mangled_qualname 320 321 # You can skip this test via the additional_decorators argument 322 # in generate_opcheck_tests 323 if test_name in additional_decorators: 324 for decorator in additional_decorators[test_name]: 325 generated = decorator(generated) 326 327 if hasattr(testcase, test_name): 328 raise RuntimeError( 329 f"Tried to generate a test named {test_name}, but it exists " 330 f"already. This could be because of a name collision (where " 331 f"we generated two tests with the same name), or where we " 332 f"generated a test with the same name as an existing test." 333 ) 334 setattr(testcase, test_name, generated) 335 336 337TEST_OPTIONS = ("xfail", "skip", "xsuccess") 338 339 340def validate_failures_dict_formatting(failures_dict_path: str) -> None: 341 with open(failures_dict_path) as fp: 342 actual = fp.read() 343 failures_dict = FailuresDict.load(failures_dict_path) 344 expected = failures_dict._save(to_str=True) 345 if actual == expected: 346 return 347 if should_update_failures_dict(): 348 failures_dict = FailuresDict.load(failures_dict_path) 349 failures_dict.save() 350 return 351 expected = expected.splitlines(1) 352 actual = actual.splitlines(1) 353 diff = difflib.unified_diff(actual, expected) 354 diff = "".join(diff) 355 raise RuntimeError( 356 f"\n{diff}\n\nExpected the failures dict to be formatted " 357 f"a certain way. Please see the above diff; you can correct " 358 f"this either manually or by re-running the test with " 359 f"PYTORCH_OPCHECK_ACCEPT=1" 360 ) 361 362 363def validate_failures_dict_structure( 364 failure_dict: "FailuresDict", test_utils: List[str], testcase: Any 365) -> None: 366 """Validates the failures dict. 367 368 The failure dict looks something like the following. 369 It maps operator name (qualname) to a list of autogenerated tests. 370 Each autogenerated test may have a check for the operator (if the operator is 371 called by the test); the dictionary specifies if we should skip the check, 372 or if we expect some check to fail. 373 374 { 375 "fbgemm::split_lengths": { 376 "test_schema__test_split_lengths": { 377 "comment": "you can put whatever you want into the comment section", 378 "status": "xfail", 379 } 380 "test_schema__test_split_lengths_empty": { 381 "comment": "", 382 "status": "skip", 383 }, 384 }, 385 "fbgemm::gather_lengths": { 386 "test_schema__test_gather_lengths": { 387 "comment": "", 388 "status": "skip", 389 }, 390 }, 391 } 392 393 """ 394 failure_dict = failure_dict.data 395 qualnames = list(failure_dict.keys()) 396 for test_to_option in failure_dict.values(): 397 test_names = list(test_to_option.keys()) 398 for test_name, test_dict in test_to_option.items(): 399 if set(test_dict.keys()) != set({"comment", "status"}): 400 raise RuntimeError( 401 "in failures_dict, expected sub-dict to have keys 'comment' and 'status'" 402 ) 403 test_option = test_dict["status"] 404 if test_option not in TEST_OPTIONS: 405 raise RuntimeError( 406 f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}" 407 ) 408 test_class, actual_test_name = test_name.split(".") 409 if not any(actual_test_name.startswith(test) for test in test_utils): 410 raise RuntimeError( 411 f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}" 412 ) 413 for test in test_utils: 414 if not actual_test_name.startswith(test): 415 continue 416 base_test_name = actual_test_name[len(test) + 2 :] 417 # remove potential pytest parametrization suffix 418 base_test_name = re.sub(r"\[.*\]", "", base_test_name) 419 if testcase.__name__ != test_class: 420 continue 421 if hasattr(testcase, base_test_name): 422 continue 423 raise RuntimeError( 424 f"In failures dict, got test name '{test_name}'. We parsed this as " 425 f"running test '{test}' on '{base_test_name}', but " 426 f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. " 427 f"Maybe you need to change the test name?" 428 ) 429 430 431def should_update_failures_dict() -> bool: 432 key = "PYTORCH_OPCHECK_ACCEPT" 433 return key in os.environ and os.environ[key] == "1" 434 435 436_is_inside_opcheck_mode = threading.local() 437_is_inside_opcheck_mode.value = False 438 439 440def is_inside_opcheck_mode(): 441 return _is_inside_opcheck_mode.value 442 443 444class OpCheckMode(TorchFunctionMode): 445 """ 446 For a given test, OpCheckMode intercepts calls to operators and runs 447 test_util(op, args, kwargs) for each intercepted (op, args, kwargs). 448 """ 449 450 def __init__( 451 self, 452 namespaces: List[str], 453 test_util_name: str, 454 test_util: Callable, 455 failures_dict: "FailuresDict", 456 test_name: str, 457 failures_dict_path: str, 458 ): 459 # We will intercept calls to ops with these namespaces 460 self.namespaces = namespaces 461 # The test utility function. Its signature should be (op, args, kwargs) -> None. 462 # Examples of test utilities are: schema_check, make_fx_check 463 self.test_util = test_util 464 self.test_util_name = test_util_name 465 # The name of the test that is running this OpCheckMode. 466 self.test_name = test_name 467 # Maps qualname -> test_name -> skip/xfail 468 # Tells us if we should skip a test or assert that there is a failure. 469 self.failures_dict = failures_dict 470 # Location of the failures dict. Makes it so that the error message is better. 471 self.failures_dict_path = failures_dict_path 472 473 # OpCheckMode surpresses errors, collects them here, and then raises them on exit. 474 # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)] 475 self.seen_ops_to_errors = {} 476 477 def maybe_raise_errors_on_exit(self) -> None: 478 # Check expected failures first 479 for qualname in self.seen_ops_to_errors.keys(): 480 option = self.failures_dict.get_status(qualname, self.test_name) 481 if len(self.seen_ops_to_errors[qualname]) == 0: 482 if should_update_failures_dict(): 483 self.failures_dict.set_status( 484 qualname, self.test_name, "xsuccess", comment="" 485 ) 486 else: 487 if option == "xfail": 488 raise OpCheckError( 489 f"generate_opcheck_tests: Unexpected success for operator " 490 f"{qualname} on test {self.test_name}. This may mean that " 491 f"you have fixed this test failure. Please rerun the test with " 492 f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner " 493 f"or manually remove the " 494 f"expected failure in the failure dict at " 495 f"{self.failures_dict_path}" 496 f"For more details, see " 497 f"{GDOC}" 498 ) 499 continue 500 failed_ops = [] 501 for qualname in self.seen_ops_to_errors.keys(): 502 option = self.failures_dict.get_status(qualname, self.test_name) 503 if option != "xsuccess": 504 continue 505 if len(self.seen_ops_to_errors[qualname]) == 0: 506 continue 507 failed_ops.append(qualname) 508 if not failed_ops: 509 return 510 511 if should_update_failures_dict(): 512 for op in failed_ops: 513 self.failures_dict.set_status(op, self.test_name, "xfail") 514 return 515 516 # Raise from the first error but also report about all of them to make 517 # recording xfails easier. 518 ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0] 519 repro_command = generate_repro( 520 self.test_util_name, op, args, kwargs, save_data=should_print_better_repro() 521 ) 522 raise OpCheckError( 523 f"Test generated by `generate_opcheck_tests`, {self.test_name}, " 524 f"failed on operators {failed_ops}. This usually means that the " 525 f"operators are not implemented correctly and may lead to silently " 526 f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, " 527 f"or please see " 528 f"{GDOC} " 529 f"for more recommendations. " 530 f"To reproduce this problem locally, try to run the following:\n{repro_command}" 531 ) from ex 532 533 def __enter__(self, *args, **kwargs): 534 self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value 535 self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "") 536 _is_inside_opcheck_mode.value = True 537 os.environ["TORCHDYNAMO_DISABLE"] = "1" 538 return super().__enter__(*args, **kwargs) 539 540 def __exit__(self, *args, **kwargs): 541 _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode 542 os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable 543 try: 544 self.maybe_raise_errors_on_exit() 545 if should_update_failures_dict(): 546 self.failures_dict.save() 547 finally: 548 result = super().__exit__(*args, **kwargs) 549 return result 550 551 def run_test_util(self, op, args, kwargs): 552 try: 553 self.test_util(op, args, kwargs, copy_inputs=False) 554 except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: 555 # We might get here if the input is already a FakeTensor 556 # or if we're in a torch.compile block. Just ignore these 557 # since we can't handle them and reporting them as failures 558 # is too noisy. 559 pass 560 561 def __torch_function__(self, func, types, args=(), kwargs=None): 562 kwargs = kwargs if kwargs else {} 563 564 # Only intercept calls to operators 565 if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): 566 return func(*args, **kwargs) 567 if ( 568 torch.jit.is_tracing() 569 or torch.jit.is_scripting() 570 or torch._dynamo.is_compiling() 571 ): 572 return func(*args, **kwargs) 573 # Pre-existing code may not use the .default overload. If we see an 574 # OpOverloadPacket and we cannot resolve the overload, then we just throw 575 # and ask the user to clarify. Otherwise, we attempt to resolve the overload. 576 if isinstance(func, torch._ops.OpOverloadPacket): 577 func = resolve_unique_overload_or_throw(func) 578 qualname = func.name() 579 ns = qualname.split("::")[0] 580 if ns not in self.namespaces: 581 return func(*args, **kwargs) 582 583 args_c, kwargs_c = deepcopy_tensors((args, kwargs)) 584 result = func(*args, **kwargs) 585 586 option = self.failures_dict.get_status(qualname, self.test_name) 587 if option == "xsuccess" or option == "xfail": 588 # Surpress all errors during execution. Raise them during __exit__. 589 try: 590 if qualname not in self.seen_ops_to_errors: 591 self.seen_ops_to_errors[qualname] = [] 592 self.run_test_util(func, args_c, kwargs_c) 593 except Exception as ex: 594 if should_print_better_repro(): 595 self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs)) 596 else: 597 self.seen_ops_to_errors[qualname].append((ex, func, None, None)) 598 elif option == "skip": 599 pass 600 return result 601 602 603def should_print_better_repro() -> None: 604 """If set, the tests generated by `generate_opcheck_tests` will print a 605 repro command on failure. 606 607 In order to print the repro command, we need to save some tensors to disk. 608 These will be saved under the following directory: 609 {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/. 610 611 Although this is a temp folder, it will usually not automatically get cleaned 612 up, so you'll need to manually delete it. 613 """ 614 key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO" 615 if key not in os.environ: 616 return False 617 value = os.environ[key] 618 return value == "1" or value == 1 619 620 621def opcheck( 622 op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], 623 args: Tuple[Any, ...], 624 kwargs: Optional[Dict[str, Any]] = None, 625 *, 626 test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS, 627 raise_exception: bool = True, 628) -> Dict[str, str]: 629 """See torch.library.opcheck for docstring""" 630 631 if kwargs is None: 632 kwargs = {} 633 if isinstance(op, CustomOpDef): 634 op = op._opoverload 635 if isinstance(op, torch._ops.OpOverloadPacket): 636 op = resolve_unique_overload_or_throw(op) 637 if not isinstance(op, torch._ops.OpOverload): 638 raise ValueError( 639 f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, " 640 f"e.g. torch.ops.aten.sin.default, got {type(op)}" 641 ) 642 if test_utils == "ALL": 643 test_utils = tuple(ALL_TEST_UTILS.keys()) 644 if isinstance(test_utils, str): 645 test_utils = (test_utils,) 646 if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset( 647 ALL_TEST_UTILS.keys() 648 ): 649 raise ValueError( 650 f"opcheck(op, ..., test_utils={test_utils}), expected test_utils " 651 f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not" 652 ) 653 654 results_dict = {} 655 for test_util in test_utils: 656 tester = ALL_TEST_UTILS[test_util] 657 try: 658 tester(op, args, kwargs) 659 results_dict[test_util] = "SUCCESS" 660 except Exception as ex: 661 if raise_exception: 662 raise OpCheckError( 663 f"opcheck(op, ...): {test_util} failed with {ex} " 664 f"(scroll up for stack trace)" 665 ) from ex 666 results_dict[test_util] = ex 667 return results_dict 668 669 670class OpCheckError(Exception): 671 pass 672 673 674def generate_repro( 675 test: str, 676 op: torch._ops.OpOverload, 677 args: Tuple[Any, ...], 678 kwargs: Dict[str, Any], 679 *, 680 save_data: bool, 681 dry_run: bool = False, 682) -> str: 683 if save_data: 684 now = datetime.datetime.now() 685 path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete") 686 unix_timestamp = datetime.datetime.timestamp(now) * 100000 687 filepath = os.path.join(path, f"repro_{unix_timestamp}.pt") 688 if not dry_run: 689 os.makedirs(path, exist_ok=True) 690 torch.save((args, kwargs), filepath) 691 args_kwargs = f'args, kwargs = torch.load("{filepath}")' 692 else: 693 args_kwargs = ( 694 "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n" 695 "# we will fill them in same (args, kwargs) as in your test\n" 696 "args = () # args to the operator\n" 697 "kwargs = {} # kwargs to the operator" 698 ) 699 700 ns, name = op._schema.name.split("::") 701 overload = op._overloadname 702 703 repro_command = ( 704 f"# =========================================================\n" 705 f"# BEGIN REPRO SCRIPT\n" 706 f"# =========================================================\n" 707 f"import torch\n" 708 f"from torch.testing._internal.optests import opcheck\n" 709 f"\n" 710 f"# Make sure you have loaded the library that contains the op\n" 711 f"# via an import or torch.ops.load_library(...)\n" 712 f"op = torch.ops.{ns}.{name}.{overload}\n" 713 f"\n" 714 f"{args_kwargs}\n" 715 f'opcheck(op, args, kwargs, test_utils="{test}")\n' 716 f"# =========================================================\n" 717 f"# END REPRO SCRIPT\n" 718 f"# =========================================================\n" 719 ) 720 return repro_command 721 722 723def resolve_unique_overload_or_throw( 724 op: torch._ops.OpOverloadPacket, 725) -> torch._ops.OpOverload: 726 all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name) 727 if len(all_schemas) != 1: 728 raise RuntimeError( 729 f"opcheck can only test operators without overloads. " 730 f"Got the following overloads for {op._qualified_op_name}: " 731 f"{[schema.overload_name for schema in all_schemas]}" 732 ) 733 734 overload_name = all_schemas[0].overload_name 735 if overload_name == "": 736 return op.default 737 return getattr(op, overload_name) 738 739 740DUMP_OPTIONS = {"indent": 2, "sort_keys": True} 741 742 743FailuresDictData = Dict[str, Dict[str, Dict[str, str]]] 744 745 746VERSION = 1 747DESCRIPTION = ( 748 f"This is a dict containing failures for tests autogenerated by " 749 f"generate_opcheck_tests. " 750 f"For more details, please see {GDOC}" 751) 752 753 754class FailuresDict: 755 def __init__(self, path: str, data: FailuresDictData): 756 self.path = path 757 self.data = data 758 759 @staticmethod 760 def load(path, *, create_file=False) -> "FailuresDict": 761 if create_file and not os.path.exists(path): 762 result = FailuresDict(path, {}) 763 FailuresDict.save() 764 return result 765 with open(path) as fp: 766 contents = fp.read() 767 if contents.strip() == "": 768 dct = { 769 "_description": DESCRIPTION, 770 "data": {}, 771 "_version": VERSION, 772 } 773 else: 774 dct = json.loads(contents) 775 assert "data" in dct 776 assert "_version" in dct and dct["_version"] == VERSION 777 return FailuresDict(path, dct["data"]) 778 779 def _save(self, to_str=False) -> Optional[str]: 780 to_dump = { 781 "_description": DESCRIPTION, 782 "data": self.data, 783 "_version": VERSION, 784 } 785 # json.dumps doesn't end with a newline. Let's add one because files 786 # should end in newlines. 787 serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n" 788 if to_str: 789 return serialized 790 with open(self.path, "w") as fp: 791 fp.write(serialized) 792 return None 793 794 def save(self) -> None: 795 return self._save() 796 797 def get_status(self, qualname: str, test_name: str) -> str: 798 if qualname not in self.data: 799 return "xsuccess" 800 dct = self.data[qualname] 801 if test_name not in dct: 802 return "xsuccess" 803 return dct[test_name]["status"] 804 805 def set_status( 806 self, 807 qualname: str, 808 test_name: str, 809 status: str, 810 *, 811 comment: Optional[str] = None, 812 ): 813 if qualname not in self.data: 814 self.data[qualname] = {} 815 dct = self.data[qualname] 816 if test_name not in dct: 817 dct[test_name] = {"status": None, "comment": ""} 818 819 if status == "xsuccess": 820 # The default status is "xsuccess". 821 del dct[test_name] 822 else: 823 dct[test_name]["status"] = status 824 if comment is not None: 825 dct[test_name]["comment"] = comment 826