xref: /aosp_15_r20/external/pytorch/test/distributed/_tools/test_memory_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2import os
3import unittest
4
5import torch
6import torch.nn as nn
7from torch.distributed._tools import MemoryTracker
8from torch.testing._internal.common_cuda import TEST_CUDA
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12class TestMemoryTracker(TestCase):
13    @unittest.skipIf(not TEST_CUDA, "no cuda")
14    def test_local_model(self):
15        """
16        Minimal test case to check the memory tracker can collect the expected
17        memory stats at operator level, as well as can print the summary result
18        without crash.
19        """
20        # Create a model with a hierarchy of modules
21        torch.manual_seed(0)
22        model = nn.Sequential(
23            nn.Sequential(
24                nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
25                nn.BatchNorm2d(64),
26                nn.ReLU(inplace=False),
27                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
28            ),
29            nn.Flatten(start_dim=1),
30            nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
31        ).cuda()
32
33        # Run one iteration of forward and backward pass
34        tracker = MemoryTracker()
35        tracker.start_monitor(model)
36
37        x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda"))
38        # torch.LongTensor expects cpu device type, not cuda device type in
39        # constructor, so calling .cuda() outside constructor here.
40        target = torch.LongTensor([0, 1]).cuda()
41        criterion = nn.CrossEntropyLoss()
42        criterion(model(x), target).backward()
43
44        self.assertTrue(len(tracker._hooks) > 0)
45
46        tracker.stop()
47
48        self.assertTrue(len(tracker._hooks) == 0)
49
50        path = "memory.trace"
51        tracker.save_stats(path)
52        tracker.load(path)
53        tracker.summary()
54        if os.path.exists(path):
55            os.remove(path)
56
57        self.assertTrue(tracker._op_index > 0)
58        self.assertTrue(len(tracker._operator_names) > 0)
59        self.assertEqual(len(tracker.memories_allocated), tracker._op_index)
60        self.assertEqual(len(tracker.memories_active), tracker._op_index)
61        self.assertEqual(len(tracker.memories_reserved), tracker._op_index)
62        self.assertTrue(len(tracker._markers) == 2)
63        self.assertTrue(tracker._cur_module_name != "")
64        self.assertTrue(hasattr(tracker, "_num_cuda_retries"))
65
66
67if __name__ == "__main__":
68    run_tests()
69