xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/optests/generate_tests.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import datetime
4import difflib
5import functools
6import inspect
7import json
8import os
9import re
10import tempfile
11import threading
12import unittest
13from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
14
15import torch
16import torch._dynamo
17import torch.utils._pytree as pytree
18from torch._dynamo.utils import clone_input
19from torch._library.custom_ops import CustomOpDef
20from torch._subclasses.schema_check_mode import SchemaCheckMode
21from torch._utils_internal import get_file_path_2
22from torch.overrides import TorchFunctionMode
23from torch.testing._internal.optests import (
24    aot_autograd_check,
25    autograd_registration_check,
26    fake_check,
27)
28
29
30def dontGenerateOpCheckTests(reason: str):
31    def inner(fun):
32        fun._torch_dont_generate_opcheck_tests = True
33        return fun
34
35    return inner
36
37
38def is_abstract(tensor: torch.Tensor) -> bool:
39    if tensor.is_meta:
40        return True
41    if torch._subclasses.fake_tensor.is_fake(tensor):
42        return True
43    return False
44
45
46def safe_schema_check(
47    op: torch._ops.OpOverload,
48    args: Tuple[Any, ...],
49    kwargs: Dict[str, Any],
50    *,
51    copy_inputs: bool = True,
52) -> Any:
53    if copy_inputs:
54        args, kwargs = deepcopy_tensors((args, kwargs))
55    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
56        return None
57    with SchemaCheckMode():
58        result = op(*args, **kwargs)
59        return result
60
61
62def safe_autograd_registration_check(
63    op: torch._ops.OpOverload,
64    args: Tuple[Any, ...],
65    kwargs: Dict[str, Any],
66    *,
67    copy_inputs: bool = True,
68) -> None:
69    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
70        return
71    if copy_inputs:
72        args, kwargs = deepcopy_tensors((args, kwargs))
73    # Don't perform autograd_registration_check if none of the inputs require grad.
74    if not pytree.tree_any_only(
75        torch.Tensor, lambda x: x.requires_grad, (args, kwargs)
76    ):
77        return
78    return autograd_registration_check(op, args, kwargs)
79
80
81def safe_fake_check(
82    op: torch._ops.OpOverload,
83    args: Tuple[Any, ...],
84    kwargs: Dict[str, Any],
85    *,
86    copy_inputs: bool = True,
87) -> None:
88    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
89        return None
90    if copy_inputs:
91        args, kwargs = deepcopy_tensors((args, kwargs))
92    return fake_check(op, args, kwargs)
93
94
95def safe_aot_autograd_check(
96    op: torch._ops.OpOverload,
97    args: Tuple[Any, ...],
98    kwargs: Dict[str, Any],
99    dynamic: bool,
100    *,
101    copy_inputs: bool = True,
102) -> Any:
103    # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy
104    # inputs.
105    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
106        return None
107
108    def func(*args, **kwargs):
109        args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs))
110        return op(*args, **kwargs)
111
112    # aot_autograd_check runs func(*args, **kwargs) multiple times
113    # and assumes `func` does not modify its inputs.
114    return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")
115
116
117def deepcopy_tensors(inputs: Any) -> Any:
118    return pytree.tree_map_only(torch.Tensor, clone_input, inputs)
119
120
121# Test util requirements
122# - The test util must have signature (op: OpOverload, args, kwargs)
123# - The test util must NOT mutate args, kwargs.
124# - The test utils in this list must not be prefixes of each other. For example,
125#   having both "test_schema" and "test_schema_is_functional" is NOT OK.
126# - The order of items in this dict matters (for opcheck), we'll run them
127#   in order.
128ALL_TEST_UTILS = {
129    "test_schema": safe_schema_check,
130    "test_autograd_registration": safe_autograd_registration_check,
131    "test_faketensor": safe_fake_check,
132    "test_aot_dispatch_static": functools.partial(
133        safe_aot_autograd_check,
134        dynamic=False,
135    ),
136    "test_aot_dispatch_dynamic": functools.partial(
137        safe_aot_autograd_check,
138        dynamic=True,
139    ),
140}
141
142GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit"
143
144DEFAULT_TEST_UTILS = [
145    "test_schema",
146    "test_autograd_registration",
147    "test_faketensor",
148    "test_aot_dispatch_dynamic",
149]
150
151DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [
152    "test_aot_dispatch_static",
153]
154
155
156def generate_opcheck_tests(
157    testcase: Any,
158    namespaces: List[str],
159    failures_dict_path: Optional[str] = None,
160    additional_decorators: Dict[str, Callable] = None,
161    test_utils: List[str] = DEFAULT_TEST_UTILS,
162) -> None:
163    """Given an existing TestCase, use the existing tests to generate
164    additional validation tests for custom operators.
165
166    For {all existing tests in the TestCase} x {all test utils},
167    we will generate one new test. The new test runs a TorchFunctionMode
168    that intercepts ``op(*args, **kwargs)`` calls and invokes
169    ``test_util(op, *args, **kwargs)``, where ``op`` is an operator.
170
171    The test_util that we support are in ALL_TEST_UTILS. They are:
172    - test_schema: This runs SchemaCheckMode.
173    - test_autograd_registration: This runs autograd_registration_check.
174    - test_faketensor: This runs CrossRefFakeMode.
175    - test_aot_dispatch_static: This runs aot_autograd_check, which:
176        checks that the outputs (and gradients, if they are computable)
177        are the same under eager-mode PyTorch and using AOTAutograd.
178    - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
179        runs AOTAutograd using dynamic shapes instead of static shapes.
180
181    The generated test will have name ``{test_util}__{original_name}``.
182    For example, if there is a method named ``test_cumsum``, then
183    we will generate a ``test_schema__test_cumsum``,
184    ``test_faketensor__test_cumsum``, etc.
185
186    For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit
187
188    Args:
189        testcase: The testcase we will modify and generate additional tests for.
190        namespaces: We will only intercept calls to custom operators with these
191                    namespaces.
192        failures_dict_path: See ``validate_failures_dict_structure`` for more details
193        test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"]
194    """
195    if additional_decorators is None:
196        additional_decorators = {}
197    test_methods = [
198        m
199        for m in dir(testcase)
200        if m.startswith("test_") and callable(getattr(testcase, m))
201    ]
202    if failures_dict_path is None:
203        # The default failures_dict_path is failures_dict.json in
204        # the same directory as the test file.
205        prev_frame = inspect.currentframe().f_back
206        filename = inspect.getframeinfo(prev_frame)[0]
207        failures_dict_path = get_file_path_2(
208            os.path.dirname(filename), "failures_dict.json"
209        )
210    failures_dict = FailuresDict.load(
211        failures_dict_path, create_file=should_update_failures_dict()
212    )
213    validate_failures_dict_structure(failures_dict, test_utils, testcase)
214    validate_failures_dict_formatting(failures_dict_path)
215
216    def construct_method(attr, prefix, tester):
217        method = getattr(testcase, attr)
218        if getattr(method, "_torch_dont_generate_opcheck_tests", False):
219            return
220        new_method_name = prefix + "__" + attr
221
222        @functools.wraps(method)
223        def new_method(*args, **kwargs):
224            with OpCheckMode(
225                namespaces,
226                prefix,
227                tester,
228                failures_dict,
229                f"{testcase.__name__}.{new_method_name}",
230                failures_dict_path,
231            ):
232                result = method(*args, **kwargs)
233            return result
234
235        if pytestmark := new_method.__dict__.get("pytestmark"):
236            import pytest
237
238            # check if we need to simplify the parametrize marks
239            # NB: you need to add this mark to your pytest.ini
240            opcheck_only_one = False
241            for mark in pytestmark:
242                if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one":
243                    opcheck_only_one = True
244
245            if opcheck_only_one:
246                new_pytestmark = []
247                for mark in pytestmark:
248                    if isinstance(mark, pytest.Mark) and mark.name == "parametrize":
249                        argnames, argvalues = mark.args
250                        assert not mark.kwargs, "NYI"
251                        # Special case for device, we want to run on all
252                        # devices
253                        if argnames != "device":
254                            new_pytestmark.append(
255                                pytest.mark.parametrize(
256                                    argnames, (next(iter(argvalues)),)
257                                )
258                            )
259                            continue
260                    new_pytestmark.append(mark)
261                new_method.__dict__["pytestmark"] = new_pytestmark
262
263        if new_method_name in additional_decorators:
264            for dec in additional_decorators[new_method_name]:
265                new_method = dec(new_method)
266
267        if hasattr(testcase, new_method_name):
268            raise RuntimeError(
269                f"Tried to autogenerate {new_method_name} but {testcase} already "
270                f"has method named {new_method_name}. Please rename the original "
271                f"method on the TestCase."
272            )
273        setattr(testcase, new_method_name, new_method)
274
275    test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils}
276    for attr in test_methods:
277        for prefix, tester in test_utils.items():
278            construct_method(attr, prefix, tester)
279
280    generate_tag_tests(testcase, failures_dict, additional_decorators)
281
282
283def generate_tag_tests(testcase, failures_dict, additional_decorators):
284    def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests):
285        def inner(self):
286            try:
287                op = torch._library.utils.lookup_op(qualname)
288            except AttributeError as e:
289                # Operator not importable in this test file
290                raise unittest.SkipTest(f"Can't import operator {qualname}") from e
291            op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags
292            if not op_marked_as_compliant:
293                return
294            if not definitely_not_pt2_compliant:
295                return
296            raise AssertionError(
297                f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag "
298                f"but it failed some of the generated opcheck tests "
299                f"({xfailed_tests}). This may lead to silent correctness issues, "
300                f"please fix this."
301            )
302
303        return inner
304
305    for qualname, test_dict in failures_dict.data.items():
306        xfailed_tests = [
307            test
308            for test, status_dict in test_dict.items()
309            # We're about to delete the following test after Ed's PR
310            # to specialize on C++ .size() calls
311            if "test_aot_dispatch_static" not in test
312            and status_dict["status"] == "xfail"
313        ]
314        definitely_not_pt2_compliant = len(xfailed_tests) > 0
315        generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests)
316
317        # Could result in collisions, but unlikely. We'll raise if we see one below.
318        mangled_qualname = qualname.replace("::", "_").replace(".", "_")
319        test_name = "test_pt2_compliant_tag_" + mangled_qualname
320
321        # You can skip this test via the additional_decorators argument
322        # in generate_opcheck_tests
323        if test_name in additional_decorators:
324            for decorator in additional_decorators[test_name]:
325                generated = decorator(generated)
326
327        if hasattr(testcase, test_name):
328            raise RuntimeError(
329                f"Tried to generate a test named {test_name}, but it exists "
330                f"already. This could be because of a name collision (where "
331                f"we generated two tests with the same name), or where we "
332                f"generated a test with the same name as an existing test."
333            )
334        setattr(testcase, test_name, generated)
335
336
337TEST_OPTIONS = ("xfail", "skip", "xsuccess")
338
339
340def validate_failures_dict_formatting(failures_dict_path: str) -> None:
341    with open(failures_dict_path) as fp:
342        actual = fp.read()
343    failures_dict = FailuresDict.load(failures_dict_path)
344    expected = failures_dict._save(to_str=True)
345    if actual == expected:
346        return
347    if should_update_failures_dict():
348        failures_dict = FailuresDict.load(failures_dict_path)
349        failures_dict.save()
350        return
351    expected = expected.splitlines(1)
352    actual = actual.splitlines(1)
353    diff = difflib.unified_diff(actual, expected)
354    diff = "".join(diff)
355    raise RuntimeError(
356        f"\n{diff}\n\nExpected the failures dict to be formatted "
357        f"a certain way. Please see the above diff; you can correct "
358        f"this either manually or by re-running the test with "
359        f"PYTORCH_OPCHECK_ACCEPT=1"
360    )
361
362
363def validate_failures_dict_structure(
364    failure_dict: "FailuresDict", test_utils: List[str], testcase: Any
365) -> None:
366    """Validates the failures dict.
367
368    The failure dict looks something like the following.
369    It maps operator name (qualname) to a list of autogenerated tests.
370    Each autogenerated test may have a check for the operator (if the operator is
371    called by the test); the dictionary specifies if we should skip the check,
372    or if we expect some check to fail.
373
374    {
375        "fbgemm::split_lengths": {
376            "test_schema__test_split_lengths": {
377                "comment": "you can put whatever you want into the comment section",
378                "status": "xfail",
379            }
380            "test_schema__test_split_lengths_empty": {
381                "comment": "",
382                "status": "skip",
383            },
384        },
385        "fbgemm::gather_lengths": {
386            "test_schema__test_gather_lengths": {
387                "comment": "",
388                "status": "skip",
389            },
390        },
391    }
392
393    """
394    failure_dict = failure_dict.data
395    qualnames = list(failure_dict.keys())
396    for test_to_option in failure_dict.values():
397        test_names = list(test_to_option.keys())
398        for test_name, test_dict in test_to_option.items():
399            if set(test_dict.keys()) != set({"comment", "status"}):
400                raise RuntimeError(
401                    "in failures_dict, expected sub-dict to have keys 'comment' and 'status'"
402                )
403            test_option = test_dict["status"]
404            if test_option not in TEST_OPTIONS:
405                raise RuntimeError(
406                    f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}"
407                )
408            test_class, actual_test_name = test_name.split(".")
409            if not any(actual_test_name.startswith(test) for test in test_utils):
410                raise RuntimeError(
411                    f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}"
412                )
413            for test in test_utils:
414                if not actual_test_name.startswith(test):
415                    continue
416                base_test_name = actual_test_name[len(test) + 2 :]
417                # remove potential pytest parametrization suffix
418                base_test_name = re.sub(r"\[.*\]", "", base_test_name)
419                if testcase.__name__ != test_class:
420                    continue
421                if hasattr(testcase, base_test_name):
422                    continue
423                raise RuntimeError(
424                    f"In failures dict, got test name '{test_name}'. We parsed this as "
425                    f"running test '{test}' on '{base_test_name}', but "
426                    f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. "
427                    f"Maybe you need to change the test name?"
428                )
429
430
431def should_update_failures_dict() -> bool:
432    key = "PYTORCH_OPCHECK_ACCEPT"
433    return key in os.environ and os.environ[key] == "1"
434
435
436_is_inside_opcheck_mode = threading.local()
437_is_inside_opcheck_mode.value = False
438
439
440def is_inside_opcheck_mode():
441    return _is_inside_opcheck_mode.value
442
443
444class OpCheckMode(TorchFunctionMode):
445    """
446    For a given test, OpCheckMode intercepts calls to operators and runs
447    test_util(op, args, kwargs) for each intercepted (op, args, kwargs).
448    """
449
450    def __init__(
451        self,
452        namespaces: List[str],
453        test_util_name: str,
454        test_util: Callable,
455        failures_dict: "FailuresDict",
456        test_name: str,
457        failures_dict_path: str,
458    ):
459        # We will intercept calls to ops with these namespaces
460        self.namespaces = namespaces
461        # The test utility function. Its signature should be (op, args, kwargs) -> None.
462        # Examples of test utilities are: schema_check, make_fx_check
463        self.test_util = test_util
464        self.test_util_name = test_util_name
465        # The name of the test that is running this OpCheckMode.
466        self.test_name = test_name
467        # Maps qualname -> test_name -> skip/xfail
468        # Tells us if we should skip a test or assert that there is a failure.
469        self.failures_dict = failures_dict
470        # Location of the failures dict. Makes it so that the error message is better.
471        self.failures_dict_path = failures_dict_path
472
473        # OpCheckMode surpresses errors, collects them here, and then raises them on exit.
474        # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)]
475        self.seen_ops_to_errors = {}
476
477    def maybe_raise_errors_on_exit(self) -> None:
478        # Check expected failures first
479        for qualname in self.seen_ops_to_errors.keys():
480            option = self.failures_dict.get_status(qualname, self.test_name)
481            if len(self.seen_ops_to_errors[qualname]) == 0:
482                if should_update_failures_dict():
483                    self.failures_dict.set_status(
484                        qualname, self.test_name, "xsuccess", comment=""
485                    )
486                else:
487                    if option == "xfail":
488                        raise OpCheckError(
489                            f"generate_opcheck_tests: Unexpected success for operator "
490                            f"{qualname} on test {self.test_name}. This may mean that "
491                            f"you have fixed this test failure. Please rerun the test with "
492                            f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner "
493                            f"or manually remove the "
494                            f"expected failure in the failure dict at "
495                            f"{self.failures_dict_path}"
496                            f"For more details, see "
497                            f"{GDOC}"
498                        )
499                continue
500        failed_ops = []
501        for qualname in self.seen_ops_to_errors.keys():
502            option = self.failures_dict.get_status(qualname, self.test_name)
503            if option != "xsuccess":
504                continue
505            if len(self.seen_ops_to_errors[qualname]) == 0:
506                continue
507            failed_ops.append(qualname)
508        if not failed_ops:
509            return
510
511        if should_update_failures_dict():
512            for op in failed_ops:
513                self.failures_dict.set_status(op, self.test_name, "xfail")
514            return
515
516        # Raise from the first error but also report about all of them to make
517        # recording xfails easier.
518        ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0]
519        repro_command = generate_repro(
520            self.test_util_name, op, args, kwargs, save_data=should_print_better_repro()
521        )
522        raise OpCheckError(
523            f"Test generated by `generate_opcheck_tests`, {self.test_name}, "
524            f"failed on operators {failed_ops}. This usually means that the "
525            f"operators are not implemented correctly and may lead to silently "
526            f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, "
527            f"or please see "
528            f"{GDOC} "
529            f"for more recommendations. "
530            f"To reproduce this problem locally, try to run the following:\n{repro_command}"
531        ) from ex
532
533    def __enter__(self, *args, **kwargs):
534        self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value
535        self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "")
536        _is_inside_opcheck_mode.value = True
537        os.environ["TORCHDYNAMO_DISABLE"] = "1"
538        return super().__enter__(*args, **kwargs)
539
540    def __exit__(self, *args, **kwargs):
541        _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode
542        os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable
543        try:
544            self.maybe_raise_errors_on_exit()
545            if should_update_failures_dict():
546                self.failures_dict.save()
547        finally:
548            result = super().__exit__(*args, **kwargs)
549        return result
550
551    def run_test_util(self, op, args, kwargs):
552        try:
553            self.test_util(op, args, kwargs, copy_inputs=False)
554        except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
555            # We might get here if the input is already a FakeTensor
556            # or if we're in a torch.compile block. Just ignore these
557            # since we can't handle them and reporting them as failures
558            # is too noisy.
559            pass
560
561    def __torch_function__(self, func, types, args=(), kwargs=None):
562        kwargs = kwargs if kwargs else {}
563
564        # Only intercept calls to operators
565        if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
566            return func(*args, **kwargs)
567        if (
568            torch.jit.is_tracing()
569            or torch.jit.is_scripting()
570            or torch._dynamo.is_compiling()
571        ):
572            return func(*args, **kwargs)
573        # Pre-existing code may not use the .default overload. If we see an
574        # OpOverloadPacket and we cannot resolve the overload, then we just throw
575        # and ask the user to clarify. Otherwise, we attempt to resolve the overload.
576        if isinstance(func, torch._ops.OpOverloadPacket):
577            func = resolve_unique_overload_or_throw(func)
578        qualname = func.name()
579        ns = qualname.split("::")[0]
580        if ns not in self.namespaces:
581            return func(*args, **kwargs)
582
583        args_c, kwargs_c = deepcopy_tensors((args, kwargs))
584        result = func(*args, **kwargs)
585
586        option = self.failures_dict.get_status(qualname, self.test_name)
587        if option == "xsuccess" or option == "xfail":
588            # Surpress all errors during execution. Raise them during __exit__.
589            try:
590                if qualname not in self.seen_ops_to_errors:
591                    self.seen_ops_to_errors[qualname] = []
592                self.run_test_util(func, args_c, kwargs_c)
593            except Exception as ex:
594                if should_print_better_repro():
595                    self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs))
596                else:
597                    self.seen_ops_to_errors[qualname].append((ex, func, None, None))
598        elif option == "skip":
599            pass
600        return result
601
602
603def should_print_better_repro() -> None:
604    """If set, the tests generated by `generate_opcheck_tests` will print a
605    repro command on failure.
606
607    In order to print the repro command, we need to save some tensors to disk.
608    These will be saved under the following directory:
609    {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/.
610
611    Although this is a temp folder, it will usually not automatically get cleaned
612    up, so you'll need to manually delete it.
613    """
614    key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO"
615    if key not in os.environ:
616        return False
617    value = os.environ[key]
618    return value == "1" or value == 1
619
620
621def opcheck(
622    op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
623    args: Tuple[Any, ...],
624    kwargs: Optional[Dict[str, Any]] = None,
625    *,
626    test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
627    raise_exception: bool = True,
628) -> Dict[str, str]:
629    """See torch.library.opcheck for docstring"""
630
631    if kwargs is None:
632        kwargs = {}
633    if isinstance(op, CustomOpDef):
634        op = op._opoverload
635    if isinstance(op, torch._ops.OpOverloadPacket):
636        op = resolve_unique_overload_or_throw(op)
637    if not isinstance(op, torch._ops.OpOverload):
638        raise ValueError(
639            f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, "
640            f"e.g. torch.ops.aten.sin.default, got {type(op)}"
641        )
642    if test_utils == "ALL":
643        test_utils = tuple(ALL_TEST_UTILS.keys())
644    if isinstance(test_utils, str):
645        test_utils = (test_utils,)
646    if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset(
647        ALL_TEST_UTILS.keys()
648    ):
649        raise ValueError(
650            f"opcheck(op, ..., test_utils={test_utils}), expected test_utils "
651            f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not"
652        )
653
654    results_dict = {}
655    for test_util in test_utils:
656        tester = ALL_TEST_UTILS[test_util]
657        try:
658            tester(op, args, kwargs)
659            results_dict[test_util] = "SUCCESS"
660        except Exception as ex:
661            if raise_exception:
662                raise OpCheckError(
663                    f"opcheck(op, ...): {test_util} failed with {ex} "
664                    f"(scroll up for stack trace)"
665                ) from ex
666            results_dict[test_util] = ex
667    return results_dict
668
669
670class OpCheckError(Exception):
671    pass
672
673
674def generate_repro(
675    test: str,
676    op: torch._ops.OpOverload,
677    args: Tuple[Any, ...],
678    kwargs: Dict[str, Any],
679    *,
680    save_data: bool,
681    dry_run: bool = False,
682) -> str:
683    if save_data:
684        now = datetime.datetime.now()
685        path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete")
686        unix_timestamp = datetime.datetime.timestamp(now) * 100000
687        filepath = os.path.join(path, f"repro_{unix_timestamp}.pt")
688        if not dry_run:
689            os.makedirs(path, exist_ok=True)
690            torch.save((args, kwargs), filepath)
691        args_kwargs = f'args, kwargs = torch.load("{filepath}")'
692    else:
693        args_kwargs = (
694            "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n"
695            "# we will fill them in same (args, kwargs) as in your test\n"
696            "args = ()  # args to the operator\n"
697            "kwargs = {}  # kwargs to the operator"
698        )
699
700    ns, name = op._schema.name.split("::")
701    overload = op._overloadname
702
703    repro_command = (
704        f"# =========================================================\n"
705        f"# BEGIN REPRO SCRIPT\n"
706        f"# =========================================================\n"
707        f"import torch\n"
708        f"from torch.testing._internal.optests import opcheck\n"
709        f"\n"
710        f"# Make sure you have loaded the library that contains the op\n"
711        f"# via an import or torch.ops.load_library(...)\n"
712        f"op = torch.ops.{ns}.{name}.{overload}\n"
713        f"\n"
714        f"{args_kwargs}\n"
715        f'opcheck(op, args, kwargs, test_utils="{test}")\n'
716        f"# =========================================================\n"
717        f"# END REPRO SCRIPT\n"
718        f"# =========================================================\n"
719    )
720    return repro_command
721
722
723def resolve_unique_overload_or_throw(
724    op: torch._ops.OpOverloadPacket,
725) -> torch._ops.OpOverload:
726    all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name)
727    if len(all_schemas) != 1:
728        raise RuntimeError(
729            f"opcheck can only test operators without overloads. "
730            f"Got the following overloads for {op._qualified_op_name}: "
731            f"{[schema.overload_name for schema in all_schemas]}"
732        )
733
734    overload_name = all_schemas[0].overload_name
735    if overload_name == "":
736        return op.default
737    return getattr(op, overload_name)
738
739
740DUMP_OPTIONS = {"indent": 2, "sort_keys": True}
741
742
743FailuresDictData = Dict[str, Dict[str, Dict[str, str]]]
744
745
746VERSION = 1
747DESCRIPTION = (
748    f"This is a dict containing failures for tests autogenerated by "
749    f"generate_opcheck_tests. "
750    f"For more details, please see {GDOC}"
751)
752
753
754class FailuresDict:
755    def __init__(self, path: str, data: FailuresDictData):
756        self.path = path
757        self.data = data
758
759    @staticmethod
760    def load(path, *, create_file=False) -> "FailuresDict":
761        if create_file and not os.path.exists(path):
762            result = FailuresDict(path, {})
763            FailuresDict.save()
764            return result
765        with open(path) as fp:
766            contents = fp.read()
767            if contents.strip() == "":
768                dct = {
769                    "_description": DESCRIPTION,
770                    "data": {},
771                    "_version": VERSION,
772                }
773            else:
774                dct = json.loads(contents)
775                assert "data" in dct
776                assert "_version" in dct and dct["_version"] == VERSION
777        return FailuresDict(path, dct["data"])
778
779    def _save(self, to_str=False) -> Optional[str]:
780        to_dump = {
781            "_description": DESCRIPTION,
782            "data": self.data,
783            "_version": VERSION,
784        }
785        # json.dumps doesn't end with a newline. Let's add one because files
786        # should end in newlines.
787        serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n"
788        if to_str:
789            return serialized
790        with open(self.path, "w") as fp:
791            fp.write(serialized)
792        return None
793
794    def save(self) -> None:
795        return self._save()
796
797    def get_status(self, qualname: str, test_name: str) -> str:
798        if qualname not in self.data:
799            return "xsuccess"
800        dct = self.data[qualname]
801        if test_name not in dct:
802            return "xsuccess"
803        return dct[test_name]["status"]
804
805    def set_status(
806        self,
807        qualname: str,
808        test_name: str,
809        status: str,
810        *,
811        comment: Optional[str] = None,
812    ):
813        if qualname not in self.data:
814            self.data[qualname] = {}
815        dct = self.data[qualname]
816        if test_name not in dct:
817            dct[test_name] = {"status": None, "comment": ""}
818
819        if status == "xsuccess":
820            # The default status is "xsuccess".
821            del dct[test_name]
822        else:
823            dct[test_name]["status"] = status
824            if comment is not None:
825                dct[test_name]["comment"] = comment
826