1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: multiprocessing"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport copy 5*da0073e9SAndroid Build Coastguard Workerimport gc 6*da0073e9SAndroid Build Coastguard Workerimport os 7*da0073e9SAndroid Build Coastguard Workerimport sys 8*da0073e9SAndroid Build Coastguard Workerimport time 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Workerfrom sys import platform 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport torch 13*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 14*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp 15*da0073e9SAndroid Build Coastguard Workerimport torch.utils.hooks 16*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import IS_JETSON 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 19*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 20*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 21*da0073e9SAndroid Build Coastguard Worker load_tests, 22*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 23*da0073e9SAndroid Build Coastguard Worker run_tests, 24*da0073e9SAndroid Build Coastguard Worker slowTest, 25*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 26*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 27*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 28*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 29*da0073e9SAndroid Build Coastguard Worker TestCase, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 34*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 35*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard WorkerTEST_REPEATS = 30 38*da0073e9SAndroid Build Coastguard WorkerHAS_SHM_FILES = os.path.isdir("/dev/shm") 39*da0073e9SAndroid Build Coastguard WorkerMAX_WAITING_TIME_IN_SECONDS = 30 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerTEST_CUDA_IPC = ( 42*da0073e9SAndroid Build Coastguard Worker torch.cuda.is_available() 43*da0073e9SAndroid Build Coastguard Worker and sys.platform != "darwin" 44*da0073e9SAndroid Build Coastguard Worker and sys.platform != "win32" 45*da0073e9SAndroid Build Coastguard Worker and not IS_JETSON 46*da0073e9SAndroid Build Coastguard Worker and not TEST_WITH_ROCM 47*da0073e9SAndroid Build Coastguard Worker) # https://github.com/pytorch/pytorch/issues/90940 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerclass SubProcess(mp.Process): 53*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 54*da0073e9SAndroid Build Coastguard Worker super().__init__() 55*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 56*da0073e9SAndroid Build Coastguard Worker self.daemon = True 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def run(self): 59*da0073e9SAndroid Build Coastguard Worker self.tensor.add_(3) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Workerdef _test_cuda_ipc_deadlock_actor(queue, iterations): 63*da0073e9SAndroid Build Coastguard Worker for i in range(iterations): 64*da0073e9SAndroid Build Coastguard Worker if not queue.empty(): 65*da0073e9SAndroid Build Coastguard Worker queue.get() 66*da0073e9SAndroid Build Coastguard Worker time.sleep(0.01) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef _test_cuda_ipc_deadlock_learner(queue, iterations): 70*da0073e9SAndroid Build Coastguard Worker net = torch.nn.LSTM(1, 1).cuda() 71*da0073e9SAndroid Build Coastguard Worker for i in range(iterations): 72*da0073e9SAndroid Build Coastguard Worker if not queue.full(): 73*da0073e9SAndroid Build Coastguard Worker queue.put(copy.deepcopy(net.state_dict())) 74*da0073e9SAndroid Build Coastguard Worker time.sleep(0.01) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerdef simple_fill(queue, event): 78*da0073e9SAndroid Build Coastguard Worker data = queue.get() 79*da0073e9SAndroid Build Coastguard Worker data[0][:] = 4 80*da0073e9SAndroid Build Coastguard Worker event.set() 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerdef simple_pool_fill(tensor): 84*da0073e9SAndroid Build Coastguard Worker tensor.fill_(4) 85*da0073e9SAndroid Build Coastguard Worker return tensor.add(1) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workerdef send_tensor(queue, event, device, dtype): 89*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device, dtype=dtype) 90*da0073e9SAndroid Build Coastguard Worker queue.put(t) 91*da0073e9SAndroid Build Coastguard Worker queue.put(t) 92*da0073e9SAndroid Build Coastguard Worker event.wait() 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Workerdef send_and_delete_tensors(queue, event, device, dtype, count, size=5): 96*da0073e9SAndroid Build Coastguard Worker for i in range(count): 97*da0073e9SAndroid Build Coastguard Worker t = torch.full([size], i, device=device, dtype=dtype) 98*da0073e9SAndroid Build Coastguard Worker queue.put(t) 99*da0073e9SAndroid Build Coastguard Worker del t 100*da0073e9SAndroid Build Coastguard Worker event.wait() 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Workerdef receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5): 104*da0073e9SAndroid Build Coastguard Worker s = torch.full([size], 0, device=device, dtype=dtype) 105*da0073e9SAndroid Build Coastguard Worker for i in range(count): 106*da0073e9SAndroid Build Coastguard Worker t = queue.get() 107*da0073e9SAndroid Build Coastguard Worker s += t 108*da0073e9SAndroid Build Coastguard Worker out_queue.put(s) 109*da0073e9SAndroid Build Coastguard Worker event.wait() 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerdef receive_and_send(queue, out_queue, event, count): 113*da0073e9SAndroid Build Coastguard Worker for i in range(count): 114*da0073e9SAndroid Build Coastguard Worker t = queue.get() 115*da0073e9SAndroid Build Coastguard Worker out_queue.put(t.clone()) 116*da0073e9SAndroid Build Coastguard Worker event.wait() 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Workerdef sum_tensors(inq, outq): 120*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 121*da0073e9SAndroid Build Coastguard Worker tensors = inq.get() 122*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 123*da0073e9SAndroid Build Coastguard Worker outq.put( 124*da0073e9SAndroid Build Coastguard Worker ( 125*da0073e9SAndroid Build Coastguard Worker tensor.sum().item(), 126*da0073e9SAndroid Build Coastguard Worker tensor.get_device(), 127*da0073e9SAndroid Build Coastguard Worker tensor.numel(), 128*da0073e9SAndroid Build Coastguard Worker tensor.storage().size(), 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Workerdef queue_get_exception(inqueue, outqueue): 134*da0073e9SAndroid Build Coastguard Worker os.close(2) # hide expected error message 135*da0073e9SAndroid Build Coastguard Worker try: 136*da0073e9SAndroid Build Coastguard Worker torch.zeros(5, 5).cuda() 137*da0073e9SAndroid Build Coastguard Worker except Exception as e: 138*da0073e9SAndroid Build Coastguard Worker outqueue.put(e) 139*da0073e9SAndroid Build Coastguard Worker else: 140*da0073e9SAndroid Build Coastguard Worker outqueue.put("no exception") 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker# Multiply by two in a separate stream 144*da0073e9SAndroid Build Coastguard Workerdef cuda_multiply_two(queue, ready, done): 145*da0073e9SAndroid Build Coastguard Worker ready.set() 146*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(torch.cuda.Stream()): 147*da0073e9SAndroid Build Coastguard Worker cuda_event, tensor = queue.get() 148*da0073e9SAndroid Build Coastguard Worker cuda_event.wait() 149*da0073e9SAndroid Build Coastguard Worker tensor.mul_(2) 150*da0073e9SAndroid Build Coastguard Worker cuda_event.record() 151*da0073e9SAndroid Build Coastguard Worker done.set() 152*da0073e9SAndroid Build Coastguard Worker del cuda_event 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef requires_grad_variable_sharing(queue, ready): 156*da0073e9SAndroid Build Coastguard Worker var = queue.get() 157*da0073e9SAndroid Build Coastguard Worker ready.set() 158*da0073e9SAndroid Build Coastguard Worker queue.put(var.requires_grad) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Workerdef integer_parameter_serialization(iparam): 162*da0073e9SAndroid Build Coastguard Worker iparam + 1 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerdef autograd_sharing(queue, ready, master_modified, device, is_parameter): 166*da0073e9SAndroid Build Coastguard Worker var = queue.get() 167*da0073e9SAndroid Build Coastguard Worker ready.set() 168*da0073e9SAndroid Build Coastguard Worker master_modified.wait() 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker expected_var = torch.arange(1.0, 26, device=device).view(5, 5) 171*da0073e9SAndroid Build Coastguard Worker expected_var[0, 0] = 1000 172*da0073e9SAndroid Build Coastguard Worker is_ok = var.data.equal(expected_var) 173*da0073e9SAndroid Build Coastguard Worker var.data[:] = torch.ones(5, 5, device=device) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker is_ok &= var.grad is None 176*da0073e9SAndroid Build Coastguard Worker is_ok &= not var._backward_hooks 177*da0073e9SAndroid Build Coastguard Worker if is_parameter: 178*da0073e9SAndroid Build Coastguard Worker is_ok &= type(var) == Parameter 179*da0073e9SAndroid Build Coastguard Worker else: 180*da0073e9SAndroid Build Coastguard Worker is_ok &= type(var) == torch.Tensor 181*da0073e9SAndroid Build Coastguard Worker var._grad = torch.ones(5, 5, device=device) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker queue.put(is_ok) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef mixed_type_producer(queue, event): 187*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 188*da0073e9SAndroid Build Coastguard Worker float_tensor = torch.ones(2, 2).float().cuda() 189*da0073e9SAndroid Build Coastguard Worker byte_tensor = torch.zeros(2, 2).byte().cuda() 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker queue.put(float_tensor) 192*da0073e9SAndroid Build Coastguard Worker queue.put(byte_tensor) 193*da0073e9SAndroid Build Coastguard Worker event.wait() 194*da0073e9SAndroid Build Coastguard Worker event.clear() 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Workerdef simple_autograd_function(a=1): 198*da0073e9SAndroid Build Coastguard Worker torch.rand(3).requires_grad_(True).mean().backward() 199*da0073e9SAndroid Build Coastguard Worker return a**2 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 203*da0073e9SAndroid Build Coastguard Workerdef fs_sharing(): 204*da0073e9SAndroid Build Coastguard Worker prev_strategy = mp.get_sharing_strategy() 205*da0073e9SAndroid Build Coastguard Worker mp.set_sharing_strategy("file_system") 206*da0073e9SAndroid Build Coastguard Worker try: 207*da0073e9SAndroid Build Coastguard Worker yield 208*da0073e9SAndroid Build Coastguard Worker finally: 209*da0073e9SAndroid Build Coastguard Worker mp.set_sharing_strategy(prev_strategy) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Workerclass leak_checker: 213*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_case): 214*da0073e9SAndroid Build Coastguard Worker self.checked_pids = [os.getpid()] 215*da0073e9SAndroid Build Coastguard Worker self.test_case = test_case 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 218*da0073e9SAndroid Build Coastguard Worker self.next_fds = self._get_next_fds(10) 219*da0073e9SAndroid Build Coastguard Worker return self 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 222*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 223*da0073e9SAndroid Build Coastguard Worker torch.cuda.ipc_collect() 224*da0073e9SAndroid Build Coastguard Worker if args[0] is None: 225*da0073e9SAndroid Build Coastguard Worker # Check that the 10th available file-descriptor at the end of the 226*da0073e9SAndroid Build Coastguard Worker # test is no more than 4 higher than the 10th available at the 227*da0073e9SAndroid Build Coastguard Worker # start. This attempts to catch file descriptor leaks, but allows 228*da0073e9SAndroid Build Coastguard Worker # one-off initialization that may use up a file descriptor 229*da0073e9SAndroid Build Coastguard Worker # TODO: Disabled because this check is too flaky 230*da0073e9SAndroid Build Coastguard Worker # available_fds = self._get_next_fds(10) 231*da0073e9SAndroid Build Coastguard Worker # self.test_case.assertLessEqual( 232*da0073e9SAndroid Build Coastguard Worker # available_fds[-1] - self.next_fds[-1], 5) 233*da0073e9SAndroid Build Coastguard Worker self.test_case.assertFalse(self.has_shm_files()) 234*da0073e9SAndroid Build Coastguard Worker return False 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker def check_pid(self, pid): 237*da0073e9SAndroid Build Coastguard Worker self.checked_pids.append(pid) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker def _get_next_fds(self, n=1): 240*da0073e9SAndroid Build Coastguard Worker # dup uses the lowest-numbered unused descriptor for the new descriptor 241*da0073e9SAndroid Build Coastguard Worker fds = [os.dup(0) for i in range(n)] 242*da0073e9SAndroid Build Coastguard Worker for fd in fds: 243*da0073e9SAndroid Build Coastguard Worker os.close(fd) 244*da0073e9SAndroid Build Coastguard Worker return fds 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker def has_shm_files(self, wait=True): 247*da0073e9SAndroid Build Coastguard Worker if not HAS_SHM_FILES: 248*da0073e9SAndroid Build Coastguard Worker return False 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker result = self._has_shm_files() 251*da0073e9SAndroid Build Coastguard Worker if not result or mp.get_sharing_strategy() != "file_system" or not wait: 252*da0073e9SAndroid Build Coastguard Worker return result 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker total_waiting_time = 0 255*da0073e9SAndroid Build Coastguard Worker waiting_time = 0.5 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and result: 258*da0073e9SAndroid Build Coastguard Worker time.sleep(waiting_time) 259*da0073e9SAndroid Build Coastguard Worker total_waiting_time += waiting_time 260*da0073e9SAndroid Build Coastguard Worker result = self._has_shm_files() 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker return result 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def _has_shm_files(self): 265*da0073e9SAndroid Build Coastguard Worker gc.collect() 266*da0073e9SAndroid Build Coastguard Worker names = ["torch_" + str(pid) for pid in self.checked_pids] 267*da0073e9SAndroid Build Coastguard Worker for filename in os.listdir("/dev/shm"): 268*da0073e9SAndroid Build Coastguard Worker for name in names: 269*da0073e9SAndroid Build Coastguard Worker if filename.startswith(name): 270*da0073e9SAndroid Build Coastguard Worker return True 271*da0073e9SAndroid Build Coastguard Worker return False 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 275*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 276*da0073e9SAndroid Build Coastguard Worker "TSAN is not fork-safe since we're forking in a multi-threaded environment", 277*da0073e9SAndroid Build Coastguard Worker) 278*da0073e9SAndroid Build Coastguard Workerclass TestMultiprocessing(TestCase): 279*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 280*da0073e9SAndroid Build Coastguard Worker # This will keep tests isolated from each-other 281*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 282*da0073e9SAndroid Build Coastguard Worker torch.cuda.ipc_collect() 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def _test_sharing(self, ctx=mp, device="cpu", dtype=torch.float, repeat=1): 285*da0073e9SAndroid Build Coastguard Worker def test_fill(): 286*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(5, 5).to(device, dtype) 287*da0073e9SAndroid Build Coastguard Worker q = ctx.Queue() 288*da0073e9SAndroid Build Coastguard Worker e = ctx.Event() 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker data = [x, x[:, 1]] 291*da0073e9SAndroid Build Coastguard Worker q.put(data) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=simple_fill, args=(q, e)) 294*da0073e9SAndroid Build Coastguard Worker p.daemon = True 295*da0073e9SAndroid Build Coastguard Worker lc.check_pid(p.pid) 296*da0073e9SAndroid Build Coastguard Worker p.start() 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker total_waiting_time = 0 299*da0073e9SAndroid Build Coastguard Worker waiting_time = 0.5 300*da0073e9SAndroid Build Coastguard Worker is_set = False 301*da0073e9SAndroid Build Coastguard Worker # Once the child process is done, it will set the event to notify the 302*da0073e9SAndroid Build Coastguard Worker # parent accordingly 303*da0073e9SAndroid Build Coastguard Worker while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and not is_set: 304*da0073e9SAndroid Build Coastguard Worker time.sleep(waiting_time) 305*da0073e9SAndroid Build Coastguard Worker total_waiting_time += waiting_time 306*da0073e9SAndroid Build Coastguard Worker is_set = e.is_set() 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_set) 309*da0073e9SAndroid Build Coastguard Worker if device != "meta": 310*da0073e9SAndroid Build Coastguard Worker self.assertTrue(data[0].eq(4).all()) 311*da0073e9SAndroid Build Coastguard Worker self.assertTrue(data[1].eq(4).all()) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker p.join(100) 314*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker def test_receive(): 317*da0073e9SAndroid Build Coastguard Worker q = ctx.Queue() 318*da0073e9SAndroid Build Coastguard Worker e = ctx.Event() 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=send_tensor, args=(q, e, device, dtype)) 321*da0073e9SAndroid Build Coastguard Worker p.daemon = True 322*da0073e9SAndroid Build Coastguard Worker lc.check_pid(p.pid) 323*da0073e9SAndroid Build Coastguard Worker p.start() 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker t1 = q.get() 326*da0073e9SAndroid Build Coastguard Worker t2 = q.get() 327*da0073e9SAndroid Build Coastguard Worker if device == "meta": 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.size(), t2.size()) 329*da0073e9SAndroid Build Coastguard Worker else: 330*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t1.eq(1).all()) 331*da0073e9SAndroid Build Coastguard Worker s1 = t1.storage() 332*da0073e9SAndroid Build Coastguard Worker s2 = t2.storage() 333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(s1), type(s2)) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.data_ptr(), s1.data_ptr()) 335*da0073e9SAndroid Build Coastguard Worker if device == "meta": 336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.size(), s2.size()) 337*da0073e9SAndroid Build Coastguard Worker else: 338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1, s2) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # We need to delete this tensors to allow producer (child process) 341*da0073e9SAndroid Build Coastguard Worker # collect them properly 342*da0073e9SAndroid Build Coastguard Worker del t1, t2 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker # Mark the event as done and join the process 345*da0073e9SAndroid Build Coastguard Worker e.set() 346*da0073e9SAndroid Build Coastguard Worker p.join(100) 347*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker with leak_checker(self) as lc: 350*da0073e9SAndroid Build Coastguard Worker for _ in range(repeat): 351*da0073e9SAndroid Build Coastguard Worker test_fill() 352*da0073e9SAndroid Build Coastguard Worker test_receive() 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker def _test_preserve_sharing(self, ctx=mp, repeat=1): 355*da0073e9SAndroid Build Coastguard Worker def do_test(): 356*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 357*da0073e9SAndroid Build Coastguard Worker data = [x.storage(), x, x[2], x[:, 1]] 358*da0073e9SAndroid Build Coastguard Worker q = ctx.Queue() 359*da0073e9SAndroid Build Coastguard Worker q.put(data) 360*da0073e9SAndroid Build Coastguard Worker new_data = q.get(timeout=1) 361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_data, data, atol=0, rtol=0) 362*da0073e9SAndroid Build Coastguard Worker storage_cdata = data[0]._cdata 363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_data[0]._cdata, storage_cdata) 364*da0073e9SAndroid Build Coastguard Worker for t in new_data[1:]: 365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.storage()._cdata, storage_cdata) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker with leak_checker(self): 368*da0073e9SAndroid Build Coastguard Worker for _ in range(repeat): 369*da0073e9SAndroid Build Coastguard Worker do_test() 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker def _test_pool(self, ctx=mp, repeat=1): 372*da0073e9SAndroid Build Coastguard Worker def do_test(): 373*da0073e9SAndroid Build Coastguard Worker p = ctx.Pool(2) 374*da0073e9SAndroid Build Coastguard Worker for proc in p._pool: 375*da0073e9SAndroid Build Coastguard Worker lc.check_pid(proc.pid) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker buffers = [torch.zeros(2, 2) for i in range(4)] 378*da0073e9SAndroid Build Coastguard Worker results = p.map(simple_pool_fill, buffers, 1) 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(results), len(buffers)) 380*da0073e9SAndroid Build Coastguard Worker for r in results: 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, torch.ones(2, 2) * 5, atol=0, rtol=0) 382*da0073e9SAndroid Build Coastguard Worker for b in buffers: 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, torch.ones(2, 2) * 4, atol=0, rtol=0) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker p.close() 386*da0073e9SAndroid Build Coastguard Worker p.join() 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker with leak_checker(self) as lc: 389*da0073e9SAndroid Build Coastguard Worker for _ in range(repeat): 390*da0073e9SAndroid Build Coastguard Worker do_test() 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 393*da0073e9SAndroid Build Coastguard Worker platform == "darwin", "file descriptor strategy is not supported on macOS" 394*da0073e9SAndroid Build Coastguard Worker ) 395*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 396*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 397*da0073e9SAndroid Build Coastguard Worker "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326", 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker def test_fd_sharing(self): 400*da0073e9SAndroid Build Coastguard Worker self._test_sharing(repeat=TEST_REPEATS) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 403*da0073e9SAndroid Build Coastguard Worker platform == "darwin", "file descriptor strategy is not supported on macOS" 404*da0073e9SAndroid Build Coastguard Worker ) 405*da0073e9SAndroid Build Coastguard Worker def test_fd_preserve_sharing(self): 406*da0073e9SAndroid Build Coastguard Worker self._test_preserve_sharing(repeat=TEST_REPEATS) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 409*da0073e9SAndroid Build Coastguard Worker platform == "darwin", "file descriptor strategy is not supported on macOS" 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker def test_fd_pool(self): 412*da0073e9SAndroid Build Coastguard Worker self._test_pool(repeat=TEST_REPEATS) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 415*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 416*da0073e9SAndroid Build Coastguard Worker "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326", 417*da0073e9SAndroid Build Coastguard Worker ) 418*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 419*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 420*da0073e9SAndroid Build Coastguard Worker "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467", 421*da0073e9SAndroid Build Coastguard Worker ) 422*da0073e9SAndroid Build Coastguard Worker def test_fs_sharing(self): 423*da0073e9SAndroid Build Coastguard Worker with fs_sharing(): 424*da0073e9SAndroid Build Coastguard Worker # The test works but is very slow on MacOS, see https://github.com/pytorch/pytorch/pull/93183, 425*da0073e9SAndroid Build Coastguard Worker # so run it only once there. The delay is in waiting for the child process to terminate (join) 426*da0073e9SAndroid Build Coastguard Worker repeat = 1 if IS_MACOS else TEST_REPEATS 427*da0073e9SAndroid Build Coastguard Worker self._test_sharing(repeat=repeat) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 430*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 431*da0073e9SAndroid Build Coastguard Worker "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467", 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker def test_fs_preserve_sharing(self): 434*da0073e9SAndroid Build Coastguard Worker with fs_sharing(): 435*da0073e9SAndroid Build Coastguard Worker self._test_preserve_sharing(repeat=TEST_REPEATS) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 438*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 439*da0073e9SAndroid Build Coastguard Worker "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467", 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker def test_fs_pool(self): 442*da0073e9SAndroid Build Coastguard Worker with fs_sharing(): 443*da0073e9SAndroid Build Coastguard Worker self._test_pool(repeat=TEST_REPEATS) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist") 446*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 447*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 448*da0073e9SAndroid Build Coastguard Worker "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467", 449*da0073e9SAndroid Build Coastguard Worker ) 450*da0073e9SAndroid Build Coastguard Worker def test_fs(self): 451*da0073e9SAndroid Build Coastguard Worker def queue_put(): 452*da0073e9SAndroid Build Coastguard Worker x = torch.DoubleStorage(4) 453*da0073e9SAndroid Build Coastguard Worker q = mp.Queue() 454*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lc.has_shm_files()) 455*da0073e9SAndroid Build Coastguard Worker q.put(x) 456*da0073e9SAndroid Build Coastguard Worker time.sleep(0.05) # queue serializes asynchronously 457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lc.has_shm_files(wait=False)) 458*da0073e9SAndroid Build Coastguard Worker q.get() 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker with fs_sharing(), leak_checker(self) as lc: 461*da0073e9SAndroid Build Coastguard Worker for _ in range(TEST_REPEATS): 462*da0073e9SAndroid Build Coastguard Worker queue_put() 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker def test_inherit_tensor(self): 465*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(5, 5) 466*da0073e9SAndroid Build Coastguard Worker p = SubProcess(t.share_memory_()) 467*da0073e9SAndroid Build Coastguard Worker p.start() 468*da0073e9SAndroid Build Coastguard Worker p.join(2) 469*da0073e9SAndroid Build Coastguard Worker if p.exitcode is None: 470*da0073e9SAndroid Build Coastguard Worker print("test_inherit_tensor: SubProcess too slow") 471*da0073e9SAndroid Build Coastguard Worker else: 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing") 475*da0073e9SAndroid Build Coastguard Worker def test_autograd_errors(self): 476*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("fork") 477*da0073e9SAndroid Build Coastguard Worker simple_autograd_function() 478*da0073e9SAndroid Build Coastguard Worker # Autograd only uses thread when GPUs are involved 479*da0073e9SAndroid Build Coastguard Worker if ( 480*da0073e9SAndroid Build Coastguard Worker torch.cuda.is_available() 481*da0073e9SAndroid Build Coastguard Worker or torch.backends.mps.is_available() 482*da0073e9SAndroid Build Coastguard Worker or torch.xpu.is_available() 483*da0073e9SAndroid Build Coastguard Worker ): 484*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Unable to handle autograd"): 485*da0073e9SAndroid Build Coastguard Worker with ctx.Pool(3) as pool: 486*da0073e9SAndroid Build Coastguard Worker pool.map(simple_autograd_function, [1, 2, 3]) 487*da0073e9SAndroid Build Coastguard Worker else: 488*da0073e9SAndroid Build Coastguard Worker with ctx.Pool(3) as pool: 489*da0073e9SAndroid Build Coastguard Worker pool.map(simple_autograd_function, [1, 2, 3]) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 492*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, "Test needs to use spawn multiprocessing" 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker def test_autograd_fine_with_spawn(self): 495*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 496*da0073e9SAndroid Build Coastguard Worker simple_autograd_function() 497*da0073e9SAndroid Build Coastguard Worker with ctx.Pool(3) as pool: 498*da0073e9SAndroid Build Coastguard Worker pool.map(simple_autograd_function, [1, 2, 3]) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 501*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 502*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 503*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 504*da0073e9SAndroid Build Coastguard Worker ) 505*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 506*da0073e9SAndroid Build Coastguard Worker def test_cuda_simple(self): 507*da0073e9SAndroid Build Coastguard Worker torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker 508*da0073e9SAndroid Build Coastguard Worker self._test_sharing(mp.get_context("spawn"), "cuda", torch.float) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 511*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 512*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 513*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 514*da0073e9SAndroid Build Coastguard Worker ) 515*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 516*da0073e9SAndroid Build Coastguard Worker def test_cuda_memory_allocation(self): 517*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 518*da0073e9SAndroid Build Coastguard Worker q = ctx.Queue() 519*da0073e9SAndroid Build Coastguard Worker e = ctx.Event() 520*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 521*da0073e9SAndroid Build Coastguard Worker target=send_and_delete_tensors, args=(q, e, "cuda", torch.int, 5) 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker p.start() 524*da0073e9SAndroid Build Coastguard Worker t = [] 525*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 526*da0073e9SAndroid Build Coastguard Worker t.append(q.get()) 527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0], torch.full([5], 0, dtype=torch.int32)) 528*da0073e9SAndroid Build Coastguard Worker del t 529*da0073e9SAndroid Build Coastguard Worker e.set() 530*da0073e9SAndroid Build Coastguard Worker p.join(1) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 533*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 534*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 535*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 538*da0073e9SAndroid Build Coastguard Worker def test_cuda_ipc_deadlock(self): 539*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 540*da0073e9SAndroid Build Coastguard Worker queue = ctx.Queue(1) 541*da0073e9SAndroid Build Coastguard Worker processes = dict( 542*da0073e9SAndroid Build Coastguard Worker a=ctx.Process(target=_test_cuda_ipc_deadlock_actor, args=(queue, 100)), 543*da0073e9SAndroid Build Coastguard Worker l=ctx.Process(target=_test_cuda_ipc_deadlock_learner, args=(queue, 100)), 544*da0073e9SAndroid Build Coastguard Worker ) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker for p in processes.values(): 547*da0073e9SAndroid Build Coastguard Worker p.start() 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker for p in processes.values(): 550*da0073e9SAndroid Build Coastguard Worker p.join(10) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker for p in processes.values(): 553*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @slowTest 556*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 557*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 558*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 559*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 560*da0073e9SAndroid Build Coastguard Worker ) 561*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 562*da0073e9SAndroid Build Coastguard Worker def test_cuda_send_many(self, name=None, size=5, count=100000): 563*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 564*da0073e9SAndroid Build Coastguard Worker q1 = ctx.Queue() 565*da0073e9SAndroid Build Coastguard Worker q2 = ctx.Queue() 566*da0073e9SAndroid Build Coastguard Worker q3 = ctx.Queue() 567*da0073e9SAndroid Build Coastguard Worker e1 = ctx.Event() 568*da0073e9SAndroid Build Coastguard Worker e2 = ctx.Event() 569*da0073e9SAndroid Build Coastguard Worker e3 = ctx.Event() 570*da0073e9SAndroid Build Coastguard Worker p1 = ctx.Process( 571*da0073e9SAndroid Build Coastguard Worker target=send_and_delete_tensors, 572*da0073e9SAndroid Build Coastguard Worker args=(q1, e1, "cuda", torch.long, count, size), 573*da0073e9SAndroid Build Coastguard Worker ) 574*da0073e9SAndroid Build Coastguard Worker p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count)) 575*da0073e9SAndroid Build Coastguard Worker p3 = ctx.Process( 576*da0073e9SAndroid Build Coastguard Worker target=receive_and_send_sum, 577*da0073e9SAndroid Build Coastguard Worker args=(q2, q3, e3, "cuda", torch.long, count, size), 578*da0073e9SAndroid Build Coastguard Worker ) 579*da0073e9SAndroid Build Coastguard Worker p1.start() 580*da0073e9SAndroid Build Coastguard Worker p2.start() 581*da0073e9SAndroid Build Coastguard Worker p3.start() 582*da0073e9SAndroid Build Coastguard Worker result = q3.get() 583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[0], int(count * (count - 1) / 2)) 584*da0073e9SAndroid Build Coastguard Worker del result 585*da0073e9SAndroid Build Coastguard Worker e1.set() 586*da0073e9SAndroid Build Coastguard Worker e2.set() 587*da0073e9SAndroid Build Coastguard Worker e3.set() 588*da0073e9SAndroid Build Coastguard Worker p1.join(1) 589*da0073e9SAndroid Build Coastguard Worker p2.join(1) 590*da0073e9SAndroid Build Coastguard Worker p3.join(1) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 593*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 594*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 595*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 596*da0073e9SAndroid Build Coastguard Worker ) 597*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 598*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU") 599*da0073e9SAndroid Build Coastguard Worker def test_cuda_small_tensors(self): 600*da0073e9SAndroid Build Coastguard Worker # Check multiple small tensors which will likely use the same 601*da0073e9SAndroid Build Coastguard Worker # underlying cached allocation 602*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 603*da0073e9SAndroid Build Coastguard Worker tensors = [] 604*da0073e9SAndroid Build Coastguard Worker for i in range(5): 605*da0073e9SAndroid Build Coastguard Worker device = i % 2 606*da0073e9SAndroid Build Coastguard Worker tensors += [torch.arange(i * 5.0, (i + 1) * 5).cuda(device)] 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker inq = ctx.Queue() 609*da0073e9SAndroid Build Coastguard Worker outq = ctx.Queue() 610*da0073e9SAndroid Build Coastguard Worker inq.put(tensors) 611*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=sum_tensors, args=(inq, outq)) 612*da0073e9SAndroid Build Coastguard Worker p.start() 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker results = [] 615*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 616*da0073e9SAndroid Build Coastguard Worker results.append(outq.get()) 617*da0073e9SAndroid Build Coastguard Worker p.join() 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker for i, _tensor in enumerate(tensors): 620*da0073e9SAndroid Build Coastguard Worker v, device, tensor_size, storage_size = results[i] 621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, torch.arange(i * 5.0, (i + 1) * 5).sum()) 622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, i % 2) 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_size, 5) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker # You might think this should be the case, but it's not! After 626*da0073e9SAndroid Build Coastguard Worker # data from the CUDA caching allocator goes through IPC, the 627*da0073e9SAndroid Build Coastguard Worker # size of the storage is the size of the *cached cudaMalloc for 628*da0073e9SAndroid Build Coastguard Worker # the entire memory block* of the storage, not just the storage. 629*da0073e9SAndroid Build Coastguard Worker # See Note [CUDA IPC and the caching allocator] for more info 630*da0073e9SAndroid Build Coastguard Worker # 631*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(storage_size, 5) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker # Collect current process (producer) files, make sure nothing holds 634*da0073e9SAndroid Build Coastguard Worker # ref to the sent tensors 635*da0073e9SAndroid Build Coastguard Worker del _tensor 636*da0073e9SAndroid Build Coastguard Worker del tensors 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker # We need to collect, as CUDA MP implementation holds one shared 639*da0073e9SAndroid Build Coastguard Worker # memory 'file' for performance reason 640*da0073e9SAndroid Build Coastguard Worker torch.cuda.ipc_collect() 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)") 643*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 644*da0073e9SAndroid Build Coastguard Worker def test_cuda_bad_call(self): 645*da0073e9SAndroid Build Coastguard Worker # Initialize CUDA 646*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(5, 5).cuda().cpu() 647*da0073e9SAndroid Build Coastguard Worker inq = mp.Queue() 648*da0073e9SAndroid Build Coastguard Worker outq = mp.Queue() 649*da0073e9SAndroid Build Coastguard Worker p = mp.Process(target=queue_get_exception, args=(inq, outq)) 650*da0073e9SAndroid Build Coastguard Worker p.start() 651*da0073e9SAndroid Build Coastguard Worker inq.put(t) 652*da0073e9SAndroid Build Coastguard Worker p.join() 653*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(outq.get(), RuntimeError) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)") 656*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 657*da0073e9SAndroid Build Coastguard Worker def test_wrong_cuda_fork(self): 658*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr( 659*da0073e9SAndroid Build Coastguard Worker """\ 660*da0073e9SAndroid Build Coastguard Workerimport torch 661*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process 662*da0073e9SAndroid Build Coastguard Workerdef run(rank): 663*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(rank) 664*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 665*da0073e9SAndroid Build Coastguard Worker size = 2 666*da0073e9SAndroid Build Coastguard Worker processes = [] 667*da0073e9SAndroid Build Coastguard Worker for rank in range(size): 668*da0073e9SAndroid Build Coastguard Worker # it would work fine without the line below 669*da0073e9SAndroid Build Coastguard Worker x = torch.rand(20, 2).cuda() 670*da0073e9SAndroid Build Coastguard Worker p = Process(target=run, args=(rank,)) 671*da0073e9SAndroid Build Coastguard Worker p.start() 672*da0073e9SAndroid Build Coastguard Worker processes.append(p) 673*da0073e9SAndroid Build Coastguard Worker for p in processes: 674*da0073e9SAndroid Build Coastguard Worker p.join() 675*da0073e9SAndroid Build Coastguard Worker""" 676*da0073e9SAndroid Build Coastguard Worker ) 677*da0073e9SAndroid Build Coastguard Worker self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.") 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 680*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 681*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 682*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 683*da0073e9SAndroid Build Coastguard Worker ) 684*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 685*da0073e9SAndroid Build Coastguard Worker def test_event(self): 686*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 687*da0073e9SAndroid Build Coastguard Worker queue = ctx.Queue() 688*da0073e9SAndroid Build Coastguard Worker ready = ctx.Event() 689*da0073e9SAndroid Build Coastguard Worker done = ctx.Event() 690*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=cuda_multiply_two, args=(queue, ready, done)) 691*da0073e9SAndroid Build Coastguard Worker p.start() 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker ready.wait() 694*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(torch.cuda.Stream()): 695*da0073e9SAndroid Build Coastguard Worker tensor = torch.cuda.FloatTensor([1, 1, 1, 1]) 696*da0073e9SAndroid Build Coastguard Worker # Use a sleep kernel to test events. Without the event, the 697*da0073e9SAndroid Build Coastguard Worker # multiply happens before the add. 698*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event(interprocess=True) 699*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(20000000) # about 30 ms 700*da0073e9SAndroid Build Coastguard Worker tensor.add_(1) 701*da0073e9SAndroid Build Coastguard Worker event.record() 702*da0073e9SAndroid Build Coastguard Worker queue.put((event, tensor)) 703*da0073e9SAndroid Build Coastguard Worker done.wait() # must wait until subprocess records event 704*da0073e9SAndroid Build Coastguard Worker event.synchronize() 705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(tensor), [4, 4, 4, 4]) 706*da0073e9SAndroid Build Coastguard Worker p.join() 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker @staticmethod 709*da0073e9SAndroid Build Coastguard Worker def _test_event_multiprocess_child(event, p2c, c2p): 710*da0073e9SAndroid Build Coastguard Worker c2p.put(0) # notify parent child is ready 711*da0073e9SAndroid Build Coastguard Worker p2c.get() # wait for record in parent 712*da0073e9SAndroid Build Coastguard Worker event.synchronize() 713*da0073e9SAndroid Build Coastguard Worker c2p.put(1) # notify parent synchronization is done 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 716*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 717*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 718*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 719*da0073e9SAndroid Build Coastguard Worker ) 720*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 721*da0073e9SAndroid Build Coastguard Worker def test_event_multiprocess(self): 722*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event(enable_timing=False, interprocess=True) 723*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 726*da0073e9SAndroid Build Coastguard Worker p2c = ctx.SimpleQueue() 727*da0073e9SAndroid Build Coastguard Worker c2p = ctx.SimpleQueue() 728*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 729*da0073e9SAndroid Build Coastguard Worker target=TestMultiprocessing._test_event_multiprocess_child, 730*da0073e9SAndroid Build Coastguard Worker args=(event, p2c, c2p), 731*da0073e9SAndroid Build Coastguard Worker ) 732*da0073e9SAndroid Build Coastguard Worker p.start() 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker c2p.get() # wait for until child process is ready 735*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(50000000) # spin for about 50 ms 736*da0073e9SAndroid Build Coastguard Worker event.record() 737*da0073e9SAndroid Build Coastguard Worker p2c.put(0) # notify child event is recorded 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker self.assertFalse(event.query()) 740*da0073e9SAndroid Build Coastguard Worker c2p.get() # wait for synchronization in child 741*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.query()) 742*da0073e9SAndroid Build Coastguard Worker p.join() 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 745*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 746*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 747*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 750*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU") 751*da0073e9SAndroid Build Coastguard Worker def test_event_handle_multi_gpu(self): 752*da0073e9SAndroid Build Coastguard Worker d0 = torch.device("cuda:0") 753*da0073e9SAndroid Build Coastguard Worker d1 = torch.device("cuda:1") 754*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 755*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event(enable_timing=False, interprocess=True) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 758*da0073e9SAndroid Build Coastguard Worker # create handle on different device from un-recorded event 759*da0073e9SAndroid Build Coastguard Worker e0.ipc_handle() 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d0): 762*da0073e9SAndroid Build Coastguard Worker e1 = torch.cuda.Event(enable_timing=False, interprocess=True) 763*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 764*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(50000000) # spin for about 50 ms 765*da0073e9SAndroid Build Coastguard Worker e1.record(stream) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(d1): 768*da0073e9SAndroid Build Coastguard Worker # create handle on different device from recorded event 769*da0073e9SAndroid Build Coastguard Worker e1.ipc_handle() 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker @staticmethod 772*da0073e9SAndroid Build Coastguard Worker def _test_event_handle_importer_consumer(handle, p2c, c2p): 773*da0073e9SAndroid Build Coastguard Worker e1 = torch.cuda.Event.from_ipc_handle(0, handle) 774*da0073e9SAndroid Build Coastguard Worker c2p.put(0) # notify parent child is ready 775*da0073e9SAndroid Build Coastguard Worker p2c.get() # wait for record in parent 776*da0073e9SAndroid Build Coastguard Worker e1.synchronize() 777*da0073e9SAndroid Build Coastguard Worker c2p.put(1) # notify synchronization is done in child 778*da0073e9SAndroid Build Coastguard Worker p2c.get() # wait for parent to finish before destructing child event 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 781*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 782*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 783*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 784*da0073e9SAndroid Build Coastguard Worker ) 785*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 786*da0073e9SAndroid Build Coastguard Worker def test_event_handle_importer(self): 787*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event(enable_timing=False, interprocess=True) 788*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 791*da0073e9SAndroid Build Coastguard Worker p2c = ctx.SimpleQueue() 792*da0073e9SAndroid Build Coastguard Worker c2p = ctx.SimpleQueue() 793*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 794*da0073e9SAndroid Build Coastguard Worker target=TestMultiprocessing._test_event_handle_importer_consumer, 795*da0073e9SAndroid Build Coastguard Worker args=(e0.ipc_handle(), p2c, c2p), 796*da0073e9SAndroid Build Coastguard Worker ) 797*da0073e9SAndroid Build Coastguard Worker p.start() 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker c2p.get() # wait for child to become ready 800*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(50000000) # spin for about 50 ms 801*da0073e9SAndroid Build Coastguard Worker e0.record() 802*da0073e9SAndroid Build Coastguard Worker p2c.put(0) # notify child event is recorded 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e0.query()) 805*da0073e9SAndroid Build Coastguard Worker c2p.get() # wait for synchronization in child 806*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 807*da0073e9SAndroid Build Coastguard Worker p2c.put(1) # notify child that parent is done 808*da0073e9SAndroid Build Coastguard Worker p.join() 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker @staticmethod 811*da0073e9SAndroid Build Coastguard Worker def _test_event_handle_exporter_consumer(handle, p2c, c2p): 812*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 813*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 814*da0073e9SAndroid Build Coastguard Worker e1 = torch.cuda.Event.from_ipc_handle(torch.cuda.current_device(), handle) 815*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(50000000) # spin for about 50 ms 816*da0073e9SAndroid Build Coastguard Worker e1.record() 817*da0073e9SAndroid Build Coastguard Worker c2p.put(0) 818*da0073e9SAndroid Build Coastguard Worker # wait for parent process finished synchronization before 819*da0073e9SAndroid Build Coastguard Worker # destructing e1 820*da0073e9SAndroid Build Coastguard Worker p2c.get() 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 823*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 824*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 825*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 826*da0073e9SAndroid Build Coastguard Worker ) 827*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 828*da0073e9SAndroid Build Coastguard Worker def test_event_handle_exporter(self): 829*da0073e9SAndroid Build Coastguard Worker e0 = torch.cuda.Event(enable_timing=False, interprocess=True) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 832*da0073e9SAndroid Build Coastguard Worker p2c = ctx.SimpleQueue() 833*da0073e9SAndroid Build Coastguard Worker c2p = ctx.SimpleQueue() 834*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 835*da0073e9SAndroid Build Coastguard Worker target=TestMultiprocessing._test_event_handle_exporter_consumer, 836*da0073e9SAndroid Build Coastguard Worker args=(e0.ipc_handle(), p2c, c2p), 837*da0073e9SAndroid Build Coastguard Worker ) 838*da0073e9SAndroid Build Coastguard Worker p.start() 839*da0073e9SAndroid Build Coastguard Worker # wait for event in child process is recorded 840*da0073e9SAndroid Build Coastguard Worker c2p.get() 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e0.query()) 843*da0073e9SAndroid Build Coastguard Worker e0.synchronize() 844*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e0.query()) 845*da0073e9SAndroid Build Coastguard Worker p2c.put(0) 846*da0073e9SAndroid Build Coastguard Worker p.join() 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def _test_empty_tensor_sharing(self, dtype, device): 849*da0073e9SAndroid Build Coastguard Worker q = mp.Queue() 850*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([], dtype=dtype, device=device) 851*da0073e9SAndroid Build Coastguard Worker q.put(empty) 852*da0073e9SAndroid Build Coastguard Worker out = q.get(timeout=1) 853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, empty) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_sharing(self): 856*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.float32, torch.device("cpu")) 857*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.int64, torch.device("cpu")) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 860*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_sharing_cuda(self): 861*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.float32, torch.device("cuda")) 862*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.int64, torch.device("cuda")) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_sharing_meta(self): 865*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.float32, torch.device("meta")) 866*da0073e9SAndroid Build Coastguard Worker self._test_empty_tensor_sharing(torch.int64, torch.device("meta")) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker def test_tensor_sharing_meta(self): 869*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 870*da0073e9SAndroid Build Coastguard Worker device = torch.device("meta") 871*da0073e9SAndroid Build Coastguard Worker q = mp.Queue() 872*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([1], dtype=dtype, device=device) 873*da0073e9SAndroid Build Coastguard Worker q.put(empty) 874*da0073e9SAndroid Build Coastguard Worker out = q.get(timeout=1) 875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, empty) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker def test_meta_simple(self): 878*da0073e9SAndroid Build Coastguard Worker self._test_sharing(mp.get_context("spawn"), "meta", torch.float) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False): 881*da0073e9SAndroid Build Coastguard Worker device = "cuda" if var.is_cuda else "cpu" 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker ready = ctx.Event() 884*da0073e9SAndroid Build Coastguard Worker master_modified = ctx.Event() 885*da0073e9SAndroid Build Coastguard Worker queue = ctx.Queue() 886*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 887*da0073e9SAndroid Build Coastguard Worker target=autograd_sharing, 888*da0073e9SAndroid Build Coastguard Worker args=(queue, ready, master_modified, device, is_parameter), 889*da0073e9SAndroid Build Coastguard Worker ) 890*da0073e9SAndroid Build Coastguard Worker p.daemon = True 891*da0073e9SAndroid Build Coastguard Worker p.start() 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker # This would cause an error if we tried to serialize the hooks, 894*da0073e9SAndroid Build Coastguard Worker # because it's a closure and pickle doesn't support closures. 895*da0073e9SAndroid Build Coastguard Worker @torch.utils.hooks.unserializable_hook 896*da0073e9SAndroid Build Coastguard Worker def hook(*unused): 897*da0073e9SAndroid Build Coastguard Worker pass 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker if var.requires_grad: 900*da0073e9SAndroid Build Coastguard Worker var.register_hook(hook) 901*da0073e9SAndroid Build Coastguard Worker var._grad = torch.zeros(5, 5, device=device) 902*da0073e9SAndroid Build Coastguard Worker queue.put(var) 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker ready.wait() 905*da0073e9SAndroid Build Coastguard Worker var.data[0, 0] = 1000 906*da0073e9SAndroid Build Coastguard Worker var.grad.data[:] = torch.ones(5, 5, device=device) * 4 907*da0073e9SAndroid Build Coastguard Worker master_modified.set() 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker worker_ok = queue.get() 910*da0073e9SAndroid Build Coastguard Worker self.assertTrue(worker_ok) 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var.data, torch.ones(5, 5, device=device)) 913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var.grad.data, torch.ones(5, 5, device=device) * 4) 914*da0073e9SAndroid Build Coastguard Worker p.join(100) 915*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker # Check sharing a cudaMalloc allocation with different types of storage. 918*da0073e9SAndroid Build Coastguard Worker # (Issue #11422) 919*da0073e9SAndroid Build Coastguard Worker def _test_mixed_types_cuda_sharing(self, ctx=mp): 920*da0073e9SAndroid Build Coastguard Worker all_ones = torch.ones(2, 2).float() 921*da0073e9SAndroid Build Coastguard Worker all_zeros = torch.zeros(2, 2).byte() 922*da0073e9SAndroid Build Coastguard Worker queue = ctx.Queue() 923*da0073e9SAndroid Build Coastguard Worker event = ctx.Event() 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=mixed_type_producer, args=(queue, event)) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker p.start() 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 930*da0073e9SAndroid Build Coastguard Worker float_tensor = queue.get() 931*da0073e9SAndroid Build Coastguard Worker byte_tensor = queue.get() 932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_tensor, all_ones) 933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byte_tensor, all_zeros) 934*da0073e9SAndroid Build Coastguard Worker del float_tensor, byte_tensor 935*da0073e9SAndroid Build Coastguard Worker event.set() 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker time.sleep(5) 938*da0073e9SAndroid Build Coastguard Worker p.join() 939*da0073e9SAndroid Build Coastguard Worker 940*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 941*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 942*da0073e9SAndroid Build Coastguard Worker "non-deterministically hangs with ASAN https://github.com/pytorch/pytorch/issues/94024", 943*da0073e9SAndroid Build Coastguard Worker ) 944*da0073e9SAndroid Build Coastguard Worker def test_variable_sharing(self): 945*da0073e9SAndroid Build Coastguard Worker for requires_grad in [True, False]: 946*da0073e9SAndroid Build Coastguard Worker var = torch.arange(1.0, 26).view(5, 5).requires_grad_(requires_grad) 947*da0073e9SAndroid Build Coastguard Worker self._test_autograd_sharing(var) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/14997 950*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "non-deterministically hangs with ASAN") 951*da0073e9SAndroid Build Coastguard Worker def test_leaf_variable_sharing(self): 952*da0073e9SAndroid Build Coastguard Worker devices = ["cpu"] 953*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available() and not NO_MULTIPROCESSING_SPAWN and TEST_CUDA_IPC: 954*da0073e9SAndroid Build Coastguard Worker devices.append("cuda") 955*da0073e9SAndroid Build Coastguard Worker for device in devices: 956*da0073e9SAndroid Build Coastguard Worker for requires_grad in [True, False]: 957*da0073e9SAndroid Build Coastguard Worker var = ( 958*da0073e9SAndroid Build Coastguard Worker torch.arange(1.0, 26, device=device) 959*da0073e9SAndroid Build Coastguard Worker .view(5, 5) 960*da0073e9SAndroid Build Coastguard Worker .requires_grad_(requires_grad) 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker self.assertTrue(var.is_leaf) 963*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") if device == "cuda" else mp 964*da0073e9SAndroid Build Coastguard Worker ready = ctx.Event() 965*da0073e9SAndroid Build Coastguard Worker queue = ctx.Queue() 966*da0073e9SAndroid Build Coastguard Worker p = ctx.Process( 967*da0073e9SAndroid Build Coastguard Worker target=requires_grad_variable_sharing, args=(queue, ready) 968*da0073e9SAndroid Build Coastguard Worker ) 969*da0073e9SAndroid Build Coastguard Worker p.daemon = True 970*da0073e9SAndroid Build Coastguard Worker p.start() 971*da0073e9SAndroid Build Coastguard Worker queue.put(var) 972*da0073e9SAndroid Build Coastguard Worker ready.wait() 973*da0073e9SAndroid Build Coastguard Worker worker_requires_grad = queue.get() 974*da0073e9SAndroid Build Coastguard Worker self.assertTrue(worker_requires_grad == requires_grad) 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker def test_non_leaf_variable_sharing(self): 977*da0073e9SAndroid Build Coastguard Worker devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] 978*da0073e9SAndroid Build Coastguard Worker for device in devices: 979*da0073e9SAndroid Build Coastguard Worker var0 = torch.arange(1.0, 26, device=device).view(5, 5).requires_grad_(True) 980*da0073e9SAndroid Build Coastguard Worker var = var0 * 2 981*da0073e9SAndroid Build Coastguard Worker # Don't use a regular Queue; it uses a background thread (which 982*da0073e9SAndroid Build Coastguard Worker # means we can't catch the exceptions) 983*da0073e9SAndroid Build Coastguard Worker queue = mp.SimpleQueue() 984*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 985*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"requires_grad", lambda: queue.put(var) 986*da0073e9SAndroid Build Coastguard Worker ) 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 989*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 990*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 991*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 992*da0073e9SAndroid Build Coastguard Worker ) 993*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 994*da0073e9SAndroid Build Coastguard Worker def test_cuda_variable_sharing(self): 995*da0073e9SAndroid Build Coastguard Worker for requires_grad in [True, False]: 996*da0073e9SAndroid Build Coastguard Worker var = ( 997*da0073e9SAndroid Build Coastguard Worker torch.arange(1.0, 26, device="cuda") 998*da0073e9SAndroid Build Coastguard Worker .view(5, 5) 999*da0073e9SAndroid Build Coastguard Worker .requires_grad_(requires_grad) 1000*da0073e9SAndroid Build Coastguard Worker ) 1001*da0073e9SAndroid Build Coastguard Worker self._test_autograd_sharing(var, mp.get_context("spawn")) 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1004*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 1005*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 1006*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 1007*da0073e9SAndroid Build Coastguard Worker ) 1008*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1009*da0073e9SAndroid Build Coastguard Worker def test_mixed_types_cuda_sharing(self): 1010*da0073e9SAndroid Build Coastguard Worker self._test_mixed_types_cuda_sharing(mp.get_context("spawn")) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker def test_parameter_sharing(self): 1013*da0073e9SAndroid Build Coastguard Worker param = Parameter(torch.arange(1.0, 26).view(5, 5)) 1014*da0073e9SAndroid Build Coastguard Worker self._test_autograd_sharing(param, is_parameter=True) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1017*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 1018*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 1019*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 1020*da0073e9SAndroid Build Coastguard Worker ) 1021*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1022*da0073e9SAndroid Build Coastguard Worker def test_cuda_parameter_sharing(self): 1023*da0073e9SAndroid Build Coastguard Worker param = Parameter(torch.arange(1.0, 26, device="cuda").view(5, 5)) 1024*da0073e9SAndroid Build Coastguard Worker self._test_autograd_sharing(param, mp.get_context("spawn"), is_parameter=True) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1027*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 1028*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 1029*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 1030*da0073e9SAndroid Build Coastguard Worker ) 1031*da0073e9SAndroid Build Coastguard Worker def test_integer_parameter_serialization_cpu(self): 1032*da0073e9SAndroid Build Coastguard Worker self._test_integer_parameter_serialization(device="cpu") 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1035*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 1036*da0073e9SAndroid Build Coastguard Worker "Disabled for environments that \ 1037*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method", 1038*da0073e9SAndroid Build Coastguard Worker ) 1039*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1040*da0073e9SAndroid Build Coastguard Worker def test_integer_parameter_serialization_cuda(self): 1041*da0073e9SAndroid Build Coastguard Worker self._test_integer_parameter_serialization(device="cuda") 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker def _test_integer_parameter_serialization(self, device): 1044*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter( 1045*da0073e9SAndroid Build Coastguard Worker torch.tensor(0, dtype=torch.int64, device=device), requires_grad=False 1046*da0073e9SAndroid Build Coastguard Worker ) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context("spawn") 1049*da0073e9SAndroid Build Coastguard Worker p = ctx.Process(target=integer_parameter_serialization, args=(param,)) 1050*da0073e9SAndroid Build Coastguard Worker p.start() 1051*da0073e9SAndroid Build Coastguard Worker p.join() 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1054*da0073e9SAndroid Build Coastguard Worker 0, 1055*da0073e9SAndroid Build Coastguard Worker p.exitcode, 1056*da0073e9SAndroid Build Coastguard Worker msg=f'Failed to serialize successfully for "{device}" device!', 1057*da0073e9SAndroid Build Coastguard Worker ) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker def test_empty_shared(self): 1060*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([]) 1061*da0073e9SAndroid Build Coastguard Worker t.share_memory_() 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker def _test_is_shared(self): 1064*da0073e9SAndroid Build Coastguard Worker t = torch.randn(5, 5) 1065*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.is_shared()) 1066*da0073e9SAndroid Build Coastguard Worker t.share_memory_() 1067*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_shared()) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1070*da0073e9SAndroid Build Coastguard Worker platform == "darwin", "file descriptor strategy is not supported on macOS" 1071*da0073e9SAndroid Build Coastguard Worker ) 1072*da0073e9SAndroid Build Coastguard Worker def test_is_shared(self): 1073*da0073e9SAndroid Build Coastguard Worker self._test_is_shared() 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker def test_fs_is_shared(self): 1076*da0073e9SAndroid Build Coastguard Worker with fs_sharing(): 1077*da0073e9SAndroid Build Coastguard Worker self._test_is_shared() 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 1080*da0073e9SAndroid Build Coastguard Worker def test_is_shared_cuda(self): 1081*da0073e9SAndroid Build Coastguard Worker t = torch.randn(5, 5).cuda() 1082*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_shared()) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.platform != "linux", "Only runs on Linux; requires prctl(2)") 1085*da0073e9SAndroid Build Coastguard Worker def test_set_thread_name(self): 1086*da0073e9SAndroid Build Coastguard Worker name = "test name" 1087*da0073e9SAndroid Build Coastguard Worker mp._set_thread_name(name) 1088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mp._get_thread_name(), name) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1092*da0073e9SAndroid Build Coastguard Worker run_tests() 1093