xref: /aosp_15_r20/external/pytorch/test/profiler/test_torch_tidy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2
3# if tqdm is not shutdown properly, it will leave the monitor thread alive.
4# This causes an issue in the multithreading test because we check all events
5# in that test with their tids. The events that correspond to these lingering
6# threads all have TID of (uint64_t)(-1) which is invalid.
7# The work around is turnning off monitoring thread when tqdm is loaded.
8# Since these are unit tests, it is safe to turn off monitor thread.
9try:
10    import tqdm
11
12    tqdm.tqdm.monitor_interval = 0
13except ImportError:
14    None
15
16import gc
17import re
18import textwrap
19import unittest
20import weakref
21from typing import Any, Dict, List
22
23import torch
24import torch.nn as nn
25import torch.optim
26import torch.utils.data
27from torch._C._profiler import _TensorMetadata
28from torch.profiler import _utils, profile
29from torch.testing._internal.common_utils import run_tests, TestCase
30
31
32Json = Dict[str, Any]
33
34from torch._C._profiler import _ExtraFields_PyCall
35
36
37def find_node_with_name(nodes, name):
38    for node in _utils.traverse_dfs(nodes):
39        if node.name == name:
40            return node
41
42
43def find_node_with_regex(nodes, pattern):
44    for node in _utils.traverse_dfs(nodes):
45        if re.search(pattern, node.name):
46            return node
47
48
49class SimpleNet(nn.Module):
50    def __init__(self) -> None:
51        super().__init__()
52        self.fc1 = nn.Linear(10, 5)
53        self.fc2 = nn.Linear(5, 2)
54
55    def forward(self, x):
56        return self.fc2(self.fc1(x))
57
58
59class TestTorchTidyProfiler(TestCase):
60    def _get_tensor_fields(self, node, index):
61        self.assertIsNotNone(node)
62        self.assertIsInstance(
63            node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
64        )
65        tensor_info = node.extra_fields.inputs[index]
66        self.assertIsInstance(tensor_info, _TensorMetadata)
67        self.assertIsNotNone(tensor_info.impl_ptr)
68        self.assertIsNotNone(tensor_info.storage_data_ptr)
69        self.assertIsNotNone(tensor_info.id)
70        return tensor_info.impl_ptr, tensor_info.storage_data_ptr, tensor_info.id
71
72    def test_pointers_and_ids(self):
73        a = torch.randn(4, 3)
74        a_initial_storage_data = a.storage().data_ptr()
75
76        # Views of tensors can share the same storage, but have different TensorImpls
77        b = a.view((1, 12))
78        c = torch.randn(4, 1)
79        c_initial_storage_data = c.storage().data_ptr()
80        d = torch.randn(4, 3)
81
82        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
83            _ = a + c
84            _ = b * c
85
86            # Resize should create a new data_ptr but keep the TensorImpl the same.
87            f = a.resize_(128, 129)
88            _ = torch.relu(f)
89
90            # `.set_` points a Tensor at an existing storage.
91            _ = d.sin()
92            c.set_(d.storage())
93            _ = c.cos()
94
95        nodes = p.profiler.kineto_results.experimental_event_tree()
96
97        def get_fields(op_name, index):
98            return self._get_tensor_fields(find_node_with_name(nodes, op_name), index)
99
100        a_impl, a_storage_data, a_id = get_fields("aten::add", 0)
101        b_impl, b_storage_data, b_id = get_fields("aten::mul", 0)
102
103        # Profiler matches ground truth from Python API.
104        self.assertEqual(a_storage_data, a_initial_storage_data)
105
106        # Views are handled correctly.
107        self.assertEqual(a_storage_data, b_storage_data)
108        self.assertNotEqual(a_impl, b_impl)
109
110        # The same Tensor used in multiple calls gives identical results.
111        c_impl, c_storage_data, c_id = get_fields("aten::add", 1)
112        self.assertEqual((c_impl, c_storage_data, c_id), get_fields("aten::mul", 1))
113        self.assertEqual(c_storage_data, c_initial_storage_data)
114
115        # Mutations to the underlying storage are reflected. (But ID is shared.)
116        f_impl, f_storage_data, f_id = get_fields("aten::relu", 0)
117        self.assertEqual(a_impl, f_impl)
118        self.assertNotEqual(a_storage_data, f_storage_data)
119        self.assertEqual(a_id, f_id)
120
121        # Calling `set_` with an existing Tensor makes them share an ID.
122        d_impl, d_storage_data, d_id = get_fields("aten::sin", 0)
123        c_impl_new, c_storage_data_new, c_id_new = get_fields("aten::cos", 0)
124        self.assertNotEqual(d_impl, c_impl_new)
125        self.assertEqual(d_storage_data, c_storage_data_new)
126        self.assertEqual(c_id, c_id_new)
127        self.assertEqual(d_id, c_id_new)
128
129    @staticmethod
130    def _format_allocations(profiled_code):
131        gc.collect()
132        with profile(profile_memory=True, record_shapes=True) as prof:
133            profiled_code()
134            gc.collect()
135
136        root_events = prof.profiler.kineto_results.experimental_event_tree()
137        events = sorted(_utils.traverse_dfs(root_events), key=lambda x: x.start_time_ns)
138        allocations = tuple(
139            event.extra_fields
140            for event in events
141            if isinstance(
142                event.extra_fields, torch._C._profiler._ExtraFields_Allocation
143            )
144        )
145
146        return textwrap.indent(
147            "\n".join(
148                f"{repr(i.id):>5}{' ' * 6}"
149                f"{repr(i.allocation_id):>5}{' ' * 6}"
150                f"{'Allocation' if i.alloc_size > 0 else 'Free'}"
151                for i in allocations
152            ),
153            " " * 12,
154        )
155
156    def test_tensorimpl_invalidation_set(self) -> None:
157        def profiled_code(add_empty_set: bool):
158            x = torch.ones((1,))
159
160            # Determines if new storage is created before or after the old one
161            # is destroyed.
162            if add_empty_set:
163                x.set_()
164
165            x.set_(torch.ones((1,)).storage())
166            x.view_as(x)
167
168        self.assertExpectedInline(
169            self._format_allocations(lambda: profiled_code(add_empty_set=False)),
170            """\
171                0          1      Allocation
172                0          2      Allocation
173                0          1      Free
174                0          2      Free""",
175        )
176
177        self.assertExpectedInline(
178            self._format_allocations(lambda: profiled_code(add_empty_set=True)),
179            """\
180                0          1      Allocation
181                0          1      Free
182                0          2      Allocation
183                0          2      Free""",
184        )
185
186    def test_tensorimpl_invalidation_keep_alive(self) -> None:
187        def profiled_code(add_empty_set: bool):
188            x = torch.ones((1,))
189            x_storages = [x.storage()]
190            for _ in range(3):
191                x.set_()
192                x.set_(torch.ones((1,)).storage())
193
194                # This keeps the StorageImpls alive and preserves the chain.
195                # (Despite the `set_()` call.)
196                x_storages.append(x.storage())
197            x.view_as(x)
198
199            # Free storage in a deterministic fashion.
200            while x_storages:
201                x_storages.pop()
202                gc.collect()
203
204            # Determines if new storage is created before or after the old one
205            # is destroyed.
206            if add_empty_set:
207                x.set_()
208
209            for _ in range(3):
210                x.set_(torch.ones((1,)).storage())
211            x.view_as(x)
212
213            del x
214            gc.collect()
215
216        self.assertExpectedInline(
217            self._format_allocations(lambda: profiled_code(add_empty_set=False)),
218            """\
219                0          1      Allocation
220                0          2      Allocation
221                0          4      Allocation
222                0          5      Allocation
223                0          4      Free
224                0          2      Free
225                0          1      Free
226                0          6      Allocation
227                0          5      Free
228                0          7      Allocation
229                0          6      Free
230                0          8      Allocation
231                0          7      Free
232                0          8      Free""",
233        )
234
235        self.assertExpectedInline(
236            self._format_allocations(lambda: profiled_code(add_empty_set=True)),
237            """\
238                0          1      Allocation
239                0          2      Allocation
240                0          4      Allocation
241                0          5      Allocation
242                0          4      Free
243                0          2      Free
244                0          1      Free
245                0          5      Free
246                0          6      Allocation
247                0          7      Allocation
248                0          6      Free
249                0          8      Allocation
250                0          7      Free
251                0          8      Free""",
252        )
253
254    def test_tensorimpl_invalidation_full(self) -> None:
255        def profiled_code():
256            x = torch.ones((1,))
257            x_storages = [x.storage()]
258            for _ in range(3):
259                x.set_()
260                x.set_(torch.ones((1,)).storage())
261                x_storages.append(x.storage())
262            x.view_as(x)
263
264            # Free storage in a deterministic fashion.
265            while x_storages:
266                x_storages.pop()
267                gc.collect()
268
269            for _ in range(3):
270                x.set_(torch.ones((1,)).storage())
271
272            for _ in range(3):
273                x.set_()
274                x.set_(torch.ones((1,)).storage())
275
276            for i in range(4):
277                x.resize_((1 + i,))
278            x.view_as(x)
279
280        self.assertExpectedInline(
281            self._format_allocations(profiled_code),
282            """\
283                0          1      Allocation
284                0          2      Allocation
285                0          4      Allocation
286                0          5      Allocation
287                0          4      Free
288                0          2      Free
289                0          1      Free
290                0          6      Allocation
291                0          5      Free
292                0          7      Allocation
293                0          6      Free
294                0          8      Allocation
295                0          7      Free
296                0          8      Free
297                0          9      Allocation
298                0          9      Free
299                0         10      Allocation
300                0         10      Free
301                0         11      Allocation
302                0         12      Allocation
303                0         11      Free
304                0         13      Allocation
305                0         12      Free
306                0         14      Allocation
307                0         13      Free
308                0         14      Free""",
309        )
310
311    def test_tensorimpl_invalidation_scalar_args(self) -> None:
312        def profiled_code():
313            with torch.no_grad():
314                x = torch.ones((1,))
315                for _ in range(10):
316                    x.add_(2)
317
318        self.assertExpectedInline(
319            self._format_allocations(profiled_code),
320            """\
321                0          1      Allocation
322                1          2      Allocation
323                2          3      Allocation
324                2          3      Free
325                1          2      Free
326                3          4      Allocation
327                4          5      Allocation
328                4          5      Free
329                3          4      Free
330                5          6      Allocation
331                6          7      Allocation
332                6          7      Free
333                5          6      Free
334                7          8      Allocation
335                8          9      Allocation
336                8          9      Free
337                7          8      Free
338                9         10      Allocation
339               10         11      Allocation
340               10         11      Free
341                9         10      Free
342               11         12      Allocation
343               12         13      Allocation
344               12         13      Free
345               11         12      Free
346               13         14      Allocation
347               14         15      Allocation
348               14         15      Free
349               13         14      Free
350               15         16      Allocation
351               16         17      Allocation
352               16         17      Free
353               15         16      Free
354               17         18      Allocation
355               18         19      Allocation
356               18         19      Free
357               17         18      Free
358               19         20      Allocation
359               20         21      Allocation
360               20         21      Free
361               19         20      Free
362                0          1      Free""",
363        )
364
365    def test_module_and_optimizer_ids(self) -> None:
366        model = torch.nn.Linear(2, 1, bias=True)
367        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
368
369        def check(cold_start: bool) -> None:
370            with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
371                x = torch.ones((1, 2))
372                _ = x.sin()  # Mark `x`
373                model(x).backward()
374                optimizer.step()
375                _ = optimizer.state[model.weight][
376                    "momentum_buffer"
377                ].cos()  # Mark weight momentum
378                _ = model.weight.grad.tan()  # Mark weight gradient
379
380            nodes = p.profiler.kineto_results.experimental_event_tree()
381
382            def get_fields(op_name, index):
383                return self._get_tensor_fields(
384                    find_node_with_name(nodes, op_name), index
385                )
386
387            # Marked Tensors act as ground truth for python tracer IDs.
388            _, _, x_id = get_fields("aten::sin", 0)
389            _, _, weight_momenumtum_id = get_fields("aten::cos", 0)
390            _, _, weight_grad_id = get_fields("aten::tan", 0)
391            self.assertNotEqual(x_id, weight_momenumtum_id)
392            self.assertNotEqual(x_id, weight_grad_id)
393            self.assertNotEqual(weight_momenumtum_id, weight_grad_id)
394
395            # Use linear op to identify weight ground truth.
396            linear_op_node = find_node_with_name(nodes, "aten::linear")
397            self.assertIsNotNone(linear_op_node)
398            x_metadata, weight_metadata, _ = linear_op_node.extra_fields.inputs
399            self.assertEqual(x_id, x_metadata.id)
400
401            # Module
402            linear_module_node = find_node_with_name(nodes, "nn.Module: Linear_0")
403            self.assertIsNotNone(linear_module_node)
404            self.assertIsNotNone(linear_module_node.extra_fields.module)
405            self.assertIsNone(linear_module_node.extra_fields.optimizer)
406
407            linear_parameters = linear_module_node.extra_fields.module.parameters
408            name, weight, weight_grad = linear_parameters[0]
409            self.assertEqual(name, "weight")
410            self.assertEqual(weight.id, weight_metadata.id)
411
412            self.assertEqual(weight_grad is None, cold_start)
413            if not cold_start:
414                self.assertEqual(weight_grad.id, weight_grad_id)
415
416            # Optimizer
417            step_node = find_node_with_regex(nodes, "_optimizer_step_code")
418            self.assertIsNotNone(step_node)
419            self.assertIsNone(step_node.extra_fields.module)
420            self.assertIsNotNone(step_node.extra_fields.optimizer)
421            optimizer_parameters = step_node.extra_fields.optimizer.parameters
422            self.assertEqual(len(optimizer_parameters), 2)  # Weight and bias
423            weight, weight_grad, state = optimizer_parameters[0]
424            self.assertEqual(weight.id, weight_metadata.id)
425            self.assertEqual(weight_grad.id, weight_grad_id)
426            self.assertEqual(len(state), 1)
427            self.assertEqual(state[0][0], "momentum_buffer")
428            self.assertEqual(state[0][1].id, weight_momenumtum_id)
429
430        # Check that we handle first step (lazy initalization) and steady state.
431        check(cold_start=True)
432        check(cold_start=False)
433
434    def _test_allocation_ids(self, before_fn, after_fn) -> None:
435        with profile(profile_memory=True, record_shapes=True) as p:
436            # Introduce other operations and allocations to check robustness
437            _ = before_fn()
438
439            x = torch.rand(4, 3)
440            x.resize_(4, 4)
441
442            # We need to use `x` post resize for profiler to determine its ID.
443            x.sin()
444
445            # Introduce other operations and allocations to check robustness
446            _ = after_fn()
447
448            # Ensure `x` is the last variable collected to make it easier to
449            # find the deallocation event.
450            gc.collect()
451            del x
452            gc.collect()
453
454        nodes = p.profiler.kineto_results.experimental_event_tree()
455
456        def find_chain(names: List[str]):
457            out = []
458            for name in names:
459                root = [out[-1]] if out else nodes
460                out.append(find_node_with_name(root, name))
461                self.assertIsNotNone(out[-1], name)
462            return out
463
464        allocation = find_chain(["aten::rand", "aten::empty", "[memory]"])[
465            -1
466        ].extra_fields
467        _, uniform_node = find_chain(["aten::rand", "aten::uniform_"])
468        x_impl, x_storage_data, x_id = self._get_tensor_fields(uniform_node, 0)
469
470        # Make sure IDs are consistent between allocations and op inputs
471        self.assertEqual(allocation.ptr, x_storage_data)
472        self.assertEqual(allocation.id, x_id)
473
474        resize_node = find_node_with_name(nodes, "aten::resize_")
475        self.assertIsNotNone(resize_node)
476        self.assertEqual(len(resize_node.children), 2)
477        allocate_new = resize_node.children[0].extra_fields
478        free_old = resize_node.children[1].extra_fields
479
480        # Destruction of the old storage for x.
481        self.assertEqual(free_old.id, allocation.id)
482        self.assertEqual(free_old.ptr, allocation.ptr)
483
484        # Make sure ID is retained through change in storage.
485        self.assertEqual(allocate_new.id, allocation.id)
486        self.assertNotEqual(allocate_new.ptr, allocation.ptr)
487
488        # Deletion when `x` goes out of scope.
489        free_new = [
490            i for i in nodes if i.tag == torch._C._profiler._EventType.Allocation
491        ][-1].extra_fields
492        self.assertIsInstance(free_new, torch._C._profiler._ExtraFields_Allocation)
493        self.assertEqual(free_new.id, allocate_new.id)
494        self.assertEqual(free_new.ptr, allocate_new.ptr)
495
496    def test_allocation_ids(self) -> None:
497        self._test_allocation_ids(lambda: None, lambda: None)
498
499    def test_allocation_ids_with_other_ops(self) -> None:
500        x = torch.ones((1,))
501        self._test_allocation_ids(
502            lambda: (x + 1).relu_(), lambda: torch.zeros((1,)).cos()
503        )
504
505    def test_impl_reuse(self) -> None:
506        repeats = 1_000
507        with profile(profile_memory=True, record_shapes=True) as p:
508            for _ in range(repeats):
509                torch.ones((1,))
510            gc.collect()
511
512        roots = p.profiler.kineto_results.experimental_event_tree()
513        tensor_impls = tuple(
514            e.extra_fields.inputs[0].impl_ptr
515            for e in _utils.traverse_dfs(roots)
516            if e.name == "aten::fill_"
517        )
518
519        self.assertEqual(len(tensor_impls), repeats)
520        self.assertEqual(len(set(tensor_impls)), repeats)
521
522    def test_allocation_id_uniqueness(self) -> None:
523        repeats = 1_000
524        with profile(profile_memory=True, record_shapes=True) as p:
525            for _ in range(repeats):
526                torch.ones((1,))
527            gc.collect()
528
529        roots = p.profiler.kineto_results.experimental_event_tree()
530        id_set = set()
531        for e in _utils.traverse_dfs(roots):
532            fields = e.extra_fields
533            if isinstance(fields, torch._C._profiler._ExtraFields_TorchOp):
534                id_set |= {
535                    t.allocation_id
536                    for t in fields.inputs
537                    if isinstance(t, _TensorMetadata)
538                }
539
540            elif isinstance(fields, torch._C._profiler._ExtraFields_Allocation):
541                id_set.add(fields.allocation_id)
542
543        id_set.difference_update([None])
544        self.assertEqual(repeats, len(id_set))
545
546    def test_extra_fields(self):
547        with profile(with_stack=True, profile_memory=True) as p:
548            _ = torch.ones((1,))
549
550        nodes = p.profiler.kineto_results.experimental_event_tree()
551        node = find_node_with_name(nodes, "aten::ones")
552        self.assertIsNotNone(node)
553
554        self.assertIsInstance(
555            node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
556        )
557
558        self.assertIsInstance(
559            node.parent.extra_fields, torch._C._profiler._ExtraFields_PyCCall
560        )
561
562        self.assertEqual(node.children[0].name, "aten::empty")
563        self.assertEqual(node.children[0].children[0].name, "[memory]")
564        self.assertIsInstance(
565            node.children[0].children[0].extra_fields,
566            torch._C._profiler._ExtraFields_Allocation,
567        )
568
569    def test_tensor_properties(self):
570        x = torch.ones(10, 10).as_strided([4, 4], [12, 3])
571        y = torch.ones(4, 1, requires_grad=True)
572
573        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
574            _ = x + y
575            _ = x * y
576
577        nodes = p.profiler.kineto_results.experimental_event_tree()
578        node = find_node_with_name(nodes, "aten::add")
579        self.assertIsNotNone(node)
580
581        self.assertIsInstance(
582            node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
583        )
584
585        def getattr_inputs(name, default):
586            return [getattr(i, name, default) for i in node.extra_fields.inputs]
587
588        self.assertEqual(getattr_inputs("sizes", []), [[4, 4], [4, 1], []])
589        self.assertEqual(getattr_inputs("strides", []), [[12, 3], [1, 1], []])
590        self.assertEqual(
591            getattr_inputs("layout", None), [torch.strided, torch.strided, None]
592        )
593        self.assertEqual(
594            getattr_inputs("device", None),
595            [torch.device("cpu"), torch.device("cpu"), None],
596        )
597        self.assertEqual(
598            getattr_inputs("dtype", None), [torch.float32, torch.float32, None]
599        )
600        self.assertEqual(node.extra_fields.scope, torch.profiler.RecordScope.FUNCTION)
601
602        mul_node = find_node_with_name(nodes, "aten::mul")
603        self.assertIsNotNone(mul_node)
604        self.assertEqual(
605            node.extra_fields.sequence_number + 1, mul_node.extra_fields.sequence_number
606        )
607
608    def test_sparse_tensors(self):
609        i = [[0, 1, 1], [2, 0, 2]]
610        v = [3, 4, 5]
611        s = torch.sparse_coo_tensor(i, v, (2, 3))
612
613        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
614            _ = s + s
615
616        nodes = p.profiler.kineto_results.experimental_event_tree()
617        node = find_node_with_name(nodes, "aten::add")
618        self.assertIsNotNone(node)
619
620        self.assertIsInstance(
621            node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
622        )
623
624        def getattr_inputs(name, default):
625            return [getattr(i, name, default) for i in node.extra_fields.inputs]
626
627        self.assertEqual(getattr_inputs("sizes", []), [[2, 3], [2, 3], []])
628        self.assertEqual(getattr_inputs("strides", []), [[], [], []])
629        self.assertEqual(
630            getattr_inputs("layout", None), [torch.sparse_coo, torch.sparse_coo, None]
631        )
632        self.assertEqual(
633            getattr_inputs("device", None),
634            [torch.device("cpu"), torch.device("cpu"), None],
635        )
636
637    @unittest.skipIf(
638        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
639    )
640    def test_mkldnn_tensors(self):
641        x = torch.ones(4, 3).to_mkldnn()
642
643        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
644            _ = x + x
645
646        nodes = p.profiler.kineto_results.experimental_event_tree()
647        node = find_node_with_name(nodes, "aten::add")
648        self.assertIsNotNone(node)
649
650        self.assertIsInstance(
651            node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
652        )
653
654        def getattr_inputs(name, default):
655            return [getattr(i, name, default) for i in node.extra_fields.inputs]
656
657        self.assertEqual(getattr_inputs("sizes", []), [[4, 3], [4, 3], []])
658        self.assertEqual(getattr_inputs("strides", []), [[], [], []])
659        self.assertEqual(
660            getattr_inputs("layout", None), [torch._mkldnn, torch._mkldnn, None]
661        )
662        self.assertEqual(
663            getattr_inputs("device", None),
664            [torch.device("cpu"), torch.device("cpu"), None],
665        )
666
667    def test_scalar_ins(self):
668        x = torch.ones(5, 5)
669        alpha = 0.9
670
671        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
672            _ = torch.add(x, 9.1, alpha=alpha)
673
674        nodes = p.profiler.kineto_results.experimental_event_tree()
675        node = find_node_with_name(nodes, "aten::add")
676        self.assertIsNotNone(node)
677
678        def getattr_inputs(name, default):
679            return [getattr(i, name, default) for i in node.extra_fields.inputs]
680
681        # The second argument to the add gets promotoed to a zerodim Tensor
682        self.assertEqual(
683            getattr_inputs("dtype", None), [torch.float32, torch.float64, None]
684        )
685        self.assertEqual(getattr_inputs("sizes", []), [[5, 5], [], []])
686        self.assertEqual(node.extra_fields.inputs[2], alpha)
687
688    def test_tensor_lists(self):
689        x = torch.ones((1,))
690        y = torch.ones((1,))
691        with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
692            _ = torch.stack((x, y))
693
694        nodes = p.profiler.kineto_results.experimental_event_tree()
695        node = find_node_with_name(nodes, "aten::stack")
696        inputs = node.extra_fields.inputs
697        self.assertEqual(len(inputs), 2)
698        self.assertIsInstance(inputs[0], list)
699        self.assertEqual(len(inputs[0]), 2)
700        self.assertEqual(x.storage().data_ptr(), inputs[0][0].storage_data_ptr)
701        self.assertEqual(y.storage().data_ptr(), inputs[0][1].storage_data_ptr)
702
703    def test_nnmodule_params(self):
704        def flat_out_extrafields(nodes, out=None):
705            if out is None:
706                out = []
707            for node in nodes:
708                if (
709                    isinstance(node.extra_fields, _ExtraFields_PyCall)
710                    and node.extra_fields.module
711                ):
712                    if node.extra_fields.module.parameters:
713                        out.append(node.extra_fields.module)
714                flat_out_extrafields(node.children, out)
715            return out
716
717        inputs = torch.rand(10)
718        net = SimpleNet()
719        out = net(inputs)
720        torch.nn.functional.cross_entropy(out, torch.rand(2)).backward()
721        with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
722            _ = net(inputs)
723
724        modules = flat_out_extrafields(
725            p.profiler.kineto_results.experimental_event_tree()
726        )
727        self.assertEqual(
728            len(modules), 2, f"Expected two parameter list, but got {len(modules)}"
729        )
730
731        params = [
732            (n, p.storage_data_ptr, g.storage_data_ptr)
733            for module in modules
734            for (n, p, g) in module.parameters
735        ]
736        expected = [
737            (name, val.storage().data_ptr(), val.grad.storage().data_ptr())
738            for name, val in net.fc1._parameters.items()
739        ]
740        expected += [
741            (name, val.storage().data_ptr(), val.grad.storage().data_ptr())
742            for name, val in net.fc2._parameters.items()
743        ]
744        self.assertEqual(expected, params, f"{expected} vs. {params}")
745
746    def _flat_out_extrafields(self, nodes, out=None):
747        if out is None:
748            out = []
749        for node in nodes:
750            if (
751                isinstance(node.extra_fields, _ExtraFields_PyCall)
752                and node.extra_fields.optimizer
753                and node.extra_fields.optimizer.parameters
754            ):
755                # avoiding OptInfo duplicates from iterations
756                addr = node.extra_fields.optimizer.parameters[0][0].storage_data_ptr
757                if not [o for o in out if addr == o.parameters[0][0].storage_data_ptr]:
758                    out.append(node.extra_fields.optimizer)
759            self._flat_out_extrafields(node.children, out)
760        return out
761
762    def _check_results(self, opt, opts, check_items=False):
763        self.assertEqual(len(opts), 1, f"Expected 1 optimizer: len(opts): {len(opts)}")
764        self.assertEqual(
765            id(opt),
766            opts[0].self_ptr,
767            f"Optimizer addr ({id(opt)}) vs. profiled addr ({opts[0].self_ptr})",
768        )
769        if check_items:
770            self.assertEqual(len(opt.param_groups), len(opts))
771            for group, opt_ in zip(opt.param_groups, opts):
772                self.assertEqual(
773                    [(v.storage().data_ptr()) for v in group.get("params", [])],
774                    [(o.storage_data_ptr) for (o, _, _) in opt_.parameters],
775                )
776            for opt_ in opts:
777                observed_state = {
778                    p.storage_data_ptr: {name: s.storage_data_ptr for name, s in state}
779                    for (p, _, state) in opt_.parameters
780                }
781
782                # Make sure the profiler collected all optimizer state and check
783                # that the address recorded by the profiler is correct.
784                for parameter, parameter_state in opt.state.items():
785                    self.assertEqual(
786                        {
787                            name: value.storage().data_ptr()
788                            for name, value in parameter_state.items()
789                        },
790                        observed_state.get(parameter.storage().data_ptr(), []),
791                    )
792
793    def test_optimizer(self):
794        inputs = torch.rand(10)
795        with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
796            net = SimpleNet()
797            opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
798
799            opt.zero_grad()
800            out = net(inputs)
801            loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
802            loss.backward()
803            opt.step()
804        self._check_results(
805            opt,
806            self._flat_out_extrafields(
807                p.profiler.kineto_results.experimental_event_tree()
808            ),
809            False,
810        )
811
812    def _test_optimizer_parameters(self, optimizer_factory):
813        inputs = torch.rand(10)
814        with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
815            net = SimpleNet()
816            opt = optimizer_factory(net.parameters())
817            for _ in range(2):
818                opt.zero_grad()
819                out = net(inputs)
820                loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
821                loss.backward()
822                opt.step()
823        self._check_results(
824            opt,
825            self._flat_out_extrafields(
826                p.profiler.kineto_results.experimental_event_tree()
827            ),
828            True,
829        )
830
831    def test_optimizer_parameters_sgd(self):
832        self._test_optimizer_parameters(
833            lambda params: torch.optim.SGD(params, lr=0.01, momentum=0.9)
834        )
835
836    def test_optimizer_parameters_adam(self):
837        self._test_optimizer_parameters(
838            lambda params: torch.optim.Adam(params, foreach=True)
839        )
840
841    def test_allocations(self):
842        gc.collect()
843        with profile(profile_memory=True) as p:
844            x = torch.empty((3, 4))
845
846        nodes = p.profiler.kineto_results.experimental_event_tree()
847        node = find_node_with_name(nodes, "[memory]")
848        self.assertIsNotNone(node)
849
850        alloc_size = 3 * 4 * 4  # fp32 -> 4 bytes
851        ptr = node.extra_fields.ptr
852        self.assertGreater(ptr, 0)
853        self.assertEqual(node.extra_fields.alloc_size, alloc_size)
854        self.assertEqual(node.extra_fields.device, torch.device("cpu"))
855        total_allocated = node.extra_fields.total_allocated
856
857        # total_reserved is only for CUDACachingAllocator
858        self.assertEqual(node.extra_fields.total_reserved, 0)
859
860        with profile(profile_memory=True) as p:
861            del x
862            gc.collect()
863
864        nodes = p.profiler.kineto_results.experimental_event_tree()
865        node = find_node_with_name(nodes, "[memory]")
866        self.assertIsNotNone(node)
867
868        self.assertEqual(node.extra_fields.ptr, ptr)
869        self.assertEqual(node.extra_fields.alloc_size, -alloc_size)
870        self.assertEqual(node.extra_fields.device, torch.device("cpu"))
871        self.assertEqual(
872            node.extra_fields.total_allocated, total_allocated - alloc_size
873        )
874
875    def test_refcounts(self):
876        class Sentinel:
877            pass
878
879        def make():
880            outer_sentinel = Sentinel()
881
882            def outer():
883                # Python will only close over variables used in the function.
884                _ = outer_sentinel
885                inner_sentinel = Sentinel()
886
887                def inner():
888                    _ = inner_sentinel
889
890                with profile(with_stack=True):
891                    inner()
892
893                return weakref.ref(inner_sentinel)
894
895            return outer, weakref.ref(outer_sentinel)
896
897        # Use a factory function to ensure the test scope never sees strong
898        # references. `del` has strange semantics that interact with closures
899        # at an AST level, so this is simpler.
900        outer, outer_sentinel_ref = make()
901        inner_sentinel_ref = outer()
902
903        self.assertIsNone(inner_sentinel_ref())
904
905        # `outer` holds the last reference via closure.
906        self.assertIsNotNone(outer_sentinel_ref())
907
908        del outer
909        self.assertIsNone(outer_sentinel_ref())
910
911
912if __name__ == "__main__":
913    run_tests()
914