xref: /aosp_15_r20/external/pytorch/test/dynamo/test_after_aot.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import io
4import os
5import shutil
6import sys
7import tempfile
8import unittest
9
10import torch._dynamo.test_case
11from torch._dynamo.repro.after_aot import InputReader, InputWriter, save_graph_repro
12from torch.fx.experimental.proxy_tensor import make_fx
13from torch.testing._internal.common_utils import IS_FBCODE
14from torch.utils._traceback import report_compile_source_on_error
15
16
17def strip_trailing_whitespace(r):
18    return "\n".join([l.rstrip() for l in r.split("\n")])
19
20
21class TestAfterAot(torch._dynamo.test_case.TestCase):
22    @unittest.skipIf(IS_FBCODE, "NotImplementedError")
23    def test_save_graph_repro(self):
24        # TODO: This triggers CUDA context initialization, even though
25        # it is CPU only
26        buf = io.StringIO()
27        args = [torch.randn(4)]
28
29        def f(x):
30            return (x * x,)
31
32        gm = make_fx(f)(*args)
33        with tempfile.TemporaryDirectory() as d:
34            save_graph_repro(buf, gm, args, "inductor_accuracy", save_dir=d)
35            r = buf.getvalue()
36            with report_compile_source_on_error():
37                exec(r, {"__compile_source__": r})
38
39            shutil.rmtree(os.path.join(d, "storages"))
40
41            # Should still work even without the save dir
42            with report_compile_source_on_error():
43                exec(r, {"__compile_source__": r})
44
45    @unittest.skipIf(sys.byteorder != "little", "checksum depends on endianness")
46    def test_dump_tensor(self):
47        def test(tensor, expected):
48            with tempfile.TemporaryDirectory() as d:
49                writer = InputWriter(d, stable_hash=True)
50                writer.tensor("x", tensor)
51                self.assertExpectedInline("\n".join(writer._lines), expected, skip=1)
52                reader = InputReader(d)
53                env = {"reader": reader, "torch": torch}
54                # TODO: assert no logs
55                exec("\n".join(writer._lines), env)
56                self.assertEqual(reader.args[0], tensor)
57
58        test(
59            torch.zeros(3, 4),
60            """\
61buf0 = reader.storage('c17fd92682ca5b304ac71074b558dda9e8eb4d66', 48)
62reader.tensor(buf0, (3, 4), is_leaf=True)  # x""",
63        )
64        test(
65            torch.ones(3, 4, dtype=torch.int32),
66            """\
67buf0 = reader.storage('7c221e2da0c58c700cc2996644dd13d042bd552e', 48, dtype_hint=torch.int32)
68reader.tensor(buf0, (3, 4), dtype=torch.int32, is_leaf=True)  # x""",
69        )
70        test(
71            torch.empty((3, 4, 5, 6), memory_format=torch.channels_last).fill_(2),
72            """\
73buf0 = reader.storage('49ebab3961d6221e64c4c72b0aefd976bdd2afc4', 1440)
74reader.tensor(buf0, (3, 4, 5, 6), (120, 1, 24, 4), is_leaf=True)  # x""",
75        )
76
77
78if __name__ == "__main__":
79    from torch._dynamo.test_case import run_tests
80
81    run_tests()
82