xref: /aosp_15_r20/external/pytorch/test/profiler/test_memory_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2import functools
3import gc
4import itertools as it
5import textwrap
6from typing import Callable, Dict, Iterator, List, Optional, Tuple
7
8import torch
9from torch._C._profiler import _EventType, _TensorMetadata
10from torch.profiler import _memory_profiler, _utils
11from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
12from torch.utils import _pytree as pytree
13
14
15profile = functools.partial(
16    torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
17)
18
19
20@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
21class TestMemoryProfiler(TestCase):
22    def test_config_check(self) -> None:
23        with torch.profiler.profile() as prof:
24            pass
25
26        pattern = r"record_shapes=True, profile_memory=True, with_stack=True"
27        with self.assertRaisesRegex(ValueError, pattern):
28            prof._memory_profile()
29
30        with torch.profiler.profile(record_shapes=True, with_stack=True) as prof:
31            pass
32
33        pattern = r"^profile_memory=True required for memory profiling\.$"
34        with self.assertRaisesRegex(ValueError, pattern):
35            prof._memory_profile()
36
37        with profile() as prof:
38            pass
39
40        self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile)
41
42
43class ScaleLayer(torch.nn.Module):
44    def __init__(self) -> None:
45        super().__init__()
46        self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
47
48    def forward(self, x: torch.Tensor) -> torch.Tensor:
49        return x * self.scale
50
51
52class LazyLinear(torch.nn.Module):
53    def __init__(self, in_features: int, out_features: int):
54        super().__init__()
55        self.in_features = in_features
56        self.out_features = out_features
57
58    def forward(self, x) -> torch.Tensor:
59        if getattr(self, "weight", None) is None:
60            self.weight = torch.nn.Parameter(
61                torch.empty((self.out_features, self.in_features))
62            )
63            self.bias = torch.nn.Parameter(torch.empty(self.out_features))
64
65        return torch.nn.functional.linear(x, self.weight, self.bias)
66
67
68class RecordInputOutputDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
69    def __init__(self) -> None:
70        self.results = []
71
72    def mark_region(self, name: str):
73        self.results.append((name, (), ()))
74
75    @staticmethod
76    def flat_ids(args):
77        flat_args = pytree.tree_leaves(args)
78        return tuple(
79            (t._cdata, t.storage().data_ptr())
80            for t in flat_args
81            if isinstance(t, torch.Tensor) and t.storage()
82        )
83
84    def __torch_dispatch__(self, func, types, args=..., kwargs=None):
85        args = args or []
86        kwargs = kwargs or {}
87        flat_inputs = self.flat_ids(args) + self.flat_ids(kwargs)
88        out = func(*args, **kwargs)
89        flat_outputs = self.flat_ids(out)
90        if (
91            flat_inputs or flat_outputs
92        ) and "_record_function_enter" not in func.name():
93            self.results.append((func.name(), flat_inputs, flat_outputs))
94        return out
95
96
97@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
98class TestIdentifyGradients(TestCase):
99    def gradient_detected(
100        self,
101        prof: torch.profiler.profile,
102        ctx: _EventType,
103        grad_tensor: torch.Tensor,
104        parameter: Optional[torch.Tensor] = None,
105    ) -> None:
106        # This is not an exhaustive check, but for the purpose of unit testing
107        # it is sufficient.
108        def key_matches_tensor(key, tensor) -> bool:
109            # Vacuous case.
110            if tensor is None:
111                return True
112
113            if key is None:
114                return False
115
116            return tensor.storage().data_ptr() == key.storage.ptr
117
118        tree = prof.profiler.kineto_results.experimental_event_tree()
119        for node in _utils.traverse_dfs(tree):
120            for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
121                if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
122                    if parameter is None:
123                        return True  # Don't need to check parameter; we're done.
124
125                    elif p_key is not None:
126                        # For a complex workflow a gradient could correspond to
127                        # different parameters at different points in a trace.
128                        # However this will not happen in the relatively simple
129                        # cases tested here, so if `extract_gradients` identifies
130                        # the parameter corresponding to a particular gradient it
131                        # must be the one we expect.
132                        self.assertTrue(key_matches_tensor(p_key, parameter))
133                        return True
134
135        return False
136
137    def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
138        self.assertTrue(
139            self.gradient_detected(*args, **kwargs),
140            f"Failed to identify gradient `{name}` from profile.",
141        )
142
143    def assertOnlyGradients(
144        self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
145    ) -> None:
146        allowed_set = {t.storage().data_ptr() for t in tensors}
147
148        tree = prof.profiler.kineto_results.experimental_event_tree()
149        for node in _utils.traverse_dfs(tree):
150            for _, p_grad_key in _memory_profiler.extract_gradients(node):
151                self.assertTrue(
152                    p_grad_key.storage.ptr in allowed_set,
153                    f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
154                )
155
156    def test_extract_gradients_low_level(self) -> None:
157        x = torch.ones((1,))
158        w0 = torch.ones((1,), requires_grad=True)
159        w1 = torch.ones((1,), requires_grad=True)
160
161        def check(cold_start: bool):
162            self.assertEqual(w0.grad is None, cold_start)
163            self.assertEqual(w1.grad is None, cold_start)
164            with profile() as prof:
165                z = x.expand(4) * w0
166                (z * w1).sum().backward()
167
168            # Gradient detection through op inspection does not provide a
169            # reference to the parameter corresponding to the gradient.
170            self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
171            self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
172            self.assertOnlyGradients(prof, (w0.grad, w1.grad))
173
174        check(cold_start=True)
175        check(cold_start=False)
176
177    def test_extract_gradients_from_module(self) -> None:
178        model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
179        named_parameters = dict(model.named_parameters())
180        self.assertEqual(len(named_parameters), 3)
181
182        def assert_only_gradients(prof: torch.profiler.profile):
183            gradients = tuple(i.grad for i in named_parameters.values())
184            self.assertFalse(any(i is None for i in gradients))
185            self.assertOnlyGradients(prof, gradients)
186
187        def check(cold_start: bool):
188            x = torch.ones((2, 2))
189            with profile() as prof:
190                model(x).sum().backward()
191
192            for name, p in named_parameters.items():
193                # The first time we run a module none of the `.grad` fields
194                # have been initialized. This is fine; in that case we can
195                # detect everything we need in the profiled section.
196                self.assertNotEqual(
197                    self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
198                    cold_start,
199                    name,
200                )
201
202                # Op based detection should still identify the gradients.
203                self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
204            assert_only_gradients(prof)
205
206            # We can detect gradients even when `.backward()` is not called.
207            with profile() as prof:
208                model(torch.ones((2, 2)))
209
210            for name, p in named_parameters.items():
211                self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
212                self.assertFalse(
213                    self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
214                )
215            assert_only_gradients(prof)
216
217        check(cold_start=True)
218        check(cold_start=False)
219
220    def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
221        x = torch.ones((1,))
222        w0 = torch.ones((1,), requires_grad=True)
223        w1 = torch.ones((1,), requires_grad=True)
224        optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
225
226        def check(cold_start: bool):
227            self.assertEqual(w0.grad is None, cold_start)
228            self.assertEqual(w1.grad is None, cold_start)
229            with profile() as prof:
230                optimizer.zero_grad(set_to_none=set_to_none)
231                z = x.expand(4) * w0
232                (z * w1).sum().backward()
233                optimizer.step()
234
235            # Optimizer instrumentation runs late in the step, so we can detect
236            # gradients for both cold and warm start.
237            self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
238            self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
239
240            self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
241            self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
242            self.assertOnlyGradients(prof, (w0.grad, w1.grad))
243
244            with profile() as prof:
245                for _ in range(2):
246                    optimizer.zero_grad(set_to_none=set_to_none)
247                    z = x.expand(4) * w0
248                    (z * w1).sum().backward()
249                    optimizer.step()
250
251            # Inspected state is cached, so if we replace gradients (as is the
252            # case for `set_to_none=True`) our python instrumentation will not
253            # see them.
254            # TODO(robieta): Should `.step()` be excluded from caching?
255            self.assertNotEqual(
256                self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
257                set_to_none,
258            )
259
260            self.assertNotEqual(
261                self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
262                set_to_none,
263            )
264
265            if set_to_none:
266                with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
267                    self.assertOnlyGradients(prof, (w0.grad, w1.grad))
268
269        check(cold_start=True)
270        check(cold_start=False)
271
272    def test_extract_gradients_from_optimizer(self) -> None:
273        self._test_extract_gradients_from_optimizer(set_to_none=False)
274
275    def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
276        self._test_extract_gradients_from_optimizer(set_to_none=True)
277
278    def test_extract_gradients_from_module_and_optimizer(self) -> None:
279        # Module and optimizer are thoroughly tested individually and should be
280        # additive. Thus we can manage with a lightweight check that they don't
281        # interact adversely.
282        model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
283        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
284        with profile() as prof:
285            model(torch.ones((2, 2))).sum().backward()
286            optimizer.step()
287
288        self.assertGradientDetected(
289            "weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
290        )
291
292
293@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
294class TestDataFlow(TestCase):
295    def setUp(self) -> None:
296        super().setUp()
297        self.maxDiff = None
298
299    @staticmethod
300    def formatSchemas(
301        prof: torch.profiler.profile, indent: int = 12
302    ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
303        tree = prof.profiler.kineto_results.experimental_event_tree()
304        out: List[Tuple[str, Tuple[bool, ...]]] = []
305        for node in _utils.traverse_dfs(tree):
306            if node.tag == _EventType.TorchOp:
307                e = node.extra_fields
308                schemas = _memory_profiler.SchemaMatcher.match_schemas(e)
309                name = node.name
310                if len(schemas) == 1:
311                    name = f"{name}.{schemas[0].overload_name}"
312                elif len(schemas) > 1:
313                    name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}"
314
315                out.append((name, _memory_profiler.SchemaMatcher.inputs_are_mutable(e)))
316        return tuple(out)
317
318    @staticmethod
319    def _run_and_format_data_flow(
320        inputs: Dict[str, torch.Tensor],
321        f: Callable[..., Optional[Dict[str, torch.Tensor]]],
322        indent: int = 12,
323    ) -> str:
324        with profile() as prof:
325            outputs = f(**inputs) or {}
326            gc.collect()
327
328        memory_profile = prof._memory_profile()
329        graph = memory_profile._data_flow_graph
330        storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
331
332        lines: List[str] = []
333        for name, t in it.chain(inputs.items(), outputs.items()):
334            lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
335            if t.grad is not None:
336                grad_id = storage_to_id[t.grad.storage().data_ptr()]
337                lines.append(f"{name + '.grad:':<9} T{grad_id}")
338
339        if lines:
340            lines.append("")
341
342        for node in graph.flow_nodes:
343            destroyed = {k for k, v in node._edges.items() if v.is_deletion}
344
345            inputs: List[str] = []
346            for key, (_, v) in node.inputs.items():
347                inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
348
349            outputs = [f"T{key.id}(v{v})" for key, v in node.outputs.items()]
350            if inputs or outputs:
351                event_name = node._event.name.replace("torch::autograd::", "")
352                lines.append(
353                    f"{event_name:<25} {', '.join(inputs):<15}  ->  {', '.join(outputs)}"
354                )
355
356        return textwrap.indent("\n".join([l.rstrip() for l in lines]), " " * indent)
357
358    def test_match_schemas(self) -> None:
359        with profile() as prof:
360            x = torch.ones((1,)).mul(2).add_(2)
361            _ = torch.sin(x, out=torch.empty_like(x))
362
363        self.assertEqual(
364            self.formatSchemas(prof),
365            (
366                ("aten::ones.", (False,) * 5),
367                ("aten::empty.memory_format", (False,) * 6),
368                #
369                # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
370                ("aten::fill_.Scalar", (True, False)),
371                ("aten::mul.Tensor", (False, False)),
372                ("aten::to.dtype", (False,) * 5),
373                ("aten::_to_copy.", (False,) * 7),
374                ("aten::empty_strided.", (False,) * 6),
375                #
376                # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
377                ("aten::copy_.", (True, False, False)),
378                #
379                # add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
380                ("aten::add_.Tensor", (True, False, False)),
381                ("aten::to.dtype", (False,) * 5),
382                ("aten::_to_copy.", (False,) * 7),
383                ("aten::empty_strided.", (False,) * 6),
384                #
385                # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
386                ("aten::copy_.", (True, False, False)),
387                ("aten::empty_like.", (False,) * 6),
388                ("aten::empty_strided.", (False,) * 6),
389                #
390                # sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
391                ("aten::sin.out", (False, True)),
392            ),
393        )
394
395    def test_match_schemas_backward(self) -> None:
396        x = torch.ones((1,))
397        w = torch.ones((1,), requires_grad=True)
398        with profile() as prof:
399            torch.mul(x, w).backward()
400
401        self.assertEqual(
402            self.formatSchemas(prof),
403            (
404                ("aten::mul.Tensor", (False, False)),
405                ("aten::ones_like.", (False,) * 6),
406                ("aten::empty_like.", (False,) * 6),
407                ("aten::empty_strided.", (False,) * 6),
408                #
409                # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
410                ("aten::fill_.Scalar", (True, False)),
411                ("autograd::engine::evaluate_function: MulBackward0", ()),
412                ("MulBackward0", (None,)),
413                ("aten::mul.Tensor", (False, False)),
414                (
415                    "autograd::engine::evaluate_function: torch::autograd::AccumulateGrad",
416                    (),
417                ),
418                ("torch::autograd::AccumulateGrad", (None,)),
419                ("aten::detach.", (False,)),
420                ("detach", (None,)),
421            ),
422        )
423
424    def test_match_schemas_tensorlist(self) -> None:
425        x = torch.ones((1,))
426        y = torch.ones((1,))
427        with profile() as prof:
428            torch.cat([x, y], axis=0)
429
430        self.assertEqual(
431            self.formatSchemas(prof),
432            (("aten::cat.", (False, False)),),
433        )
434
435    def test_data_flow_graph_with_annotations(self) -> None:
436        def f(x, y):
437            # torch._C._jit_get_schemas_for_operator will reject any name that
438            # is missing a namespace. (denoted by the presence of "::") We want
439            # to check that we skip both annotations which have no schema
440            # (return empty tuple from SchemaMatcher.lookup_schemas) and
441            # annotations which cannot have schema (return None from
442            # SchemaMatcher.lookup_schemas).
443            with torch.profiler.record_function("Namespaced::Annotation"):
444                with torch.profiler.record_function("My Annotation"):
445                    x.zero_()
446                    y.zero_()
447                    return {"x0": torch.ones_like(x), "y0": torch.zeros_like(y)}
448
449        inputs = {"x": torch.ones((1,)), "y": torch.ones((1,))}
450        self.assertExpectedInline(
451            self._run_and_format_data_flow(inputs, f),
452            """\
453            x:       T0
454            y:       T1
455            x0:      T2
456            y0:      T3
457
458            aten::zero_               T0(v0)           ->  T0(v1)
459            aten::zero_               T1(v0)           ->  T1(v1)
460            aten::ones_like           T0(v1)           ->  T2(v0)
461            aten::zeros_like          T1(v1)           ->  T3(v0)""",
462        )
463
464    def test_data_flow_graph_non_op_allocations(self) -> None:
465        def f(x):
466            x.mul(2)
467
468        # The python arg parser will convert the python scalar `2` to a Tensor
469        # to pass to `aten::mul`. As a result there is no op that "owns" the
470        # allocation. The Tensor deletions also do not happen in an op; they
471        # are collected as a result of the Python objects going out of scope.
472        self.assertExpectedInline(
473            self._run_and_format_data_flow({"x": torch.ones((1,))}, f),
474            """\
475            x:       T1
476
477            [memory]                                   ->  T0(v0)
478            aten::mul                 T0(v0), T1(v0)   ->
479            [memory]                  T0(v0*)          ->""",
480        )
481
482    def test_data_flow_graph_simple(self) -> None:
483        inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
484
485        def f0(x, y):
486            z = x.mul(y)
487            return {"z": z.view_as(z)}
488
489        def f1(x, y):
490            with torch.no_grad():
491                return f0(x, y)
492
493        self.assertExpectedInline(
494            self._run_and_format_data_flow(inputs, f0),
495            """\
496            x:       T0
497            y:       T1
498            z:       T2
499
500            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
501            aten::view_as             T2(v0)           ->""",
502        )
503
504        # Out of place is identical regardless of Autograd.
505        self.assertExpectedInline(
506            self._run_and_format_data_flow(inputs, f0),
507            """\
508            x:       T0
509            y:       T1
510            z:       T2
511
512            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
513            aten::view_as             T2(v0)           ->""",
514        )
515
516    def test_data_flow_graph_simple_inplace(self) -> None:
517        inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
518
519        def f0(x, y):
520            x.mul_(y)
521
522        def f1(x, y):
523            with torch.no_grad():
524                return f0(x, y)
525
526        # When Autograd is enabled a second Tensor `T2` is created to store
527        # the values of T0(v0) which are needed for backwards.
528        self.assertExpectedInline(
529            self._run_and_format_data_flow(inputs, f0),
530            """\
531            x:       T0
532            y:       T1
533
534            aten::mul_                T0(v0), T1(v0)   ->  T0(v1), T2(v0)""",
535        )
536
537        self.assertExpectedInline(
538            self._run_and_format_data_flow(inputs, f1),
539            """\
540            x:       T0
541            y:       T1
542
543            aten::mul_                T0(v0), T1(v0)   ->  T0(v1)""",
544        )
545
546    def test_data_flow_graph_simple_backward(self) -> None:
547        inputs = {
548            "x": torch.ones((1,)),
549            "w": torch.ones((1,), requires_grad=True),
550        }
551        self.assertExpectedInline(
552            self._run_and_format_data_flow(
553                inputs, lambda x, w: (x * w).sin().backward()
554            ),
555            """\
556            x:       T0
557            w:       T1
558            w.grad:   T7
559
560            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
561            aten::sin                 T2(v0)           ->  T3(v0)
562            aten::ones_like           T3(v0)           ->  T4(v0)
563            SinBackward0              T2(v0), T4(v0)   ->  T6(v0)
564            [memory]                  T2(v0*)          ->
565            MulBackward0              T0(v0), T6(v0)   ->  T7(v0)
566            [memory]                  T6(v0*)          ->
567            AccumulateGrad            T7(v0)           ->
568            [memory]                  T4(v0*)          ->
569            [memory]                  T3(v0*)          ->""",
570        )
571
572    def test_data_flow_graph_complicated(self) -> None:
573        def f():
574            x = torch.ones((25,))
575            y = x.mul(2).add_(2)
576            z = torch.sin(y, out=torch.empty_like(y))
577            return {"x": x, "y": y, "z": z}
578
579        # T1 is the `2` in `.mul(2)`. The Python arg parser automatically
580        # converts Scalar arguments to Tensors. The same is true for `T4`
581        # and `.add_(2)`.
582        self.assertExpectedInline(
583            self._run_and_format_data_flow({}, f),
584            """\
585            x:       T0
586            y:       T3
587            z:       T6
588
589            aten::ones                                 ->  T0(v0)
590            [memory]                                   ->  T1(v0)
591            aten::mul                 T0(v0), T1(v0)   ->  T3(v0)
592            [memory]                  T1(v0*)          ->
593            [memory]                                   ->  T4(v0)
594            aten::add_                T3(v0), T4(v0)   ->  T3(v1)
595            [memory]                  T4(v0*)          ->
596            aten::empty_like          T3(v1)           ->  T6(v0)
597            aten::sin                 T3(v1), T6(v0)   ->  T6(v1)""",
598        )
599
600        with profile() as prof:
601            f()
602
603        # `aten::mul` creates a temporary Tensor (T2), which is why the output
604        # is has ID three rather than two.
605        mul_node = prof._memory_profile()._data_flow_graph.flow_nodes[2]
606        self.assertEqual(mul_node._event.name, "aten::mul")
607        self.assertEqual(len(mul_node.intermediates), 1)
608        self.assertEqual(mul_node.intermediates[0].id, 2)
609
610    def test_data_flow_graph_stacked(self) -> None:
611        inputs = {
612            "x": torch.ones((25,)),
613            "w0": torch.ones((1,), requires_grad=True),
614            "w1": torch.ones((1,), requires_grad=True),
615        }
616
617        def f(x, w0, w1):
618            return x.mul(w0).relu().mul(w1).relu().sum()
619
620        def f_fwd(**kwargs):
621            with torch.no_grad():
622                return {"loss": f(**kwargs)}
623
624        def f_fwd_bwd(**kwargs):
625            loss = f(**kwargs)
626            loss.backward()
627            return {"loss": loss}
628
629        self.assertExpectedInline(
630            self._run_and_format_data_flow(inputs, f_fwd),
631            """\
632            x:       T0
633            w0:      T1
634            w1:      T4
635            loss:    T7
636
637            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
638            aten::relu                T2(v0)           ->  T3(v0)
639            [memory]                  T2(v0*)          ->
640            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
641            [memory]                  T3(v0*)          ->
642            aten::relu                T5(v0)           ->  T6(v0)
643            [memory]                  T5(v0*)          ->
644            aten::sum                 T6(v0)           ->  T7(v0)
645            [memory]                  T6(v0*)          ->""",
646        )
647
648        self.assertExpectedInline(
649            self._run_and_format_data_flow(inputs, f_fwd_bwd),
650            """\
651            x:       T0
652            w0:      T1
653            w0.grad:  T15
654            w1:      T4
655            w1.grad:  T12
656            loss:    T7
657
658            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
659            aten::relu                T2(v0)           ->  T3(v0)
660            [memory]                  T2(v0*)          ->
661            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
662            aten::relu                T5(v0)           ->  T6(v0)
663            [memory]                  T5(v0*)          ->
664            aten::sum                 T6(v0)           ->  T7(v0)
665            aten::ones_like           T7(v0)           ->  T8(v0)
666            SumBackward0              T8(v0)           ->
667            ReluBackward0             T6(v0), T8(v0)   ->  T9(v0)
668            [memory]                  T6(v0*)          ->
669            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T10(v0), T11(v0)
670            aten::sum                 T10(v0)          ->  T12(v0)
671            [memory]                  T10(v0*)         ->
672            [memory]                  T9(v0*)          ->
673            AccumulateGrad            T12(v0)          ->
674            ReluBackward0             T3(v0), T11(v0)  ->  T13(v0)
675            [memory]                  T11(v0*)         ->
676            [memory]                  T3(v0*)          ->
677            MulBackward0              T0(v0), T13(v0)  ->  T14(v0)
678            aten::sum                 T14(v0)          ->  T15(v0)
679            [memory]                  T14(v0*)         ->
680            [memory]                  T13(v0*)         ->
681            AccumulateGrad            T15(v0)          ->
682            [memory]                  T8(v0*)          ->""",
683        )
684
685        # Second time grads are already initialized.
686        self.assertExpectedInline(
687            self._run_and_format_data_flow(inputs, f_fwd_bwd),
688            """\
689            x:       T0
690            w0:      T1
691            w0.grad:  T17
692            w1:      T4
693            w1.grad:  T13
694            loss:    T7
695
696            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
697            aten::relu                T2(v0)           ->  T3(v0)
698            [memory]                  T2(v0*)          ->
699            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
700            aten::relu                T5(v0)           ->  T6(v0)
701            [memory]                  T5(v0*)          ->
702            aten::sum                 T6(v0)           ->  T7(v0)
703            aten::ones_like           T7(v0)           ->  T8(v0)
704            SumBackward0              T8(v0)           ->
705            ReluBackward0             T6(v0), T8(v0)   ->  T9(v0)
706            [memory]                  T6(v0*)          ->
707            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T10(v0), T11(v0)
708            aten::sum                 T10(v0)          ->  T12(v0)
709            [memory]                  T10(v0*)         ->
710            [memory]                  T9(v0*)          ->
711            AccumulateGrad            T12(v0*), T13(v0)  ->  T13(v1)
712            ReluBackward0             T3(v0), T11(v0)  ->  T14(v0)
713            [memory]                  T11(v0*)         ->
714            [memory]                  T3(v0*)          ->
715            MulBackward0              T0(v0), T14(v0)  ->  T15(v0)
716            aten::sum                 T15(v0)          ->  T16(v0)
717            [memory]                  T15(v0*)         ->
718            [memory]                  T14(v0*)         ->
719            AccumulateGrad            T16(v0*), T17(v0)  ->  T17(v1)
720            [memory]                  T8(v0*)          ->""",
721        )
722
723        return
724
725        x = torch.ones((25,))
726        w0 = torch.ones((1,), requires_grad=True)
727        w1 = torch.ones((1,), requires_grad=True)
728
729        with profile() as prof_no_grad:
730            with torch.no_grad():
731                x.mul(w0).relu().mul(w1).relu().sum()
732
733        # TODO: one with `.logsumexp(dim=0)`
734
735        self.assertExpectedInline(
736            self._format_graph(prof_no_grad),
737            """\
738            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
739            aten::relu                T2(v0)           ->  T3(v0)
740            [memory]                  T2(v0*)          ->
741            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
742            [memory]                  T3(v0*)          ->
743            aten::relu                T5(v0)           ->  T6(v0)
744            [memory]                  T5(v0*)          ->
745            aten::sum                 T6(v0)           ->  T7(v0)
746            [memory]                  T6(v0*)          ->
747            [memory]                  T7(v0*)          ->""",
748        )
749
750        with profile() as prof_grad:
751            loss = x.mul(w0).relu().mul(w1).relu().sum()
752            loss.backward()
753
754        self.assertExpectedInline(
755            self._format_graph(prof_grad),
756            """\
757            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
758            aten::relu                T2(v0)           ->  T3(v0)
759            [memory]                  T2(v0*)          ->
760            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
761            aten::relu                T5(v0)           ->  T6(v0)
762            [memory]                  T5(v0*)          ->
763            aten::sum                 T6(v0)           ->  T7(v0)
764            aten::ones_like           T7(v0)           ->  T8(v0)
765            SumBackward0              T8(v0)           ->  T8(v1)
766            ReluBackward0             T6(v0), T8(v1)   ->  T8(v2), T9(v0)
767            [memory]                  T6(v0*)          ->
768            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T9(v1), T10(v0), T11(v0)
769            aten::sum                 T10(v0)          ->  T12(v0)
770            [memory]                  T10(v0*)         ->
771            [memory]                  T9(v1*)          ->
772            AccumulateGrad            T12(v0)          ->  T12(v1)
773            ReluBackward0             T3(v0), T11(v0)  ->  T11(v1), T13(v0)
774            [memory]                  T11(v1*)         ->
775            [memory]                  T3(v0*)          ->
776            MulBackward0              T0(v0), T13(v0)  ->  T13(v1), T14(v0)
777            aten::sum                 T14(v0)          ->  T15(v0)
778            [memory]                  T14(v0*)         ->
779            [memory]                  T13(v1*)         ->
780            AccumulateGrad            T15(v0)          ->  T15(v1)
781            [memory]                  T8(v2*)          ->""",
782        )
783
784        # Second time grads are already initialized.
785        with profile() as prof_grad:
786            loss = x.mul(w0).relu().mul(w1).relu().sum()
787            loss.backward()
788
789        self.assertExpectedInline(
790            self._format_graph(prof_grad),
791            """\
792            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
793            aten::relu                T2(v0)           ->  T3(v0)
794            [memory]                  T2(v0*)          ->
795            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
796            aten::relu                T5(v0)           ->  T6(v0)
797            [memory]                  T5(v0*)          ->
798            aten::sum                 T6(v0)           ->  T7(v0)
799            aten::ones_like           T7(v0)           ->  T8(v0)
800            SumBackward0              T8(v0)           ->  T8(v1)
801            ReluBackward0             T6(v0), T8(v1)   ->  T8(v2), T9(v0)
802            [memory]                  T6(v0*)          ->
803            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T9(v1), T10(v0), T11(v0)
804            aten::sum                 T10(v0)          ->  T12(v0)
805            [memory]                  T10(v0*)         ->
806            [memory]                  T9(v1*)          ->
807            AccumulateGrad            T12(v0*), T13(v0)  ->  T13(v1)
808            ReluBackward0             T3(v0), T11(v0)  ->  T11(v1), T14(v0)
809            [memory]                  T11(v1*)         ->
810            [memory]                  T3(v0*)          ->
811            MulBackward0              T0(v0), T14(v0)  ->  T14(v1), T15(v0)
812            aten::sum                 T15(v0)          ->  T16(v0)
813            [memory]                  T15(v0*)         ->
814            [memory]                  T14(v1*)         ->
815            AccumulateGrad            T16(v0*), T17(v0)  ->  T17(v1)
816            [memory]                  T8(v2*)          ->""",
817        )
818
819
820@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
821class TestMemoryProfilerE2E(TestCase):
822    @staticmethod
823    def _lookup_tensor_categories(
824        t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
825    ) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
826        storage = t.storage()
827        if storage is None:
828            raise ValueError("Cannot look up uninitialized Tensor.")
829
830        snapshot = memory_profile._category_snapshot()
831        ids = {
832            key.storage.allocation_id
833            for key, _ in snapshot
834            if key.storage.ptr == storage.data_ptr() and key.device == storage.device
835        }
836
837        return {
838            (key, version): category
839            for (key, version), category in memory_profile._category_snapshot().items()
840            #
841            # If a Tensor is live we want the most recent ID
842            if key.storage.allocation_id == max(ids | {-1})
843        }
844
845    def _run_and_check_parameters_and_gradients(
846        self, inner_fn, model, grads_none: bool = False
847    ):
848        with profile() as prof:
849            inner_fn()
850
851        memory_profile = prof._memory_profile()
852
853        def assert_category(
854            t: torch.Tensor,
855            category: _memory_profiler.Category,
856            should_be_none: bool = False,
857        ):
858            if should_be_none:
859                assert t is None, "tensor should be None but is not."
860                return
861            self.assertIsNotNone(t)
862            categories = self._lookup_tensor_categories(t, memory_profile)
863            self.assertGreater(len(categories), 0)
864            self.assertTrue(all(c == category for c in categories.values()), categories)
865
866        for p in model.parameters():
867            assert_category(p, _memory_profiler.Category.PARAMETER)
868            assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
869
870        # Rely on internal asserts
871        _ = memory_profile.timeline
872
873    def _run_and_format_categories(self, fn, indent=12):
874        """Generate summary of assigned categories for expecttest."""
875
876        # Use `__torch_dispatch__` to collect ground truth.
877        with RecordInputOutputDispatchMode() as record_ops, profile() as prof:
878            fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
879
880        memory_profile = prof._memory_profile()
881        ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
882        snapshot = memory_profile._category_snapshot()
883
884        # Build map from observed live Tensors to the memory profiler's
885        # TensorKey representation.
886        for op in memory_profile._op_tree.dfs():
887            if op.typed[0] == _EventType.TorchOp:
888                inputs = pytree.tree_leaves(op.typed[1].inputs)
889                for t in (i for i in inputs if isinstance(i, _TensorMetadata)):
890                    key = _memory_profiler.TensorKey.from_tensor(t)
891                    if key:
892                        ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
893
894        def format_categories(ptr_pair: int):
895            target_key = ptr_pair_to_key.get(ptr_pair, None)
896            if target_key is None:
897                return "???"
898
899            matches = tuple(
900                (version, category.name if category else "???")
901                for (key, version), category in snapshot.items()
902                if key == target_key
903            )
904            assert matches, "Failed to lookup Tensor"
905
906            # Deduplicate version bumps which don't change the category.
907            categories = [matches[0][1]]
908            for _, category in matches:
909                if category != categories[-1]:
910                    categories.append(category)
911
912            return f"{target_key.storage.allocation_id} ({','.join(categories)})"
913
914        out: List[str] = []
915        for name, inputs, outputs in record_ops.results:
916            if inputs or outputs:
917                # PyTorch ops
918                inputs_str = ", ".join(format_categories(i) for i in inputs)
919                outputs_str = ", ".join(format_categories(i) for i in outputs)
920                out.append(f"{name:<40} {inputs_str:<45} -> {outputs_str}")
921
922            else:
923                # Marked regions.
924                out.append(f"\n{name}")
925
926        return textwrap.indent("\n".join(out), " " * indent)
927
928    def test_parameters_and_gradients(self):
929        model = torch.nn.Sequential(
930            torch.nn.Linear(2, 2), ScaleLayer(), torch.nn.Linear(2, 1), ScaleLayer()
931        )
932        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
933
934        def fwd_only():
935            _ = model(torch.ones((2, 2)))
936
937        def fwd_bwd_step():
938            optimizer.zero_grad()
939            y = model(torch.ones((2, 2)))
940            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
941            optimizer.step()
942
943        # If we profile the first step then gradients will not have been
944        # created when we call `model.forward`, so if we don't call `.backward`
945        # then gradients are never created.
946        self._run_and_check_parameters_and_gradients(
947            inner_fn=fwd_only, model=model, grads_none=True
948        )
949
950        # On the first step we must rely on `AccumulateGrad`, since gradients
951        # did not exist when `model.forward` was called.
952        self.assertTrue(all(p.grad is None for p in model.parameters()))
953        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
954
955        # After one step the python tracer will also flag gradients.
956        self.assertTrue(not any(p.grad is None for p in model.parameters()))
957        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
958
959        # The parameter gradients are not used but we still detect them with
960        # the python tracer.
961        self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
962
963    def test_parameters_and_gradients_set_to_none(self):
964        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
965        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
966
967        def fwd_bwd_step():
968            for _ in range(3):
969                # zero grads at the start so gradients are still live to be
970                # checked.
971                optimizer.zero_grad(set_to_none=True)
972
973                y = model(torch.ones((2, 2)))
974                torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
975                optimizer.step()
976
977        fwd_bwd_step()
978        self.assertTrue(not any(p.grad is None for p in model.parameters()))
979        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
980
981        optimizer.zero_grad(set_to_none=True)
982        self.assertTrue(all(p.grad is None for p in model.parameters()))
983        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
984
985    def test_inputs_fwd(self):
986        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
987        inputs = [torch.ones((2, 2)) for _ in range(2)]
988
989        with profile() as prof:
990            # Inputs which were allocated before profiling began
991            for x in inputs:
992                _ = model(x)
993
994            # Inputs which were allocated after profiling began
995            for _ in range(2):
996                x = torch.ones((2, 2))
997                inputs.append(x)
998                _ = model(x)
999
1000        memory_profile = prof._memory_profile()
1001        for x in inputs:
1002            categories = self._lookup_tensor_categories(x, memory_profile)
1003            self.assertGreater(len(categories), 0)
1004            self.assertTrue(
1005                all(i == _memory_profiler.Category.INPUT for i in categories.values()),
1006                categories,
1007            )
1008
1009        snapshot = memory_profile._category_snapshot()
1010        self.assertTrue(_memory_profiler.Category.INPUT in snapshot.values())
1011
1012    def test_inputs_fwd_lazy(self):
1013        model = torch.nn.Sequential(LazyLinear(2, 2), LazyLinear(2, 1))
1014        inputs = [torch.ones((2, 2)) for _ in range(2)]
1015
1016        with profile() as prof:
1017            # Inputs which were allocated before profiling began
1018            for x in inputs:
1019                _ = model(x)
1020
1021            # Inputs which were allocated after profiling began
1022            for _ in range(2):
1023                x = torch.ones((2, 2))
1024                inputs.append(x)
1025                _ = model(x)
1026
1027        # For now we can't make any meaningful statements without a backward
1028        # pass. Here we simply ensure that passes don't generate false positive
1029        # category classifications.
1030        memory_profile = prof._memory_profile()
1031        for x in inputs:
1032            categories = self._lookup_tensor_categories(x, memory_profile)
1033            self.assertGreater(len(categories), 0)
1034            self.assertTrue(all(i is None for i in categories.values()), categories)
1035
1036        snapshot = memory_profile._category_snapshot()
1037        self.assertFalse(_memory_profiler.Category.INPUT in snapshot.values())
1038
1039    def test_inputs_fwd_bwd(self):
1040        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1041        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1042        inputs_targets = [(torch.ones((2, 2)), torch.rand((2, 1))) for _ in range(2)]
1043
1044        def fwd_bwd_step(x, targets):
1045            y = model(x)
1046            torch.nn.functional.mse_loss(y, targets).backward()
1047            optimizer.step()
1048            optimizer.zero_grad()
1049
1050        with profile() as prof:
1051            # Inputs which were allocated before profiling began
1052            for x, targets in inputs_targets:
1053                fwd_bwd_step(x, targets)
1054
1055            # Inputs which were allocated after profiling began
1056            for _ in range(2):
1057                x = torch.ones((2, 2))
1058                targets = torch.rand((2, 1))
1059                inputs_targets.append((x, targets))
1060                fwd_bwd_step(x, targets)
1061
1062        memory_profile = prof._memory_profile()
1063
1064        def check(t):
1065            categories = self._lookup_tensor_categories(t, memory_profile)
1066            self.assertGreater(len(categories), 0)
1067            self.assertTrue(
1068                all(i == _memory_profiler.Category.INPUT for i in categories.values())
1069            )
1070
1071        for x, targets in inputs_targets:
1072            check(x)
1073            check(targets)
1074
1075    def test_lazily_initialized(self) -> None:
1076        model = torch.nn.Sequential(
1077            torch.nn.Linear(2, 2),
1078            torch.nn.ReLU(),
1079            LazyLinear(2, 2),
1080            torch.nn.ReLU(),
1081            torch.nn.Linear(2, 1),
1082        )
1083
1084        self.assertEqual(len(list(model.parameters())), 4)
1085
1086        def inner_fn():
1087            y = model(torch.ones((2, 2)))
1088            optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1089            optimizer.zero_grad()
1090            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1091            optimizer.step()
1092
1093        self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1094        self.assertEqual(len(list(model.parameters())), 6)
1095
1096    def test_manual_optimizer_step(self) -> None:
1097        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1098
1099        def inner_fn():
1100            y = model(torch.ones((2, 2)))
1101            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1102
1103            with torch.no_grad():
1104                for p in model.parameters():
1105                    grad = p.grad
1106                    self.assertIsNotNone(grad)
1107                    p.add_(grad, alpha=-0.1)
1108
1109        self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1110
1111    def test_categories_e2e_simple_fwd(self) -> None:
1112        w0 = torch.ones((1,), requires_grad=True)
1113        w1 = torch.ones((1,), requires_grad=True)
1114
1115        def step_fn(_):
1116            x = torch.ones((2, 2))
1117            y = torch.cat([x * w0, x * w1], dim=1)
1118
1119        # NOTE: We expect that all unknown categories. This is simply a sanity
1120        #       check to ensure that we do not over-label.
1121        self.assertExpectedInline(
1122            self._run_and_format_categories(step_fn),
1123            """\
1124            aten::ones                                                                             -> 1 (???)
1125            aten::mul.Tensor                         1 (???), 2 (???)                              -> 3 (???)
1126            aten::mul.Tensor                         1 (???), 4 (???)                              -> 5 (???)
1127            aten::cat                                3 (???), 5 (???)                              -> ???""",
1128        )
1129
1130    def test_categories_e2e_simple_fwd_bwd(self) -> None:
1131        w0 = torch.ones((1,), requires_grad=True)
1132        w1 = torch.ones((1,), requires_grad=True)
1133
1134        def step_fn(mark_region):
1135            x = torch.ones((2, 2))
1136            targets = torch.ones((2, 4))
1137
1138            mark_region("Forward & loss")
1139            y = torch.cat([x * w0, x * w1], dim=1)
1140            loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1141
1142            mark_region("Backward")
1143            loss.backward()
1144
1145        self.assertExpectedInline(
1146            self._run_and_format_categories(step_fn),
1147            """\
1148            aten::ones                                                                             -> 1 (INPUT)
1149            aten::ones                                                                             -> 2 (INPUT)
1150
1151            -- Forward & loss ---------------------------------------------------------------------------------------
1152            aten::mul.Tensor                         1 (INPUT), 3 (INPUT)                          -> 4 (INPUT)
1153            aten::mul.Tensor                         1 (INPUT), 5 (INPUT)                          -> 6 (INPUT)
1154            aten::cat                                4 (INPUT), 6 (INPUT)                          -> 7 (INPUT)
1155            aten::binary_cross_entropy_with_logits   7 (INPUT), 2 (INPUT)                          -> 11 (INPUT)
1156
1157            -- Backward ---------------------------------------------------------------------------------------------
1158            aten::ones_like                          11 (INPUT)                                    -> 14 (INPUT)
1159            aten::sigmoid                            7 (INPUT)                                     -> 15 (TEMPORARY)
1160            aten::sub.Tensor                         15 (TEMPORARY), 2 (INPUT)                     -> 16 (TEMPORARY)
1161            aten::mul.Tensor                         16 (TEMPORARY), 14 (INPUT)                    -> 17 (AUTOGRAD_DETAIL)
1162            aten::div_.Scalar                        17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1163            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1164            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1165            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 20 (AUTOGRAD_DETAIL)
1166            aten::sum.dim_IntList                    20 (AUTOGRAD_DETAIL)                          -> 21 (GRADIENT)
1167            aten::view                               21 (GRADIENT)                                 -> 21 (GRADIENT)
1168            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1169            aten::detach                             21 (GRADIENT)                                 -> ???
1170            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 22 (AUTOGRAD_DETAIL)
1171            aten::sum.dim_IntList                    22 (AUTOGRAD_DETAIL)                          -> 23 (GRADIENT)
1172            aten::view                               23 (GRADIENT)                                 -> 23 (GRADIENT)
1173            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1174            aten::detach                             23 (GRADIENT)                                 -> ???""",
1175        )
1176
1177    def test_categories_e2e_simple_fwd_bwd_step(self) -> None:
1178        w0 = torch.ones((1,), requires_grad=True)
1179        w1 = torch.ones((1,), requires_grad=True)
1180        optimizer = torch.optim.SGD([w0, w1], lr=0.1)
1181
1182        def step_fn(mark_region):
1183            x = torch.ones((2, 2))
1184            targets = torch.ones((2, 4))
1185
1186            mark_region("Forward & loss")
1187            y = torch.cat([x * w0, x * w1], dim=1)
1188            loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1189
1190            mark_region("Backward")
1191            loss.backward()
1192
1193            mark_region("Optimizer")
1194            optimizer.step()
1195            optimizer.zero_grad()
1196
1197        self.assertExpectedInline(
1198            self._run_and_format_categories(step_fn),
1199            """\
1200            aten::ones                                                                             -> 1 (INPUT)
1201            aten::ones                                                                             -> 2 (INPUT)
1202
1203            -- Forward & loss ---------------------------------------------------------------------------------------
1204            aten::mul.Tensor                         1 (INPUT), 3 (PARAMETER)                      -> 4 (ACTIVATION)
1205            aten::mul.Tensor                         1 (INPUT), 5 (PARAMETER)                      -> 6 (ACTIVATION)
1206            aten::cat                                4 (ACTIVATION), 6 (ACTIVATION)                -> 7 (ACTIVATION)
1207            aten::binary_cross_entropy_with_logits   7 (ACTIVATION), 2 (INPUT)                     -> 11 (ACTIVATION)
1208
1209            -- Backward ---------------------------------------------------------------------------------------------
1210            aten::ones_like                          11 (ACTIVATION)                               -> 14 (ACTIVATION)
1211            aten::sigmoid                            7 (ACTIVATION)                                -> 15 (TEMPORARY)
1212            aten::sub.Tensor                         15 (TEMPORARY), 2 (INPUT)                     -> 16 (TEMPORARY)
1213            aten::mul.Tensor                         16 (TEMPORARY), 14 (ACTIVATION)               -> 17 (AUTOGRAD_DETAIL)
1214            aten::div_.Scalar                        17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1215            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1216            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1217            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 20 (AUTOGRAD_DETAIL)
1218            aten::sum.dim_IntList                    20 (AUTOGRAD_DETAIL)                          -> 21 (GRADIENT)
1219            aten::view                               21 (GRADIENT)                                 -> 21 (GRADIENT)
1220            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1221            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1222            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 22 (AUTOGRAD_DETAIL)
1223            aten::sum.dim_IntList                    22 (AUTOGRAD_DETAIL)                          -> 23 (GRADIENT)
1224            aten::view                               23 (GRADIENT)                                 -> 23 (GRADIENT)
1225            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1226            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1227
1228            -- Optimizer --------------------------------------------------------------------------------------------
1229            aten::add_.Tensor                        3 (PARAMETER), 23 (GRADIENT)                  -> 3 (PARAMETER)
1230            aten::add_.Tensor                        5 (PARAMETER), 21 (GRADIENT)                  -> 5 (PARAMETER)""",
1231        )
1232
1233    def test_categories_e2e_simple_module_fwd(self) -> None:
1234        model = torch.nn.Linear(2, 4, bias=True)
1235        self.assertExpectedInline(
1236            self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1237            """\
1238            aten::ones                                                                             -> 1 (INPUT)
1239            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1240            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)""",
1241        )
1242
1243    def test_categories_e2e_simple_module_fwd_bwd(self) -> None:
1244        model = torch.nn.Linear(2, 1, bias=True)
1245
1246        def step_fn(mark_region):
1247            mark_region("Forward & loss")
1248            loss = model(torch.ones((2, 2))).sum()
1249
1250            mark_region("Backward")
1251            loss.backward()
1252
1253        self.assertExpectedInline(
1254            self._run_and_format_categories(step_fn),
1255            """\
1256
1257            -- Forward & loss ---------------------------------------------------------------------------------------
1258            aten::ones                                                                             -> 1 (INPUT)
1259            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1260            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1261            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
1262
1263            -- Backward ---------------------------------------------------------------------------------------------
1264            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
1265            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1266            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
1267            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
1268            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1269            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
1270            aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
1271            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1272            aten::detach                             9 (GRADIENT)                                  -> ???
1273            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1274            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1275            aten::detach                             7 (GRADIENT)                                  -> ???""",
1276        )
1277
1278    def test_categories_e2e_simple_module_fwd_bwd_step(self) -> None:
1279        model = torch.nn.Linear(2, 1, bias=True)
1280        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1281
1282        def step_fn(mark_region):
1283            mark_region("Forward & loss")
1284            loss = model(torch.ones((2, 2))).sum()
1285
1286            mark_region("Backward")
1287            loss.backward()
1288
1289            mark_region("Optimizer")
1290            optimizer.step()
1291            optimizer.zero_grad()
1292
1293        self.assertExpectedInline(
1294            self._run_and_format_categories(step_fn),
1295            """\
1296
1297            -- Forward & loss ---------------------------------------------------------------------------------------
1298            aten::ones                                                                             -> 1 (INPUT)
1299            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1300            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1301            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
1302
1303            -- Backward ---------------------------------------------------------------------------------------------
1304            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
1305            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1306            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
1307            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
1308            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1309            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
1310            aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
1311            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1312            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1313            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1314            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1315            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1316
1317            -- Optimizer --------------------------------------------------------------------------------------------
1318            aten::clone                              7 (GRADIENT)                                  -> 10 (OPTIMIZER_STATE)
1319            aten::detach                             10 (OPTIMIZER_STATE)                          -> 10 (OPTIMIZER_STATE)
1320            aten::detach                             10 (OPTIMIZER_STATE)                          -> 10 (OPTIMIZER_STATE)
1321            aten::add_.Tensor                        2 (PARAMETER), 10 (OPTIMIZER_STATE)           -> 2 (PARAMETER)
1322            aten::clone                              9 (GRADIENT)                                  -> 11 (OPTIMIZER_STATE)
1323            aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
1324            aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
1325            aten::add_.Tensor                        3 (PARAMETER), 11 (OPTIMIZER_STATE)           -> 3 (PARAMETER)""",
1326        )
1327
1328    def test_categories_e2e_sequential_fwd(self) -> None:
1329        model = torch.nn.Sequential(
1330            torch.nn.Linear(2, 4, bias=True),
1331            torch.nn.ReLU(),
1332            torch.nn.Linear(4, 4, bias=False),
1333            torch.nn.Softmax(dim=1),
1334        )
1335        self.assertExpectedInline(
1336            self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1337            """\
1338            aten::ones                                                                             -> 1 (INPUT)
1339            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1340            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1341            aten::relu                               4 (ACTIVATION)                                -> 5 (ACTIVATION)
1342            aten::detach                             5 (ACTIVATION)                                -> ???
1343            aten::t                                  6 (PARAMETER)                                 -> 6 (PARAMETER)
1344            aten::mm                                 5 (ACTIVATION), 6 (PARAMETER)                 -> 7 (ACTIVATION)
1345            aten::_softmax                           7 (ACTIVATION)                                -> 8 (ACTIVATION)
1346            aten::detach                             8 (ACTIVATION)                                -> ???""",
1347        )
1348
1349    def test_categories_e2e_sequential_fwd_bwd(self) -> None:
1350        model = torch.nn.Sequential(
1351            torch.nn.Linear(2, 4, bias=True),
1352            torch.nn.ReLU(),
1353            torch.nn.Linear(4, 4, bias=False),
1354            torch.nn.Softmax(dim=1),
1355        )
1356
1357        def step_fn(mark_region):
1358            x = torch.ones((2, 2))
1359            targets = torch.ones((2, 4))
1360
1361            mark_region("Forward")
1362            y = model(x)
1363
1364            mark_region("Loss")
1365            loss = torch.sum((y - targets) ** 2).mean()
1366
1367            mark_region("Backward")
1368            loss.backward()
1369
1370        self.assertExpectedInline(
1371            self._run_and_format_categories(step_fn),
1372            """\
1373            aten::ones                                                                             -> 1 (INPUT)
1374            aten::ones                                                                             -> 2 (INPUT)
1375
1376            -- Forward ----------------------------------------------------------------------------------------------
1377            aten::t                                  3 (PARAMETER)                                 -> 3 (PARAMETER)
1378            aten::addmm                              4 (PARAMETER), 1 (INPUT), 3 (PARAMETER)       -> 5 (ACTIVATION)
1379            aten::relu                               5 (ACTIVATION)                                -> 6 (ACTIVATION)
1380            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1381            aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
1382            aten::mm                                 6 (ACTIVATION), 7 (PARAMETER)                 -> 8 (ACTIVATION)
1383            aten::_softmax                           8 (ACTIVATION)                                -> 9 (ACTIVATION)
1384            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
1385
1386            -- Loss -------------------------------------------------------------------------------------------------
1387            aten::sub.Tensor                         9 (ACTIVATION), 2 (INPUT)                     -> 10 (ACTIVATION)
1388            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 11 (ACTIVATION)
1389            aten::sum                                11 (ACTIVATION)                               -> 12 (ACTIVATION)
1390            aten::mean                               12 (ACTIVATION)                               -> 13 (ACTIVATION)
1391
1392            -- Backward ---------------------------------------------------------------------------------------------
1393            aten::ones_like                          13 (ACTIVATION)                               -> 16 (ACTIVATION)
1394            aten::expand                             16 (ACTIVATION)                               -> 16 (ACTIVATION)
1395            aten::div.Scalar                         16 (ACTIVATION)                               -> 19 (AUTOGRAD_DETAIL)
1396            aten::expand                             19 (AUTOGRAD_DETAIL)                          -> 19 (AUTOGRAD_DETAIL)
1397            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 20 (TEMPORARY)
1398            aten::mul.Scalar                         20 (TEMPORARY)                                -> 23 (TEMPORARY)
1399            aten::mul.Tensor                         19 (AUTOGRAD_DETAIL), 23 (TEMPORARY)          -> 24 (AUTOGRAD_DETAIL)
1400            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
1401            aten::_softmax_backward_data             24 (AUTOGRAD_DETAIL), 9 (ACTIVATION)          -> 25 (AUTOGRAD_DETAIL)
1402            aten::t                                  25 (AUTOGRAD_DETAIL)                          -> 25 (AUTOGRAD_DETAIL)
1403            aten::mm                                 25 (AUTOGRAD_DETAIL), 6 (ACTIVATION)          -> 26 (GRADIENT)
1404            aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
1405            aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
1406            aten::mm                                 25 (AUTOGRAD_DETAIL), 7 (PARAMETER)           -> 27 (AUTOGRAD_DETAIL)
1407            aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
1408            aten::detach                             26 (GRADIENT)                                 -> 26 (GRADIENT)
1409            aten::detach                             26 (GRADIENT)                                 -> ???
1410            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1411            aten::threshold_backward                 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION)          -> 28 (AUTOGRAD_DETAIL)
1412            aten::t                                  28 (AUTOGRAD_DETAIL)                          -> 28 (AUTOGRAD_DETAIL)
1413            aten::mm                                 28 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 29 (GRADIENT)
1414            aten::t                                  29 (GRADIENT)                                 -> 29 (GRADIENT)
1415            aten::sum.dim_IntList                    28 (AUTOGRAD_DETAIL)                          -> 30 (GRADIENT)
1416            aten::view                               30 (GRADIENT)                                 -> 30 (GRADIENT)
1417            aten::detach                             30 (GRADIENT)                                 -> 30 (GRADIENT)
1418            aten::detach                             30 (GRADIENT)                                 -> ???
1419            aten::t                                  29 (GRADIENT)                                 -> 29 (GRADIENT)
1420            aten::detach                             29 (GRADIENT)                                 -> 29 (GRADIENT)
1421            aten::detach                             29 (GRADIENT)                                 -> ???""",
1422        )
1423
1424    def test_memory_timeline(self) -> None:
1425        model = torch.nn.Sequential(
1426            torch.nn.Linear(64, 512, bias=True),
1427            torch.nn.ReLU(),
1428            torch.nn.Linear(512, 512, bias=False),
1429            torch.nn.Softmax(dim=1),
1430        )
1431        optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
1432
1433        with profile() as prof:
1434            x = torch.ones((1024, 64))
1435            targets = torch.ones((1024, 512))
1436            y = model(x)
1437            loss = torch.nn.functional.mse_loss(y, targets)
1438            loss.backward()
1439            optimizer.step()
1440            optimizer.zero_grad()
1441
1442        memory_profile = prof._memory_profile()
1443        timeline = memory_profile.timeline
1444        times = tuple(t for t, _, _, _ in timeline)
1445        self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
1446        self.assertTrue(
1447            all(
1448                (t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)
1449                for t, action, _, _ in timeline
1450            )
1451        )
1452
1453        def category_name(category):
1454            return category.name if category else "???"
1455
1456        def format_action(action, key, version):
1457            category = memory_profile._categories.get(key, version)
1458            if action == _memory_profiler.Action.INCREMENT_VERSION:
1459                new_category = memory_profile._categories.get(key, version + 1)
1460                if category != new_category:
1461                    return f"{category_name(category)} -> {category_name(new_category)}"
1462            return category_name(category)
1463
1464        def format_size(size: int):
1465            if size < 1024:
1466                return f"{size / 1024:3.1f} kB"
1467            return f"{size // 1024} kB"
1468
1469        # We generate sequential IDs for Tensors; however platforms vary
1470        # slightly in the exact computation executed. If this results in
1471        # tensor creation the IDs will be shifted and the unit test will fail.
1472        # (Even though the behavior we're testing is unchanged.) To correct for
1473        # this we assign sequential numbers to the tensors which are actually
1474        # tested, effectively suppressing the extraneous implementation details.
1475        id_map = {}
1476
1477        def id_for_testing(key):
1478            return id_map.setdefault(key.storage.allocation_id, len(id_map))
1479
1480        lines = [
1481            f"{action.name.lower():<25}  {format_action(action, key, version):<25}  "
1482            f"{id_for_testing(key):>3}(v{version}) {format_size(size):>15}"
1483            for _, action, (key, version), size in prof._memory_profile().timeline
1484            # We generally don't care about tiny allocations during memory
1485            # profiling and they add a lot of noise to the unit test.
1486            if size > 1024
1487        ]
1488
1489        self.assertExpectedInline(
1490            textwrap.indent("\n".join(lines), " " * 12),
1491            """\
1492            preexisting                PARAMETER                    0(v0)          128 kB
1493            preexisting                PARAMETER                    1(v0)            2 kB
1494            preexisting                PARAMETER                    2(v0)         1024 kB
1495            create                     INPUT                        3(v0)          256 kB
1496            create                     INPUT                        4(v0)         2048 kB
1497            create                     ACTIVATION                   5(v0)         2048 kB
1498            create                     ACTIVATION                   6(v0)         2048 kB
1499            destroy                    ACTIVATION                   5(v0)         2048 kB
1500            create                     ACTIVATION                   7(v0)         2048 kB
1501            create                     ACTIVATION                   8(v0)         2048 kB
1502            destroy                    ACTIVATION                   7(v0)         2048 kB
1503            create                     ACTIVATION                   9(v0)         2048 kB
1504            create                     TEMPORARY                   10(v0)         2048 kB
1505            destroy                    TEMPORARY                   10(v0)         2048 kB
1506            create                     AUTOGRAD_DETAIL             11(v0)         2048 kB
1507            create                     AUTOGRAD_DETAIL             12(v0)         2048 kB
1508            destroy                    AUTOGRAD_DETAIL             11(v0)         2048 kB
1509            create                     GRADIENT                    13(v0)         1024 kB
1510            create                     AUTOGRAD_DETAIL             14(v0)         2048 kB
1511            destroy                    AUTOGRAD_DETAIL             12(v0)         2048 kB
1512            create                     AUTOGRAD_DETAIL             15(v0)         2048 kB
1513            destroy                    AUTOGRAD_DETAIL             14(v0)         2048 kB
1514            destroy                    ACTIVATION                   6(v0)         2048 kB
1515            create                     GRADIENT                    16(v0)          128 kB
1516            create                     GRADIENT                    17(v0)            2 kB
1517            destroy                    AUTOGRAD_DETAIL             15(v0)         2048 kB
1518            create                     OPTIMIZER_STATE             18(v0)          128 kB
1519            create                     OPTIMIZER_STATE             19(v0)          128 kB
1520            create                     OPTIMIZER_STATE             20(v0)            2 kB
1521            create                     OPTIMIZER_STATE             21(v0)            2 kB
1522            create                     OPTIMIZER_STATE             22(v0)         1024 kB
1523            create                     OPTIMIZER_STATE             23(v0)         1024 kB
1524            increment_version          OPTIMIZER_STATE             18(v0)          128 kB
1525            increment_version          OPTIMIZER_STATE             19(v0)          128 kB
1526            increment_version          OPTIMIZER_STATE             19(v1)          128 kB
1527            create                     ???                         24(v0)          128 kB
1528            create                     ???                         25(v0)          128 kB
1529            destroy                    ???                         24(v0)          128 kB
1530            increment_version          ???                         25(v0)          128 kB
1531            increment_version          PARAMETER                    0(v0)          128 kB
1532            increment_version          OPTIMIZER_STATE             20(v0)            2 kB
1533            increment_version          OPTIMIZER_STATE             21(v0)            2 kB
1534            increment_version          OPTIMIZER_STATE             21(v1)            2 kB
1535            create                     ???                         26(v0)            2 kB
1536            create                     ???                         27(v0)            2 kB
1537            destroy                    ???                         26(v0)            2 kB
1538            increment_version          ???                         27(v0)            2 kB
1539            destroy                    ???                         25(v1)          128 kB
1540            increment_version          PARAMETER                    1(v0)            2 kB
1541            increment_version          OPTIMIZER_STATE             22(v0)         1024 kB
1542            increment_version          OPTIMIZER_STATE             23(v0)         1024 kB
1543            increment_version          OPTIMIZER_STATE             23(v1)         1024 kB
1544            create                     ???                         28(v0)         1024 kB
1545            create                     ???                         29(v0)         1024 kB
1546            destroy                    ???                         28(v0)         1024 kB
1547            increment_version          ???                         29(v0)         1024 kB
1548            destroy                    ???                         27(v1)            2 kB
1549            increment_version          PARAMETER                    2(v0)         1024 kB
1550            destroy                    ???                         29(v1)         1024 kB
1551            destroy                    GRADIENT                    16(v0)          128 kB
1552            destroy                    GRADIENT                    17(v0)            2 kB
1553            destroy                    GRADIENT                    13(v0)         1024 kB""",
1554        )
1555
1556    def test_memory_timeline_no_id(self) -> None:
1557        # On CPU the default behavior is to simply forward to malloc. That
1558        # means that when we free `x` the allocator doesn't actually know how
1559        # many bytes are in the allocation, and thus there's no point to
1560        # calling `c10::reportMemoryUsageToProfiler`. So in order to test that
1561        # memory profiler processes this case correctly we need to use CUDA
1562        # where we do always keep a record.
1563        x = torch.ones((1024,), device="cuda" if torch.cuda.is_available() else "cpu")
1564
1565        with profile() as prof:
1566            # We never see `x` used so we don't know the storage is for a
1567            # Tensor, but we do still see the free event.
1568            del x
1569
1570            # For empty we see the allocation and free, but not any use.
1571            # So this also cannot be identified as a Tensor.
1572            y = torch.empty((64,))
1573            del y
1574
1575            z = torch.empty((256,))
1576            z.view_as(z)  # Show `z` to the profiler
1577            del z
1578
1579        memory_profile = prof._memory_profile()
1580
1581        expected = [
1582            # x
1583            (_memory_profiler.Action.PREEXISTING, 4096),
1584            (_memory_profiler.Action.DESTROY, 4096),
1585            #
1586            # y
1587            (_memory_profiler.Action.CREATE, 256),
1588            (_memory_profiler.Action.DESTROY, 256),
1589            #
1590            # z
1591            (_memory_profiler.Action.CREATE, 1024),
1592            (_memory_profiler.Action.DESTROY, 1024),
1593        ]
1594
1595        actual = [(action, size) for _, action, _, size in memory_profile.timeline]
1596
1597        # See above.
1598        if not torch.cuda.is_available():
1599            expected = expected[2:]
1600            for event in expected:
1601                self.assertTrue(
1602                    event in actual, f"event: {event} was not found in actual."
1603                )
1604        else:
1605            self.assertEqual(
1606                actual,
1607                expected,
1608                f"expected does not match actual: {actual}",
1609            )
1610
1611
1612if __name__ == "__main__":
1613    run_tests()
1614