xref: /aosp_15_r20/external/pytorch/test/inductor/test_debug_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import logging
3import os
4import re
5import shutil
6import sys
7import tempfile
8import unittest
9from pathlib import Path
10
11import torch
12from torch._inductor import config, test_operators
13from torch._inductor.utils import fresh_inductor_cache
14from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
15
16
17try:
18    try:
19        from . import test_torchinductor
20    except ImportError:
21        import test_torchinductor
22except unittest.SkipTest:
23    if __name__ == "__main__":
24        sys.exit(0)
25    raise
26
27
28def filesize(filename: Path):
29    assert filename.exists(), f"{filename} is missing"
30    return os.stat(filename).st_size
31
32
33@config.patch("trace.enabled", True)
34class TestDebugTrace(test_torchinductor.TestCase):
35    def test_debug_trace(self):
36        @torch.compile
37        def fn(a, b):
38            a = test_operators.realize(a + 1) + 2
39            return torch.matmul(a, b)
40
41        # TODO(aakhundov): make this work with fresh_inductor_cache
42        # instead of force_disable_caches. currently, with the latter
43        # enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
44        # the counters: so the cache is actually hit and the test fails.
45        with config.patch(
46            {
47                "trace.debug_dir": tempfile.mkdtemp(),
48                "force_disable_caches": True,
49            }
50        ):
51            with self.assertLogs(
52                logging.getLogger("torch._inductor.debug"), level=logging.WARNING
53            ) as cm:
54                fn(torch.randn(16, 16), torch.randn(16, 16))
55
56        self.assertEqual(len(cm.output), 1)
57        m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
58        self.assertTrue(m)
59        filename = Path(m.group(1))
60        self.assertTrue(filename.is_dir())
61        self.assertGreater(filesize(filename / "fx_graph_readable.py"), 512)
62        self.assertGreater(filesize(filename / "fx_graph_runnable.py"), 512)
63        self.assertGreater(filesize(filename / "fx_graph_transformed.py"), 512)
64        self.assertGreater(filesize(filename / "output_code.py"), 1024)
65        self.assertExpectedInline(
66            open(filename / "ir_pre_fusion.txt").read().rstrip(),
67            """\
68op0: SchedulerNode(ComputedBuffer)
69op0.writes = [MemoryDep('buf0', c0, {c0: 256}, None)]
70op0.unmet_dependencies = []
71op0.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 256}, None)]
72op0.outputs = [
73    buf0: ComputedBuffer
74    buf0.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
75    buf0.users = [NodeUser(node=SchedulerNode(name='op1'), can_inplace=True, is_weak=False)]
76]
77op0.group.device = cpu
78op0.group.iteration = ((256,), ())
79op0.sizes = ([256], [])
80arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
81buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
82class op0_loop_body:
83    var_ranges = {z0: 256}
84    index0 = z0
85    def body(self, ops):
86        get_index = self.get_index('index0')
87        load = ops.load('arg0_1', get_index)
88        constant = ops.constant(1.0, torch.float32)
89        add = ops.add(load, constant)
90        get_index_1 = self.get_index('index0')
91        store = ops.store('buf0', get_index_1, add, None)
92        return store
93
94
95op1: SchedulerNode(ComputedBuffer)
96op1.writes = [MemoryDep('buf1', c0, {c0: 256}, None)]
97op1.unmet_dependencies = [MemoryDep('buf0', c0, {c0: 256}, None)]
98op1.met_dependencies = []
99op1.outputs = [
100    buf1: ComputedBuffer
101    buf1.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
102    buf1.users = [NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False)]
103]
104op1.group.device = cpu
105op1.group.iteration = ((256,), ())
106op1.sizes = ([256], [])
107buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
108buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
109class op1_loop_body:
110    var_ranges = {z0: 256}
111    index0 = z0
112    def body(self, ops):
113        get_index = self.get_index('index0')
114        load = ops.load('buf0', get_index)
115        constant = ops.constant(2.0, torch.float32)
116        add = ops.add(load, constant)
117        get_index_1 = self.get_index('index0')
118        store = ops.store('buf1', get_index_1, add, None)
119        return store
120
121
122op2: ExternKernelSchedulerNode(ExternKernelOut)
123op2.writes = [StarDep(name='buf2', mode=None)]
124op2.unmet_dependencies = [StarDep(name='buf1', mode=None)]
125op2.met_dependencies = [StarDep(name='arg1_1', mode=None)]
126op2.outputs = [
127    buf2: ExternKernelOut
128    buf2.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
129    buf2.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
130]
131op2.node.kernel = extern_kernels.mm""",
132        )
133        self.assertExpectedInline(
134            open(filename / "ir_post_fusion.txt").read().rstrip(),
135            """\
136op0_op1: FusedSchedulerNode(SchedulerNode,SchedulerNode)
137op0_op1.writes = [MemoryDep('buf0', c0, {c0: 256}, None), MemoryDep('buf1', c0, {c0: 256}, None)]
138op0_op1.unmet_dependencies = []
139op0_op1.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 256}, None)]
140op0_op1.outputs = [
141    buf0: ComputedBuffer
142    buf0.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
143    buf0.users = [NodeUser(node=SchedulerNode(name='op1'), can_inplace=True, is_weak=False)]
144    buf1: ComputedBuffer
145    buf1.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
146    buf1.users = [NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False)]
147]
148op0_op1.snodes[0] =
149op0: SchedulerNode(ComputedBuffer)
150op0.writes = [MemoryDep('buf0', c0, {c0: 256}, None)]
151op0.unmet_dependencies = []
152op0.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 256}, None)]
153op0.outputs = [
154    buf0: ComputedBuffer
155    buf0.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
156    buf0.users = [NodeUser(node=SchedulerNode(name='op1'), can_inplace=True, is_weak=False)]
157]
158op0.group.device = cpu
159op0.group.iteration = ((256,), ())
160op0.sizes = ([256], [])
161arg0_1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
162buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
163class op0_loop_body:
164    var_ranges = {z0: 256}
165    index0 = z0
166    def body(self, ops):
167        get_index = self.get_index('index0')
168        load = ops.load('arg0_1', get_index)
169        constant = ops.constant(1.0, torch.float32)
170        add = ops.add(load, constant)
171        get_index_1 = self.get_index('index0')
172        store = ops.store('buf0', get_index_1, add, None)
173        return store
174op0_op1.snodes[1] =
175op1: SchedulerNode(ComputedBuffer)
176op1.writes = [MemoryDep('buf1', c0, {c0: 256}, None)]
177op1.unmet_dependencies = [MemoryDep('buf0', c0, {c0: 256}, None)]
178op1.met_dependencies = []
179op1.outputs = [
180    buf1: ComputedBuffer
181    buf1.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
182    buf1.users = [NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False)]
183]
184op1.group.device = cpu
185op1.group.iteration = ((256,), ())
186op1.sizes = ([256], [])
187buf0_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
188buf1_layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
189class op1_loop_body:
190    var_ranges = {z0: 256}
191    index0 = z0
192    def body(self, ops):
193        get_index = self.get_index('index0')
194        load = ops.load('buf0', get_index)
195        constant = ops.constant(2.0, torch.float32)
196        add = ops.add(load, constant)
197        get_index_1 = self.get_index('index0')
198        store = ops.store('buf1', get_index_1, add, None)
199        return store
200
201
202op2: ExternKernelSchedulerNode(ExternKernelOut)
203op2.writes = [StarDep(name='buf2', mode=None)]
204op2.unmet_dependencies = [StarDep(name='buf1', mode=None)]
205op2.met_dependencies = [StarDep(name='arg1_1', mode=None)]
206op2.outputs = [
207    buf2: ExternKernelOut
208    buf2.layout = FixedLayout('cpu', torch.float32, size=[16, 16], stride=[16, 1])
209    buf2.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
210]
211op2.node.kernel = extern_kernels.mm""",
212        )
213        # intentionally only cleanup on success so debugging test is easier
214        shutil.rmtree(filename)
215
216    @unittest.skipIf(not HAS_GPU, "requires GPU")
217    def test_debug_multi_tempalte(self):
218        class ToyModel(torch.nn.Module):
219            def __init__(self) -> None:
220                super().__init__()
221                self.l = torch.nn.Linear(100, 100)
222                self.relu = torch.nn.ReLU()
223
224            def forward(self, x):
225                return self.relu(self.l(x))
226
227        # no failure
228        with self.assertLogs(
229            logging.getLogger("torch._inductor.debug"), level=logging.WARNING
230        ), fresh_inductor_cache():
231            m = ToyModel().to(device=GPU_TYPE)
232            m = torch.compile(m, mode="max-autotune")
233            input_tensor = torch.randn(100).to(device=GPU_TYPE)
234            m(input_tensor)
235
236
237if __name__ == "__main__":
238    from torch._inductor.test_case import run_tests
239    from torch.testing._internal.inductor_utils import HAS_CPU
240
241    if HAS_CPU:
242        run_tests(needs="filelock")
243