xref: /aosp_15_r20/external/pytorch/torch/_dynamo/test_minifier_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import io
4import logging
5import os
6import re
7import shutil
8import subprocess
9import sys
10import tempfile
11import traceback
12from typing import Optional
13from unittest.mock import patch
14
15import torch
16import torch._dynamo
17import torch._dynamo.test_case
18from torch._dynamo.trace_rules import _as_posix_path
19from torch.utils._traceback import report_compile_source_on_error
20
21
22@dataclasses.dataclass
23class MinifierTestResult:
24    minifier_code: str
25    repro_code: str
26
27    def _get_module(self, t):
28        match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t)
29        assert match is not None, "failed to find module"
30        r = match.group(0)
31        r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE)
32        r = re.sub(r"\n{3,}", "\n\n", r)
33        return r.strip()
34
35    def minifier_module(self):
36        return self._get_module(self.minifier_code)
37
38    def repro_module(self):
39        return self._get_module(self.repro_code)
40
41
42class MinifierTestBase(torch._dynamo.test_case.TestCase):
43    DEBUG_DIR = tempfile.mkdtemp()
44
45    @classmethod
46    def setUpClass(cls):
47        super().setUpClass()
48        cls._exit_stack.enter_context(  # type: ignore[attr-defined]
49            torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR)
50        )
51        # These configurations make new process startup slower.  Disable them
52        # for the minification tests to speed them up.
53        cls._exit_stack.enter_context(  # type: ignore[attr-defined]
54            torch._inductor.config.patch(
55                {
56                    # https://github.com/pytorch/pytorch/issues/100376
57                    "pattern_matcher": False,
58                    # multiprocess compilation takes a long time to warmup
59                    "compile_threads": 1,
60                    # https://github.com/pytorch/pytorch/issues/100378
61                    "cpp.vec_isa_ok": False,
62                }
63            )
64        )
65
66    @classmethod
67    def tearDownClass(cls):
68        if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1":
69            shutil.rmtree(cls.DEBUG_DIR)
70        else:
71            print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
72        cls._exit_stack.close()  # type: ignore[attr-defined]
73
74    def _gen_codegen_fn_patch_code(self, device, bug_type):
75        assert bug_type in ("compile_error", "runtime_error", "accuracy")
76        return f"""\
77{torch._dynamo.config.codegen_config()}
78{torch._inductor.config.codegen_config()}
79torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
80"""
81
82    def _maybe_subprocess_run(self, args, *, isolate, cwd=None):
83        if not isolate:
84            assert len(args) >= 2, args
85            assert args[0] == "python3", args
86            if args[1] == "-c":
87                assert len(args) == 3, args
88                code = args[2]
89                args = ["-c"]
90            else:
91                assert len(args) >= 2, args
92                with open(args[1]) as f:
93                    code = f.read()
94                args = args[1:]
95
96            # WARNING: This is not a perfect simulation of running
97            # the program out of tree.  We only interpose on things we KNOW we
98            # need to handle for tests.  If you need more stuff, you will
99            # need to augment this appropriately.
100
101            # NB: Can't use save_config because that will omit some fields,
102            # but we must save and reset ALL fields
103            dynamo_config = torch._dynamo.config.shallow_copy_dict()
104            inductor_config = torch._inductor.config.shallow_copy_dict()
105            try:
106                stderr = io.StringIO()
107                log_handler = logging.StreamHandler(stderr)
108                log = logging.getLogger("torch._dynamo")
109                log.addHandler(log_handler)
110                try:
111                    prev_cwd = _as_posix_path(os.getcwd())
112                    if cwd is not None:
113                        cwd = _as_posix_path(cwd)
114                        os.chdir(cwd)
115                    with patch("sys.argv", args), report_compile_source_on_error():
116                        exec(code, {"__name__": "__main__", "__compile_source__": code})
117                    rc = 0
118                except Exception:
119                    rc = 1
120                    traceback.print_exc(file=stderr)
121                finally:
122                    log.removeHandler(log_handler)
123                    if cwd is not None:
124                        os.chdir(prev_cwd)  # type: ignore[possibly-undefined]
125                    # Make sure we don't leave buggy compiled frames lying
126                    # around
127                    torch._dynamo.reset()
128            finally:
129                torch._dynamo.config.load_config(dynamo_config)
130                torch._inductor.config.load_config(inductor_config)
131
132            # TODO: return a more appropriate data structure here
133            return subprocess.CompletedProcess(
134                args,
135                rc,
136                b"",
137                stderr.getvalue().encode("utf-8"),
138            )
139        else:
140            if cwd is not None:
141                cwd = _as_posix_path(cwd)
142            return subprocess.run(args, capture_output=True, cwd=cwd, check=False)
143
144    # Run `code` in a separate python process.
145    # Returns the completed process state and the directory containing the
146    # minifier launcher script, if `code` outputted it.
147    def _run_test_code(self, code, *, isolate):
148        proc = self._maybe_subprocess_run(
149            ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR
150        )
151
152        print("test stdout:", proc.stdout.decode("utf-8"))
153        print("test stderr:", proc.stderr.decode("utf-8"))
154        repro_dir_match = re.search(
155            r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8")
156        )
157        if repro_dir_match is not None:
158            return proc, repro_dir_match.group(1)
159        return proc, None
160
161    # Runs the minifier launcher script in `repro_dir`
162    def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()):
163        self.assertIsNotNone(repro_dir)
164        launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
165        with open(launch_file) as f:
166            launch_code = f.read()
167        self.assertTrue(os.path.exists(launch_file))
168
169        args = ["python3", launch_file, "minify", *minifier_args]
170        if not isolate:
171            args.append("--no-isolate")
172        launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir)
173        print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
174        stderr = launch_proc.stderr.decode("utf-8")
175        print("minifier stderr:", stderr)
176        self.assertNotIn("Input graph did not fail the tester", stderr)
177
178        return launch_proc, launch_code
179
180    # Runs the repro script in `repro_dir`
181    def _run_repro(self, repro_dir, *, isolate=True):
182        self.assertIsNotNone(repro_dir)
183        repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
184        with open(repro_file) as f:
185            repro_code = f.read()
186        self.assertTrue(os.path.exists(repro_file))
187
188        repro_proc = self._maybe_subprocess_run(
189            ["python3", repro_file], isolate=isolate, cwd=repro_dir
190        )
191        print("repro stdout:", repro_proc.stdout.decode("utf-8"))
192        print("repro stderr:", repro_proc.stderr.decode("utf-8"))
193        return repro_proc, repro_code
194
195    # Template for testing code.
196    # `run_code` is the code to run for the test case.
197    # `patch_code` is the code to be patched in every generated file; usually
198    # just use this to turn on bugs via the config
199    def _gen_test_code(self, run_code, repro_after, repro_level):
200        return f"""\
201import torch
202import torch._dynamo
203{_as_posix_path(torch._dynamo.config.codegen_config())}
204{_as_posix_path(torch._inductor.config.codegen_config())}
205torch._dynamo.config.repro_after = "{repro_after}"
206torch._dynamo.config.repro_level = {repro_level}
207torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
208{run_code}
209"""
210
211    # Runs a full minifier test.
212    # Minifier tests generally consist of 3 stages:
213    # 1. Run the problematic code
214    # 2. Run the generated minifier launcher script
215    # 3. Run the generated repro script
216    #
217    # If possible, you should run the test with isolate=False; use
218    # isolate=True only if the bug you're testing would otherwise
219    # crash the process
220    def _run_full_test(
221        self, run_code, repro_after, expected_error, *, isolate, minifier_args=()
222    ) -> Optional[MinifierTestResult]:
223        if isolate:
224            repro_level = 3
225        elif expected_error is None or expected_error == "AccuracyError":
226            repro_level = 4
227        else:
228            repro_level = 2
229        test_code = self._gen_test_code(run_code, repro_after, repro_level)
230        print("running test", file=sys.stderr)
231        test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate)
232        if expected_error is None:
233            # Just check that there was no error
234            self.assertEqual(test_proc.returncode, 0)
235            self.assertIsNone(repro_dir)
236            return None
237        # NB: Intentionally do not test return code; we only care about
238        # actually generating the repro, we don't have to crash
239        self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
240        self.assertIsNotNone(repro_dir)
241        print("running minifier", file=sys.stderr)
242        minifier_proc, minifier_code = self._run_minifier_launcher(
243            repro_dir, isolate=isolate, minifier_args=minifier_args
244        )
245        print("running repro", file=sys.stderr)
246        repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
247        self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
248        self.assertNotEqual(repro_proc.returncode, 0)
249        return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)
250