1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN, TestCase 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workertry: 11*da0073e9SAndroid Build Coastguard Worker import psutil 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = True 14*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 15*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = False 16*da0073e9SAndroid Build Coastguard Worker psutil = None 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerdevice = torch.device("cpu") 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerclass Network(torch.nn.Module): 23*da0073e9SAndroid Build Coastguard Worker maxp1 = torch.nn.MaxPool2d(1, 1) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 26*da0073e9SAndroid Build Coastguard Worker return self.maxp1(x) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") 30*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") 31*da0073e9SAndroid Build Coastguard Workerclass TestOpenMP_ParallelFor(TestCase): 32*da0073e9SAndroid Build Coastguard Worker batch = 20 33*da0073e9SAndroid Build Coastguard Worker channels = 1 34*da0073e9SAndroid Build Coastguard Worker side_dim = 80 35*da0073e9SAndroid Build Coastguard Worker x = torch.randn([batch, channels, side_dim, side_dim], device=device) 36*da0073e9SAndroid Build Coastguard Worker model = Network() 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker def func(self, runs): 39*da0073e9SAndroid Build Coastguard Worker p = psutil.Process() 40*da0073e9SAndroid Build Coastguard Worker # warm up for 5 runs, then things should be stable for the last 5 41*da0073e9SAndroid Build Coastguard Worker last_rss = collections.deque(maxlen=5) 42*da0073e9SAndroid Build Coastguard Worker for n in range(10): 43*da0073e9SAndroid Build Coastguard Worker for i in range(runs): 44*da0073e9SAndroid Build Coastguard Worker self.model(self.x) 45*da0073e9SAndroid Build Coastguard Worker last_rss.append(p.memory_info().rss) 46*da0073e9SAndroid Build Coastguard Worker return last_rss 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def func_rss(self, runs): 49*da0073e9SAndroid Build Coastguard Worker last_rss = list(self.func(runs)) 50*da0073e9SAndroid Build Coastguard Worker # Check that the sequence is not strictly increasing 51*da0073e9SAndroid Build Coastguard Worker is_increasing = True 52*da0073e9SAndroid Build Coastguard Worker for idx in range(len(last_rss)): 53*da0073e9SAndroid Build Coastguard Worker if idx == 0: 54*da0073e9SAndroid Build Coastguard Worker continue 55*da0073e9SAndroid Build Coastguard Worker is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1]) 56*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 57*da0073e9SAndroid Build Coastguard Worker not is_increasing, msg=f"memory usage is increasing, {str(last_rss)}" 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def test_one_thread(self): 61*da0073e9SAndroid Build Coastguard Worker """Make sure there is no memory leak with one thread: issue gh-32284""" 62*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(1) 63*da0073e9SAndroid Build Coastguard Worker self.func_rss(300) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_n_threads(self): 66*da0073e9SAndroid Build Coastguard Worker """Make sure there is no memory leak with many threads""" 67*da0073e9SAndroid Build Coastguard Worker ncores = min(5, psutil.cpu_count(logical=False)) 68*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(ncores) 69*da0073e9SAndroid Build Coastguard Worker self.func_rss(300) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 73*da0073e9SAndroid Build Coastguard Worker run_tests() 74