xref: /aosp_15_r20/external/pytorch/test/profiler/test_profiler_tree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2
3import functools
4import os
5import re
6import textwrap
7import traceback
8import unittest
9
10import expecttest
11
12import torch
13from torch._C._profiler import _ExtraFields_PyCall, _ExtraFields_PyCCall
14from torch.testing._internal.common_utils import (
15    IS_ARM64,
16    IS_WINDOWS,
17    run_tests,
18    skipIfTorchDynamo,
19    TEST_WITH_CROSSREF,
20    TestCase,
21)
22from torch.utils._pytree import tree_map
23
24
25# These functions can vary from based on platform and build (e.g. with CUDA)
26# and generally distract from rather than adding to the test.
27PRUNE_ALL = 1
28KEEP_ELLIPSES = 2
29KEEP_NAME_AND_ELLIPSES = 3
30
31PRUNE_FUNCTIONS = {
32    "torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
33    "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
34    "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
35    "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
36    "<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
37    "cudaStreamIsCapturing": PRUNE_ALL,
38    # These show up only on CUDA, prune them so the CUDA and CPU expected results can be the same
39    "cudaGetDeviceCount": PRUNE_ALL,
40    "cudaGetDeviceProperties_v2": PRUNE_ALL,
41}
42
43# ROCTracer is currently not producing events that profiler can extract. We
44# should bring it up to parity with CUPTI Kineto / profiler integration, but in
45# the mean time there is still utility in running tests but not checking that
46# the values match expected value.
47#  1) We will still catch runtime errors and assert failures
48#  2) We can diff the output to see how far we are from parity
49#
50# TODO: We also fail to capture events for Windows on some platforms.
51ALLOW_CUDA_FAILURE = (torch.version.hip is not None) or IS_WINDOWS
52
53
54class TorchFunctionTensor(torch.Tensor):
55    @classmethod
56    def __torch_function__(cls, func, types, args=(), kwargs=None):
57        return super().__torch_function__(func, types, args, kwargs)
58
59
60class TorchDispatchTensor(torch.Tensor):
61    @staticmethod
62    def __new__(cls, elem):
63        t = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
64        t.elem = elem
65        return t
66
67    @classmethod
68    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
69        def unwrap(x):
70            return x.elem if isinstance(x, TorchDispatchTensor) else x
71
72        def wrap(x):
73            return TorchDispatchTensor(x) if isinstance(x, torch.Tensor) else x
74
75        args = tree_map(unwrap, args)
76        kwargs = tree_map(unwrap, kwargs or {})
77
78        return tree_map(wrap, func(*args, **kwargs))
79
80
81class ProfilerTree:
82    @staticmethod
83    def test(f):
84        """Mark unit test that will be using ProfilerTree to test traces.
85
86        This decorator serves two purposes. First, it provides a method name
87        that `format` can use to tell where the test runner (which is
88        environment specific) ends and the unit test begins. Second, it runs
89        the test with replicates and allows `assertTreesMatch` to adjust
90        based on which replicate is running.
91        """
92
93        @functools.wraps(f)
94        def begin_unit_test_marker(self, replicates=3):
95            try:
96                for i in range(replicates):
97                    self.tree_replicate = i
98                    out = f(self)
99                    if self.tree_replicate is None:
100                        break
101                return out
102            finally:
103                delattr(self, "tree_replicate")
104
105        return begin_unit_test_marker
106
107    @classmethod
108    def format(cls, profiler, indent: int = 0):
109        def flatten(nodes, depth=0, out=None):
110            if out is None:
111                out = []
112
113            for node in nodes:
114                cls.validate_node(node)
115                name = cls.fmt_name(node.name)
116                prune_level = PRUNE_FUNCTIONS.get(name.strip(), None)
117                if prune_level is None:
118                    out.append((depth, name))
119                    flatten(node.children, depth + 1, out)
120                elif prune_level == KEEP_NAME_AND_ELLIPSES:
121                    out.append((depth, name))
122                    if node.children:
123                        out.append((depth + 1, "..."))
124                elif prune_level == KEEP_ELLIPSES:
125                    out.append((depth, "..."))
126                else:
127                    assert prune_level == PRUNE_ALL
128
129            return out
130
131        flat_nodes = flatten(profiler.kineto_results.experimental_event_tree())
132
133        # Profiler inserts a `cudaDeviceSynchronize` at the end of profiling.
134        # and may also insert 'Context Sync' CUDA synchronization event.
135        if flat_nodes and flat_nodes[-2][1] == "cudaDeviceSynchronize":
136            flat_nodes = flat_nodes[:-2]
137
138        if flat_nodes and flat_nodes[-1][1] == "cudaDeviceSynchronize":
139            flat_nodes = flat_nodes[:-1]
140
141        # Profiler inserts a `hipDeviceSynchronize` at the end of profiling.
142        if flat_nodes and flat_nodes[-1][1] == "hipDeviceSynchronize":
143            flat_nodes = flat_nodes[:-1]
144
145        min_depth = min(
146            [d + 1 for d, name in flat_nodes if "begin_unit_test_marker" in name] or [0]
147        )
148        return textwrap.indent(
149            "\n".join(
150                [
151                    f"{'  ' * (d - min_depth)}{name.rstrip()}"
152                    for d, name in flat_nodes
153                    if d >= min_depth
154                ]
155            ),
156            " " * indent,
157        )
158
159    @staticmethod
160    def fmt_name(name: str) -> str:
161        match = re.match(r"^(.*)\.py\(([0-9]+)\): (.*)$", name)
162        if match:
163            filename, _, fn = match.groups()
164
165            # This test can appear as `test/profiler/test_profiler_tree.py`
166            # depending on where it is run from.
167            test_file = os.path.splitext(os.path.split(__file__)[1])[0]
168            if filename.endswith(test_file):
169                filename = test_file
170
171            # We test against a string literal, so all paths have to look like POSIX paths.
172            filename = filename.replace(os.sep, "/")
173
174            # We don't want to have to update this test every time PyTorch changes.
175            # At some point we should test some line numbers, but for now it's
176            # too brittle.
177            lineno = "..."
178
179            return f"{filename}.py({lineno}): {fn}"
180
181        for kernel_pattern in (
182            "void at::native::elementwise_kernel",
183            "void at::native::reduce_kernel",
184            "void at::native::vectorized_elementwise_kernel",
185            "void at::native::unrolled_elementwise_kernel",
186            r"void [a-zA-Z0-9]+_kernel",  # Nvidia kernels.
187        ):
188            name = re.sub(
189                rf"{kernel_pattern}<.+>\(.+\)$",
190                f"{kernel_pattern.replace('[a-zA-Z0-9]+', '...')}<...>(...)",
191                name,
192            )
193
194        return re.sub("object at 0x[0-9a-fA-F]+>", "object at 0xXXXXXXXXXXXX>", name)
195
196    @classmethod
197    def validate_node(cls, node):
198        extra_fields = node.extra_fields
199        if isinstance(extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)):
200            # Check that the lineage established by the profiler matches the
201            # caller recorded by the Python tracer.
202            parent = node.parent
203            while parent is not None:
204                if isinstance(parent.extra_fields, _ExtraFields_PyCall):
205                    break
206                parent = parent.parent
207
208            def to_string(frame_state):
209                return f"{frame_state.file_name}(...): {frame_state.function_name}"
210
211            if parent:
212                parent_name = to_string(parent.extra_fields.callsite)
213                caller_name = to_string(extra_fields.caller)
214                assert parent_name == caller_name, f"{parent_name} vs. {caller_name}"
215
216
217@unittest.skipIf(IS_ARM64, "Not working on ARM")
218class TestProfilerTree(TestCase):
219    def assertTreesMatch(self, actual: str, expected: str, allow_failure: bool = False):
220        # Warning: Here be dragons
221        #   Different platforms will have subtly different behavior for Python
222        #   tracing. Observed differences include:
223        #     1) Windows symbolicates names differently from posix
224        #     2) The profile callback for c_call does not fire for Tensor.__pow__
225        #        on certain platforms. This is not caused by the function tracer,
226        #        but by cPython itself.
227        #
228        # The purpose of these unit tests is to ensure that the profiler is
229        # doing reasonable things. When these platform dependent variations occur
230        # simply coerce them into a platform independent form. If you made a
231        # change in the codebase which changes the trace produced, simply use
232        # EXPECTTEST_ACCEPT=1 to update the tests to reflect the new structure.
233
234        # expecttest will not show the diff view if `len(actual) < len(expected)`
235        if not expecttest.ACCEPT:
236            actual = actual.ljust(len(expected))
237        self.maxDiff = None
238
239        replicate = getattr(self, "tree_replicate", None)
240        self.assertIsNotNone(
241            replicate, "Please annotate test with `@ProfilerTree.test`"
242        )
243
244        # The profiler should produce deterministic results and should return
245        # to a clean state after each run. As a result, only the first
246        # replicate is allowed to update `expected`. If subsequent runs do not
247        # match it is a bug in the profiler.
248        if replicate:
249            self.assertEqual(actual, expected)
250        else:
251            try:
252                self.assertExpectedInline(actual, expected, skip=1)
253            except AssertionError as e:
254                if allow_failure:
255                    self.tree_replicate = None
256                    msg = traceback.format_exception_only(type(e), e)[0]
257                    print(msg.split("AssertionError:")[-1])
258                else:
259                    raise
260
261    # TODO: Add logic for CUDA version of test
262    @ProfilerTree.test
263    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
264    def test_profiler_experimental_tree(self):
265        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
266        with torch.profiler.profile() as p:
267            z = torch.add(t1, t2)
268            y = torch.ones(1)
269            loss = (y - z) ** 2
270            loss.backward()
271
272        self.assertTreesMatch(
273            ProfilerTree.format(p.profiler, 12),
274            """\
275            aten::add
276            aten::ones
277              aten::empty
278              aten::fill_
279            aten::sub
280            aten::pow
281              aten::result_type
282              aten::to
283            aten::ones_like
284              aten::empty_like
285                aten::empty_strided
286              aten::fill_
287            autograd::engine::evaluate_function: PowBackward0
288              PowBackward0
289                aten::pow
290                  aten::result_type
291                  aten::to
292                  aten::copy_
293                aten::mul
294                  aten::mul
295                    aten::to
296                      aten::_to_copy
297                        aten::empty_strided
298                        aten::copy_
299                aten::mul
300            autograd::engine::evaluate_function: SubBackward0
301              SubBackward0
302                aten::neg
303            autograd::engine::evaluate_function: AddBackward0
304              AddBackward0
305            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
306              torch::autograd::AccumulateGrad
307                aten::new_empty_strided
308                  aten::empty_strided
309                aten::copy_
310            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
311              torch::autograd::AccumulateGrad
312                aten::detach
313                  detach""",
314        )
315
316    # TODO: Add logic for CUDA version of test
317    @ProfilerTree.test
318    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
319    def test_profiler_experimental_tree_with_record_function(self):
320        with torch.profiler.profile() as p:
321            with torch.autograd.profiler.record_function("Top level Annotation"):
322                with torch.autograd.profiler.record_function("First Annotation"):
323                    x = torch.ones((1,), requires_grad=True)
324
325                # Check that we correctly handle the case when a user
326                # annotation does not call `__exit__`.
327                _ = torch.autograd.profiler.record_function(
328                    "Second Annotation"
329                ).__enter__()
330
331                y = x + 1
332                with torch.autograd.profiler.record_function("Third Annotation"):
333                    y.backward()
334
335        # NB: The `aten::zeros` before the record function annotations are due to
336        # `at::cpp_custom_type_hack`. When we switch to `torch::CustomClassHolder`
337        # they will disappear.
338        self.assertTreesMatch(
339            ProfilerTree.format(p.profiler, 12),
340            """\
341            Top level Annotation
342              First Annotation
343                aten::ones
344                  aten::empty
345                  aten::fill_
346              Second Annotation
347                aten::add
348                  aten::to
349                    aten::_to_copy
350                      aten::empty_strided
351                      aten::copy_
352                Third Annotation
353                  aten::ones_like
354                    aten::empty_like
355                      aten::empty_strided
356                    aten::fill_
357                  autograd::engine::evaluate_function: AddBackward0
358                    AddBackward0
359                  autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
360                    torch::autograd::AccumulateGrad
361                      aten::new_empty_strided
362                        aten::empty_strided
363                      aten::copy_""",
364        )
365
366    # TODO: Add logic for CUDA version of test
367    @ProfilerTree.test
368    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
369    def test_profiler_experimental_tree_with_memory(self):
370        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
371        with torch.profiler.profile(profile_memory=True) as p:
372            z = torch.add(t1, t2)
373            y = torch.ones(1)
374            loss = (y - z) ** 2
375            loss.backward()
376
377        self.assertTreesMatch(
378            ProfilerTree.format(p.profiler, 12),
379            """\
380            aten::add
381              [memory]
382            aten::ones
383              aten::empty
384                [memory]
385              aten::fill_
386            aten::sub
387              [memory]
388            aten::pow
389              aten::result_type
390              aten::to
391              [memory]
392            aten::ones_like
393              aten::empty_like
394                aten::empty_strided
395                  [memory]
396              aten::fill_
397            autograd::engine::evaluate_function: PowBackward0
398              PowBackward0
399                aten::pow
400                  aten::result_type
401                  aten::to
402                  [memory]
403                  aten::copy_
404                aten::mul
405                  [memory]
406                  aten::mul
407                    aten::to
408                      aten::_to_copy
409                        aten::empty_strided
410                          [memory]
411                        aten::copy_
412                    [memory]
413                    [memory]
414                  [memory]
415                aten::mul
416                  [memory]
417                [memory]
418                [memory]
419              [memory]
420            autograd::engine::evaluate_function: SubBackward0
421              SubBackward0
422                aten::neg
423                  [memory]
424              [memory]
425            autograd::engine::evaluate_function: AddBackward0
426              AddBackward0
427            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
428              torch::autograd::AccumulateGrad
429                aten::new_empty_strided
430                  aten::empty_strided
431                    [memory]
432                aten::copy_
433            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
434              torch::autograd::AccumulateGrad
435                aten::detach
436                  detach
437            [memory]""",
438        )
439
440    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
441    @unittest.skipIf(
442        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
443    )
444    @ProfilerTree.test
445    def test_profiler_experimental_tree_with_memory_and_stack(self):
446        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
447        with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
448            z = torch.add(t1, t2)
449            y = torch.ones(1)
450            loss = torch.pow(y - z, 2)
451            loss.backward()
452
453        self.assertTreesMatch(
454            ProfilerTree.format(p.profiler, 12),
455            """\
456            test_profiler_tree.py(...): test_profiler_experimental_tree_with_memory_and_stack
457              torch/profiler/profiler.py(...): __enter__
458                ...
459              <built-in method add of type object at 0xXXXXXXXXXXXX>
460                aten::add
461                  [memory]
462              <built-in method ones of type object at 0xXXXXXXXXXXXX>
463                aten::ones
464                  aten::empty
465                    [memory]
466                  aten::fill_
467              aten::sub
468                [memory]
469              <built-in method pow of type object at 0xXXXXXXXXXXXX>
470                aten::pow
471                  aten::result_type
472                  aten::to
473                  [memory]
474              torch/_tensor.py(...): backward
475                <built-in function _has_torch_function_unary>
476                torch/autograd/__init__.py(...): backward
477                  <built-in method _are_functorch_transforms_active of PyCapsule object at 0xXXXXXXXXXXXX>
478                  <built-in function isinstance>
479                  <built-in function isinstance>
480                  <built-in function len>
481                  torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
482                  torch/autograd/__init__.py(...): _make_grads
483                    typing.py(...): inner
484                      typing.py(...): __hash__
485                        <built-in function hash>
486                    typing.py(...): cast
487                    <built-in function isinstance>
488                    <built-in function isinstance>
489                    <built-in function isinstance>
490                    <built-in function isinstance>
491                    <built-in function isinstance>
492                    <built-in function isinstance>
493                    <built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
494                    <built-in function isinstance>
495                    <built-in function isinstance>
496                    <built-in method ones_like of type object at 0xXXXXXXXXXXXX>
497                      aten::ones_like
498                        aten::empty_like
499                          aten::empty_strided
500                            [memory]
501                        aten::fill_
502                    <built-in method append of list object at 0xXXXXXXXXXXXX>
503                  torch/autograd/graph.py(...): _engine_run_backward
504                    logging/__init__.py(...): getEffectiveLevel
505                    <built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
506                      autograd::engine::evaluate_function: PowBackward0
507                        PowBackward0
508                          aten::pow
509                            aten::result_type
510                            aten::to
511                            [memory]
512                            aten::copy_
513                          aten::mul
514                            [memory]
515                            aten::mul
516                              aten::to
517                                aten::_to_copy
518                                  aten::empty_strided
519                                    [memory]
520                                  aten::copy_
521                              [memory]
522                              [memory]
523                            [memory]
524                          aten::mul
525                            [memory]
526                          [memory]
527                          [memory]
528                        [memory]
529                      autograd::engine::evaluate_function: SubBackward0
530                        SubBackward0
531                          aten::neg
532                            [memory]
533                        [memory]
534                      autograd::engine::evaluate_function: AddBackward0
535                        AddBackward0
536                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
537                        torch::autograd::AccumulateGrad
538                          aten::new_empty_strided
539                            aten::empty_strided
540                              [memory]
541                          aten::copy_
542                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
543                        torch::autograd::AccumulateGrad
544                          aten::detach
545                            detach
546                [memory]
547              torch/profiler/profiler.py(...): __exit__
548                torch/profiler/profiler.py(...): stop
549                  ...""",
550        )
551
552    @skipIfTorchDynamo("too slow")
553    @unittest.skipIf(
554        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
555    )
556    @ProfilerTree.test
557    def test_profiler_experimental_tree_with_stack_and_modules(self):
558        class MyModule(torch.nn.Module):
559            def __init__(self) -> None:
560                super().__init__()
561                self.layers = [
562                    torch.nn.ReLU(),
563                    torch.nn.Linear(1, 1),
564                    torch.nn.ReLU(),
565                ]
566
567            def forward(self, x: torch.Tensor) -> torch.Tensor:
568                for l in self.layers:
569                    x = l(x)
570                return x
571
572        model = MyModule()
573        with torch.profiler.profile(with_stack=True) as p:
574            for _ in range(2):
575                model(torch.ones((1,)))
576        self.maxDiff = None
577        self.assertTreesMatch(
578            ProfilerTree.format(p.profiler, 12),
579            """\
580            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_modules
581              torch/profiler/profiler.py(...): __enter__
582                ...
583              <built-in method ones of type object at 0xXXXXXXXXXXXX>
584                aten::ones
585                  aten::empty
586                  aten::fill_
587              nn.Module: MyModule_0
588                torch/nn/modules/module.py(...): _call_impl
589                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
590                  test_profiler_tree.py(...): forward
591                    nn.Module: ReLU_0
592                      torch/nn/modules/module.py(...): _call_impl
593                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
594                        torch/nn/modules/activation.py(...): forward
595                          torch/nn/functional.py(...): relu
596                            <built-in function _has_torch_function_unary>
597                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
598                              aten::relu
599                                aten::clamp_min
600                    nn.Module: Linear_0
601                      torch/nn/modules/module.py(...): _call_impl
602                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
603                        torch/nn/modules/linear.py(...): forward
604                          torch/nn/modules/module.py(...): __getattr__
605                          torch/nn/modules/module.py(...): __getattr__
606                          <built-in function linear>
607                            aten::linear
608                              aten::reshape
609                                aten::view
610                              aten::t
611                                aten::transpose
612                                  aten::as_strided
613                              aten::addmm
614                                aten::expand
615                                  aten::as_strided
616                                aten::copy_
617                                aten::resolve_conj
618                                aten::resolve_conj
619                                aten::resolve_conj
620                              aten::view
621                    nn.Module: ReLU_1
622                      torch/nn/modules/module.py(...): _call_impl
623                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
624                        torch/nn/modules/activation.py(...): forward
625                          torch/nn/functional.py(...): relu
626                            <built-in function _has_torch_function_unary>
627                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
628                              aten::relu
629                                aten::clamp_min
630              <built-in method ones of type object at 0xXXXXXXXXXXXX>
631                aten::ones
632                  aten::empty
633                  aten::fill_
634              nn.Module: MyModule_0
635                torch/nn/modules/module.py(...): _call_impl
636                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
637                  test_profiler_tree.py(...): forward
638                    nn.Module: ReLU_0
639                      torch/nn/modules/module.py(...): _call_impl
640                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
641                        torch/nn/modules/activation.py(...): forward
642                          torch/nn/functional.py(...): relu
643                            <built-in function _has_torch_function_unary>
644                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
645                              aten::relu
646                                aten::clamp_min
647                    nn.Module: Linear_0
648                      torch/nn/modules/module.py(...): _call_impl
649                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
650                        torch/nn/modules/linear.py(...): forward
651                          torch/nn/modules/module.py(...): __getattr__
652                          torch/nn/modules/module.py(...): __getattr__
653                          <built-in function linear>
654                            aten::linear
655                              aten::reshape
656                                aten::view
657                              aten::t
658                                aten::transpose
659                                  aten::as_strided
660                              aten::addmm
661                                aten::expand
662                                  aten::as_strided
663                                aten::copy_
664                                aten::resolve_conj
665                                aten::resolve_conj
666                                aten::resolve_conj
667                              aten::view
668                    nn.Module: ReLU_1
669                      torch/nn/modules/module.py(...): _call_impl
670                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
671                        torch/nn/modules/activation.py(...): forward
672                          torch/nn/functional.py(...): relu
673                            <built-in function _has_torch_function_unary>
674                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
675                              aten::relu
676                                aten::clamp_min
677              torch/profiler/profiler.py(...): __exit__
678                torch/profiler/profiler.py(...): stop
679                  ...""",
680        )
681
682    @unittest.skipIf(
683        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
684    )
685    @ProfilerTree.test
686    def test_profiler_experimental_tree_with_stack_and_torch_function(self):
687        x = TorchFunctionTensor(torch.ones((1,)))
688        y = torch.ones((1,))
689
690        # There's some lazy initialization in __torch_function__. If we don't
691        # run this the first run won't match the replicates.
692        torch.add(x, y)
693
694        with torch.profiler.profile(with_stack=True) as p:
695            torch.add(x, y)
696
697        self.assertTreesMatch(
698            ProfilerTree.format(p.profiler, 12),
699            """\
700            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_torch_function
701              torch/profiler/profiler.py(...): __enter__
702                ...
703              <built-in method add of type object at 0xXXXXXXXXXXXX>
704                test_profiler_tree.py(...): __torch_function__
705                  torch/_tensor.py(...): __torch_function__
706                    <built-in function all>
707                      torch/_tensor.py(...): <genexpr>
708                        <built-in function issubclass>
709                      torch/_tensor.py(...): <genexpr>
710                    <built-in method add of type object at 0xXXXXXXXXXXXX>
711                      aten::add
712                    torch/_tensor.py(...): _convert
713                      <built-in function isinstance>
714                      <built-in function isinstance>
715                      <built-in method as_subclass of Tensor object at 0xXXXXXXXXXXXX>
716                        aten::alias
717                      <built-in function isinstance>
718              torch/profiler/profiler.py(...): __exit__
719                torch/profiler/profiler.py(...): stop
720                  ...""",
721        )
722
723    @unittest.skipIf(
724        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
725    )
726    @ProfilerTree.test
727    def test_profiler_experimental_tree_with_stack_and_torch_dispatch(self):
728        x = TorchDispatchTensor(torch.ones((1,)))
729        y = torch.ones((1,))
730
731        # warmup round
732        with torch.profiler.profile(with_stack=True):
733            x + y
734
735        with torch.profiler.profile(with_stack=True) as p:
736            x + y
737
738        self.assertTreesMatch(
739            ProfilerTree.format(p.profiler, 12),
740            """\
741            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_torch_dispatch
742              torch/profiler/profiler.py(...): __enter__
743                ...
744              aten::add
745                torch/_library/simple_registry.py(...): find_torch_dispatch_rule
746                  torch/_library/simple_registry.py(...): find
747                  torch/_library/simple_registry.py(...): find
748                    <built-in method get of dict object at 0xXXXXXXXXXXXX>
749                test_profiler_tree.py(...): __torch_dispatch__
750                  torch/utils/_pytree.py(...): tree_map
751                    ...
752                  torch/utils/_pytree.py(...): tree_map
753                    ...
754                  torch/_ops.py(...): __call__
755                    <built-in method  of PyCapsule object at 0xXXXXXXXXXXXX>
756                      aten::add
757                  torch/utils/_pytree.py(...): tree_map
758                    ...
759              torch/profiler/profiler.py(...): __exit__
760                torch/profiler/profiler.py(...): stop
761                  ...""",
762        )
763
764    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
765    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
766    @ProfilerTree.test
767    def test_profiler_experimental_tree_cuda(self):
768        with torch.profiler.profile(profile_memory=True) as p:
769            weight = torch.ones(1, device="cuda", requires_grad=True)
770            x = torch.ones(1, device="cuda")
771            y = torch.add(weight, x)
772            loss = torch.pow(y, 2)
773            loss.backward()
774            torch.optim.SGD([weight], lr=0.01, momentum=0.9).step()
775
776        self.assertTreesMatch(
777            ProfilerTree.format(p.profiler, 12),
778            """\
779            aten::ones
780              aten::empty
781                [memory]
782              aten::fill_
783                cudaLaunchKernel
784                  void at::native::vectorized_elementwise_kernel<...>(...)
785            aten::ones
786              aten::empty
787                [memory]
788              aten::fill_
789                cudaLaunchKernel
790                  void at::native::vectorized_elementwise_kernel<...>(...)
791            aten::add
792              cudaLaunchKernel
793                void at::native::vectorized_elementwise_kernel<...>(...)
794              [memory]
795            aten::pow
796              cudaLaunchKernel
797                void at::native::vectorized_elementwise_kernel<...>(...)
798              aten::result_type
799              aten::to
800              [memory]
801            aten::ones_like
802              aten::empty_like
803                aten::empty_strided
804                  [memory]
805              aten::fill_
806                cudaLaunchKernel
807                  void at::native::vectorized_elementwise_kernel<...>(...)
808            autograd::engine::evaluate_function: PowBackward0
809              PowBackward0
810                aten::pow
811                  aten::result_type
812                  aten::to
813                  [memory]
814                  aten::copy_
815                    cudaMemcpyAsync
816                      Memcpy DtoD (Device -> Device)
817                aten::mul
818                  [memory]
819                  aten::mul
820                    cudaLaunchKernel
821                      void at::native::vectorized_elementwise_kernel<...>(...)
822                    [memory]
823                  [memory]
824                aten::mul
825                  cudaLaunchKernel
826                    void at::native::vectorized_elementwise_kernel<...>(...)
827                  [memory]
828                [memory]
829                [memory]
830            autograd::engine::evaluate_function: AddBackward0
831              AddBackward0
832            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
833              torch::autograd::AccumulateGrad
834                aten::detach
835                  detach
836            [memory]
837            aten::zeros
838              aten::zeros
839                aten::empty
840                  [memory]
841                aten::zero_
842            Optimizer.step#SGD.step
843              aten::empty
844                [memory]
845              [memory]
846              [memory]
847              aten::clone
848                aten::empty_strided
849                  [memory]
850                aten::copy_
851                  cudaMemcpyAsync
852                    Memcpy DtoD (Device -> Device)
853              aten::detach
854                detach
855              aten::add_
856                cudaLaunchKernel
857                  void at::native::vectorized_elementwise_kernel<...>(...)
858            [memory]""",  # noqa: B950
859            allow_failure=ALLOW_CUDA_FAILURE,
860        )
861
862    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
863    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
864    @ProfilerTree.test
865    def test_profiler_experimental_tree_cuda_with_stream(self):
866        streams = [torch.cuda.Stream() for _ in range(3)]
867        results = []
868        with torch.profiler.profile(profile_memory=True) as p:
869            x = torch.ones((4, 4), device="cuda")
870            for stream in streams:
871                with torch.cuda.stream(stream):
872                    results.append(torch.tanh(x) - x)
873        del results
874        for s in streams:
875            torch.cuda.current_stream().wait_stream(s)
876
877        self.assertTreesMatch(
878            ProfilerTree.format(p.profiler, 12),
879            """\
880            aten::ones
881              aten::empty
882                [memory]
883              aten::fill_
884                cudaLaunchKernel
885                  void at::native::vectorized_elementwise_kernel<...>(...)
886            aten::tanh
887              cudaMalloc
888              cudaLaunchKernel
889                void at::native::vectorized_elementwise_kernel<...>(...)
890              [memory]
891            aten::sub
892              cudaLaunchKernel
893                void at::native::vectorized_elementwise_kernel<...>(...)
894              [memory]
895            [memory]
896            aten::tanh
897              cudaMalloc
898              cudaLaunchKernel
899                void at::native::vectorized_elementwise_kernel<...>(...)
900              [memory]
901            aten::sub
902              cudaLaunchKernel
903                void at::native::vectorized_elementwise_kernel<...>(...)
904              [memory]
905            [memory]
906            aten::tanh
907              cudaMalloc
908              cudaLaunchKernel
909                void at::native::vectorized_elementwise_kernel<...>(...)
910              [memory]
911            aten::sub
912              cudaLaunchKernel
913                void at::native::vectorized_elementwise_kernel<...>(...)
914              [memory]
915            [memory]""",
916            allow_failure=ALLOW_CUDA_FAILURE,
917        )
918
919    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
920    @unittest.skipIf(
921        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
922    )
923    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
924    @ProfilerTree.test
925    def test_profiler_experimental_tree_cuda_detailed(self):
926        # Do lazy imports ahead of time to avoid it showing up in the tree
927        import torch.nested._internal.nested_tensor
928
929        model = torch.nn.modules.Linear(1, 1, device="cuda")
930        model.train()
931        opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
932
933        def step():
934            x = torch.ones((1, 1), device="cuda")
935            loss = model(x)
936            loss.backward()
937            opt.step()
938
939        # Warmup
940        for _ in range(3):
941            step()
942
943        with torch.profiler.profile(profile_memory=True, with_stack=True) as p:
944            step()
945
946        self.assertTreesMatch(
947            ProfilerTree.format(p.profiler, 12),
948            """\
949            test_profiler_tree.py(...): test_profiler_experimental_tree_cuda_detailed
950              torch/profiler/profiler.py(...): __enter__
951                ...
952              test_profiler_tree.py(...): step
953                <built-in method ones of type object at 0xXXXXXXXXXXXX>
954                  aten::ones
955                    aten::empty
956                      [memory]
957                    aten::fill_
958                      cudaLaunchKernel
959                        void at::native::vectorized_elementwise_kernel<...>(...)
960                nn.Module: Linear_0
961                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
962                  torch/nn/modules/linear.py(...): forward
963                    torch/nn/modules/module.py(...): __getattr__
964                    torch/nn/modules/module.py(...): __getattr__
965                    <built-in function linear>
966                      aten::linear
967                        aten::t
968                          aten::transpose
969                            aten::as_strided
970                        aten::addmm
971                          cudaMemcpyAsync
972                            Memcpy DtoD (Device -> Device)
973                          cudaLaunchKernel
974                            void ..._kernel<...>(...)
975                          [memory]
976                          aten::expand
977                            aten::as_strided
978                torch/_tensor.py(...): backward
979                  <built-in function _has_torch_function_unary>
980                  torch/autograd/__init__.py(...): backward
981                    <built-in function isinstance>
982                    <built-in function isinstance>
983                    <built-in function len>
984                    torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
985                    torch/autograd/__init__.py(...): _make_grads
986                      typing.py(...): inner
987                        typing.py(...): __hash__
988                          <built-in function hash>
989                      typing.py(...): cast
990                      <built-in function isinstance>
991                      <built-in function isinstance>
992                      <built-in function isinstance>
993                      <built-in function isinstance>
994                      <built-in function isinstance>
995                      <built-in function isinstance>
996                      <built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
997                      <built-in function isinstance>
998                      <built-in function isinstance>
999                      <built-in method ones_like of type object at 0xXXXXXXXXXXXX>
1000                        aten::ones_like
1001                          aten::empty_like
1002                            aten::empty_strided
1003                              [memory]
1004                          aten::fill_
1005                            cudaLaunchKernel
1006                              void at::native::vectorized_elementwise_kernel<...>(...)
1007                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1008                    <built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
1009                      autograd::engine::evaluate_function: AddmmBackward0
1010                        AddmmBackward0
1011                          aten::t
1012                            aten::transpose
1013                              aten::as_strided
1014                          aten::mm
1015                            cudaLaunchKernel
1016                              void ..._kernel<...>(...)
1017                            [memory]
1018                          aten::t
1019                            aten::transpose
1020                              aten::as_strided
1021                        aten::sum
1022                          aten::sum
1023                            cudaLaunchKernel
1024                              void at::native::reduce_kernel<...>(...)
1025                            [memory]
1026                        aten::view
1027                          aten::view
1028                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
1029                        torch::autograd::AccumulateGrad
1030                          aten::add_
1031                            cudaLaunchKernel
1032                              void at::native::vectorized_elementwise_kernel<...>(...)
1033                          [memory]
1034                      autograd::engine::evaluate_function: TBackward0
1035                        TBackward0
1036                          aten::t
1037                            aten::transpose
1038                              aten::as_strided
1039                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
1040                        torch::autograd::AccumulateGrad
1041                          aten::add_
1042                            cudaLaunchKernel
1043                              void at::native::vectorized_elementwise_kernel<...>(...)
1044                          [memory]
1045                  [memory]
1046                torch/optim/optimizer.py(...): wrapper
1047                  <built-in method format of str object at 0xXXXXXXXXXXXX>
1048                  torch/autograd/profiler.py(...): __init__
1049                    <built-in method zeros of type object at 0xXXXXXXXXXXXX>
1050                      aten::zeros
1051                        aten::zeros
1052                          aten::empty
1053                            [memory]
1054                          aten::zero_
1055                  torch/autograd/profiler.py(...): __enter__
1056                    torch/_ops.py(...): __call__
1057                      <built-in method _record_function_enter of PyCapsule object at 0xXXXXXXXXXXXX>
1058                        Optimizer.step#SGD.step
1059                          aten::empty
1060                            [memory]
1061                          [memory]
1062                    [memory]
1063                  torch/optim/optimizer.py(...): _use_grad
1064                    <built-in function is_grad_enabled>
1065                    torch/autograd/grad_mode.py(...): __init__
1066                      <built-in function is_grad_enabled>
1067                      <built-in function _set_grad_enabled>
1068                    torch/optim/sgd.py(...): step
1069                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1070                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1071                      torch/_tensor.py(...): __hash__
1072                        <built-in function id>
1073                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1074                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1075                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1076                      torch/_tensor.py(...): __hash__
1077                        <built-in function id>
1078                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1079                      torch/optim/sgd.py(...): sgd
1080                        torch/optim/sgd.py(...): _single_tensor_sgd
1081                          <built-in method mul_ of Tensor object at 0xXXXXXXXXXXXX>
1082                            [memory]
1083                            aten::mul_
1084                              cudaLaunchKernel
1085                                void at::native::vectorized_elementwise_kernel<...>(...)
1086                            [memory]
1087                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1088                            aten::add_
1089                              cudaLaunchKernel
1090                                void at::native::vectorized_elementwise_kernel<...>(...)
1091                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1092                            aten::add_
1093                              cudaLaunchKernel
1094                                void at::native::vectorized_elementwise_kernel<...>(...)
1095                          <built-in method mul_ of Tensor object at 0xXXXXXXXXXXXX>
1096                            [memory]
1097                            aten::mul_
1098                              cudaLaunchKernel
1099                                void at::native::vectorized_elementwise_kernel<...>(...)
1100                            [memory]
1101                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1102                            aten::add_
1103                              cudaLaunchKernel
1104                                void at::native::vectorized_elementwise_kernel<...>(...)
1105                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1106                            aten::add_
1107                              cudaLaunchKernel
1108                                void at::native::vectorized_elementwise_kernel<...>(...)
1109                      torch/_tensor.py(...): __hash__
1110                        <built-in function id>
1111                      torch/_tensor.py(...): __hash__
1112                        <built-in function id>
1113                    torch/autograd/grad_mode.py(...): __init__
1114                      <built-in function is_grad_enabled>
1115                      <built-in function _set_grad_enabled>
1116                  torch/autograd/profiler.py(...): __exit__
1117                    torch/_ops.py(...): __call__
1118                      <built-in method _record_function_exit of PyCapsule object at 0xXXXXXXXXXXXX>
1119              [memory]
1120              [memory]
1121              torch/profiler/profiler.py(...): __exit__
1122                torch/profiler/profiler.py(...): stop
1123                  torch/profiler/profiler.py(...): _transit_action
1124                    <built-in method get of dict object at 0xXXXXXXXXXXXX>
1125                      enum.py(...): __hash__
1126                        <built-in function hash>
1127                    ...""",  # noqa: B950
1128            allow_failure=ALLOW_CUDA_FAILURE,
1129        )
1130
1131
1132if __name__ == "__main__":
1133    run_tests()
1134