xref: /aosp_15_r20/external/pytorch/test/profiler/test_execution_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2
3# if tqdm is not shutdown properly, it will leave the monitor thread alive.
4# This causes an issue in the multithreading test because we check all events
5# in that test with their tids. The events that correspond to these lingering
6# threads all have TID of (uint64_t)(-1) which is invalid.
7# The work around is turnning off monitoring thread when tqdm is loaded.
8# Since these are unit tests, it is safe to turn off monitor thread.
9try:
10    import tqdm
11
12    tqdm.tqdm.monitor_interval = 0
13except ImportError:
14    pass
15
16import json
17import sys
18import tempfile
19import unittest
20from typing import Any, Dict, List
21
22import torch
23import torch.nn as nn
24from torch import _dynamo as torchdynamo
25from torch.autograd import (
26    _record_function_with_args_enter,
27    _record_function_with_args_exit,
28)
29from torch.profiler import (
30    ExecutionTraceObserver,
31    kineto_available,
32    profile,
33    record_function,
34    supported_activities,
35)
36from torch.testing._internal.common_cuda import TEST_CUDA
37from torch.testing._internal.common_utils import (
38    IS_WINDOWS,
39    run_tests,
40    skipIfTorchDynamo,
41    TestCase,
42)
43from torch.utils._triton import has_triton
44
45
46Json = Dict[str, Any]
47
48
49class TestExecutionTrace(TestCase):
50    def payload(self, use_cuda=False):
51        u = torch.randn(3, 4, 5, requires_grad=True)
52        with record_function("## TEST 1 ##", "1, 2, 3"):
53            inf_val = float("inf")
54            neg_inf_val = float("-inf")
55            nan_val = float("nan")
56            rf_handle = _record_function_with_args_enter(
57                "## TEST 2 ##",
58                1,
59                False,
60                2.5,
61                [u, u],
62                (u, u),
63                "hello",
64                u,
65                inf_val,
66                neg_inf_val,
67                nan_val,
68            )
69            x = torch.randn(10, 10, requires_grad=True)
70            if use_cuda:
71                x = x.cuda()
72            y = torch.randn(10, 10, requires_grad=True)
73            if use_cuda:
74                y = y.cuda()
75            z = x + y + x * y + x * y
76            z.backward(z)
77            gelu = nn.GELU()
78            m = torch.randn(2)
79            _ = gelu(m)
80            if use_cuda:
81                z = z.cpu()
82            _record_function_with_args_exit(rf_handle)
83
84    def get_execution_trace_root(self, output_file_name) -> Json:
85        nodes = []
86        with open(output_file_name) as f:
87            et_graph = json.load(f)
88            assert "nodes" in et_graph
89            nodes = et_graph["nodes"]
90        return nodes
91
92    def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
93        """Returns a sorted list of rf_id (record function ids) in execution trace"""
94
95        def get_rf_id(node):
96            attrs = node["attrs"]
97            for a in attrs:
98                if a["name"] == "rf_id":
99                    return a["value"]
100            return None
101
102        rf_ids_ = (
103            get_rf_id(n)
104            for n in nodes
105            if n["name"] != "[pytorch|profiler|execution_trace|process]"
106            and n["name"] != "[pytorch|profiler|execution_trace|thread]"
107        )
108        return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
109
110    def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
111        """Returns a sorted list of Record function IDs for CPU operators and user annotations"""
112        ops_and_annotations = (
113            e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
114        )
115        return sorted(
116            e.get("args", {}).get("Record function id", -1) for e in ops_and_annotations
117        )
118
119    @unittest.skipIf(not kineto_available(), "Kineto is required")
120    def test_execution_trace_with_kineto(self):
121        trace_called_num = 0
122
123        def trace_handler(p):
124            nonlocal trace_called_num
125            trace_called_num += 1
126
127        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
128        # Create a temp file to save execution trace and kineto data.
129        fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
130        fp.close()
131        kt = tempfile.NamedTemporaryFile(
132            mode="w+t", suffix=".kineto.json", delete=False
133        )
134        kt.close()
135
136        with profile(
137            activities=supported_activities(),
138            schedule=torch.profiler.schedule(
139                skip_first=3, wait=1, warmup=1, active=2, repeat=1
140            ),
141            on_trace_ready=trace_handler,
142            execution_trace_observer=(
143                ExecutionTraceObserver().register_callback(fp.name)
144            ),
145        ) as p:
146            for idx in range(10):
147                with record_function(f"## LOOP {idx} ##"):
148                    self.payload(use_cuda=use_cuda)
149                p.step()
150            self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path())
151
152        # Uncomment for debugging
153        # print("Output kineto = ", kt.name)
154        # print("Output ET = ", fp.name)
155
156        p.export_chrome_trace(kt.name)
157        self.assertEqual(trace_called_num, 1)
158
159        nodes = self.get_execution_trace_root(fp.name)
160        loop_count = 0
161        found_root_node = False
162        for n in nodes:
163            assert "name" in n
164            if "[pytorch|profiler|execution_trace|process]" in n["name"]:
165                found_root_node = True
166            if n["name"].startswith("## LOOP "):
167                loop_count += 1
168        self.assertTrue(found_root_node)
169        # Since profiler trace is active for 2 iterations
170        self.assertEqual(loop_count, 2)
171
172        # Compare the collected Execution Trace and Kineto Trace
173        # in terms of record func ID (rf_id) and External IDs
174        # both of these should match for the same trace window.
175
176        with open(kt.name) as f:
177            kineto = json.load(f)
178            events = kineto["traceEvents"]
179
180        # Look up rf_ids in both Execution and Kineto trace as two lists.
181        rf_ids_et = self.get_execution_trace_rf_ids(nodes)
182        rf_ids_kineto = self.get_kineto_rf_ids(events)
183
184        self.assertCountEqual(rf_ids_et, rf_ids_kineto)
185        self.assertListEqual(
186            rf_ids_et,
187            rf_ids_kineto,
188            msg=f"ET and kineto rf_id should exactly match\n"
189            f"  rf_ids_et = {rf_ids_et}\n"
190            f"  rf_ids_kineto = {rf_ids_kineto}\n",
191        )
192
193    def test_execution_trace_alone(self):
194        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
195        # Create a temp file to save execution trace data.
196        fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
197        fp.close()
198        expected_loop_events = 0
199
200        et = ExecutionTraceObserver().register_callback(fp.name)
201
202        et.start()
203        for idx in range(5):
204            expected_loop_events += 1
205            with record_function(f"## LOOP {idx} ##"):
206                self.payload(use_cuda=use_cuda)
207        et.stop()
208
209        assert fp.name == et.get_output_file_path()
210        et.unregister_callback()
211        nodes = self.get_execution_trace_root(fp.name)
212        loop_count = 0
213        # Expected tensor object tuple size, in th form of:
214        # [tensor_id, storage_id, offset, numel, itemsize, device_str]
215        tensor_tuple_size = 6
216        found_root_node = False
217        for n in nodes:
218            assert "name" in n
219            if "[pytorch|profiler|execution_trace|process]" in n["name"]:
220                found_root_node = True
221            if n["name"].startswith("## LOOP "):
222                loop_count += 1
223            # Check if tensor tuple representation size is correct.
224            if n["name"] == "## TEST 2 ##":
225                assert len(n["inputs"]["values"][3][0]) == tensor_tuple_size
226        assert found_root_node
227        assert loop_count == expected_loop_events
228
229    @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
230    @unittest.skipIf(
231        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
232    )
233    @unittest.skipIf(not TEST_CUDA or not has_triton(), "need CUDA and triton to run")
234    def test_execution_trace_with_pt2(self):
235        @torchdynamo.optimize("inductor")
236        def fn(a, b, c):
237            x = torch.nn.functional.linear(a, b)
238            x = x + c
239            return x.cos()
240
241        a, b, c = (torch.randn(4, 4, requires_grad=True).to("cuda") for _ in range(3))
242
243        inputs = [a, b, c]
244        with torch._inductor.config.patch(compile_threads=1):
245            fn(*inputs)
246
247        # Create a temp file to save execution trace data.
248        fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False)
249        fp.close()
250
251        with profile(
252            activities=torch.profiler.supported_activities(),
253            record_shapes=True,
254            schedule=torch.profiler.schedule(
255                skip_first=3, wait=1, warmup=1, active=2, repeat=1
256            ),
257            execution_trace_observer=(
258                ExecutionTraceObserver().register_callback(fp.name)
259            ),
260        ) as p:
261            for idx in range(10):
262                with record_function(f"## LOOP {idx} ##"):
263                    fn(*inputs)
264                p.step()
265
266        nodes = self.get_execution_trace_root(fp.name)
267        found_captured_triton_kernel_node = False
268        for n in nodes:
269            assert "name" in n
270            if "triton_" in n["name"]:
271                for attr in n["attrs"]:
272                    if attr["name"] == "kernel_file" and attr["value"] != "":
273                        found_captured_triton_kernel_node = True
274                        assert len(n["inputs"]["values"]) > 0
275                        assert len(n["outputs"]["values"]) == 0
276        assert found_captured_triton_kernel_node
277
278    def test_execution_trace_start_stop(self):
279        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
280        # Create a temp file to save execution trace data.
281        fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
282        fp.close()
283        expected_loop_events = 0
284        et = ExecutionTraceObserver().register_callback(fp.name)
285        for idx in range(10):
286            if idx == 3:
287                et.start()
288            elif idx == 5:
289                et.stop()
290            elif idx == 8:
291                et.start()
292            elif idx == 9:
293                et.stop()
294            if et._execution_trace_running:
295                expected_loop_events += 1
296            with record_function(f"## LOOP {idx} ##"):
297                self.payload(use_cuda=use_cuda)
298
299        assert fp.name == et.get_output_file_path()
300        et.unregister_callback()
301        nodes = self.get_execution_trace_root(fp.name)
302        loop_count = 0
303        found_root_node = False
304        for n in nodes:
305            assert "name" in n
306            if "[pytorch|profiler|execution_trace|process]" in n["name"]:
307                found_root_node = True
308            if n["name"].startswith("## LOOP "):
309                loop_count += 1
310        assert found_root_node
311        assert loop_count == expected_loop_events
312
313    def test_execution_trace_repeat_in_loop(self):
314        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
315        iter_list = {3, 4, 6, 8}
316        expected_loop_events = len(iter_list)
317        output_files = []
318        for idx in range(10):
319            if idx in iter_list:
320                # Create a temp file to save execution trace data.
321                fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
322                fp.close()
323                output_files.append(fp.name)
324                et = ExecutionTraceObserver().register_callback(fp.name)
325                et.start()
326            with record_function(f"## LOOP {idx} ##"):
327                self.payload(use_cuda=use_cuda)
328            if idx in iter_list:
329                et.stop()
330                et.unregister_callback()
331
332        event_count = 0
333        for et_file in output_files:
334            nodes = self.get_execution_trace_root(et_file)
335            found_root_node = False
336            for n in nodes:
337                assert "name" in n
338                if "[pytorch|profiler|execution_trace|process]" in n["name"]:
339                    assert n["id"] == 1
340                    found_root_node = True
341                if n["name"].startswith("## LOOP "):
342                    event_count += 1
343            assert found_root_node
344        assert event_count == expected_loop_events
345
346    def test_execution_trace_no_capture(self):
347        fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
348        fp.close()
349        et = ExecutionTraceObserver().register_callback(fp.name)
350
351        assert fp.name == et.get_output_file_path()
352        et.unregister_callback()
353        nodes = self.get_execution_trace_root(fp.name)
354        for n in nodes:
355            assert "name" in n
356            if "[pytorch|profiler|execution_trace|process]" in n["name"]:
357                found_root_node = True
358        assert found_root_node
359
360    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
361    def test_execution_trace_nested_tensor(self):
362        fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
363        fp.close()
364
365        observer = ExecutionTraceObserver().register_callback(fp.name)
366
367        def fn(nt):
368            return nt.sin().cos()
369
370        with torch.profiler.profile(execution_trace_observer=observer) as prof:
371            for i in range(3):
372                values = torch.rand((8 + i, 4 + i))
373                offsets = torch.tensor([0, 2, 4, 6, 8 + i])
374                nt = torch.nested.nested_tensor_from_jagged(values, offsets)
375                fn(nt)
376
377        nodes = self.get_execution_trace_root(fp.name)
378        found_cos = False
379        for n in nodes:
380            assert "name" in n
381            if "cos" in n["name"]:
382                found_cos = True
383        assert found_cos
384
385
386if __name__ == "__main__":
387    run_tests()
388