1# Owner(s): ["oncall: distributed"] 2import os 3import unittest 4 5import torch 6import torch.nn as nn 7from torch.distributed._tools import MemoryTracker 8from torch.testing._internal.common_cuda import TEST_CUDA 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12class TestMemoryTracker(TestCase): 13 @unittest.skipIf(not TEST_CUDA, "no cuda") 14 def test_local_model(self): 15 """ 16 Minimal test case to check the memory tracker can collect the expected 17 memory stats at operator level, as well as can print the summary result 18 without crash. 19 """ 20 # Create a model with a hierarchy of modules 21 torch.manual_seed(0) 22 model = nn.Sequential( 23 nn.Sequential( 24 nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False), 25 nn.BatchNorm2d(64), 26 nn.ReLU(inplace=False), 27 nn.AdaptiveAvgPool2d(output_size=(1, 1)), 28 ), 29 nn.Flatten(start_dim=1), 30 nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)), 31 ).cuda() 32 33 # Run one iteration of forward and backward pass 34 tracker = MemoryTracker() 35 tracker.start_monitor(model) 36 37 x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda")) 38 # torch.LongTensor expects cpu device type, not cuda device type in 39 # constructor, so calling .cuda() outside constructor here. 40 target = torch.LongTensor([0, 1]).cuda() 41 criterion = nn.CrossEntropyLoss() 42 criterion(model(x), target).backward() 43 44 self.assertTrue(len(tracker._hooks) > 0) 45 46 tracker.stop() 47 48 self.assertTrue(len(tracker._hooks) == 0) 49 50 path = "memory.trace" 51 tracker.save_stats(path) 52 tracker.load(path) 53 tracker.summary() 54 if os.path.exists(path): 55 os.remove(path) 56 57 self.assertTrue(tracker._op_index > 0) 58 self.assertTrue(len(tracker._operator_names) > 0) 59 self.assertEqual(len(tracker.memories_allocated), tracker._op_index) 60 self.assertEqual(len(tracker.memories_active), tracker._op_index) 61 self.assertEqual(len(tracker.memories_reserved), tracker._op_index) 62 self.assertTrue(len(tracker._markers) == 2) 63 self.assertTrue(tracker._cur_module_name != "") 64 self.assertTrue(hasattr(tracker, "_num_cuda_retries")) 65 66 67if __name__ == "__main__": 68 run_tests() 69