xref: /aosp_15_r20/external/pytorch/test/test_multiprocessing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: multiprocessing"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport copy
5*da0073e9SAndroid Build Coastguard Workerimport gc
6*da0073e9SAndroid Build Coastguard Workerimport os
7*da0073e9SAndroid Build Coastguard Workerimport sys
8*da0073e9SAndroid Build Coastguard Workerimport time
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Workerfrom sys import platform
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
14*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp
15*da0073e9SAndroid Build Coastguard Workerimport torch.utils.hooks
16*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import IS_JETSON
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
19*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
20*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
21*da0073e9SAndroid Build Coastguard Worker    load_tests,
22*da0073e9SAndroid Build Coastguard Worker    NO_MULTIPROCESSING_SPAWN,
23*da0073e9SAndroid Build Coastguard Worker    run_tests,
24*da0073e9SAndroid Build Coastguard Worker    slowTest,
25*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
26*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
27*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
28*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
29*da0073e9SAndroid Build Coastguard Worker    TestCase,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
34*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
35*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard WorkerTEST_REPEATS = 30
38*da0073e9SAndroid Build Coastguard WorkerHAS_SHM_FILES = os.path.isdir("/dev/shm")
39*da0073e9SAndroid Build Coastguard WorkerMAX_WAITING_TIME_IN_SECONDS = 30
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerTEST_CUDA_IPC = (
42*da0073e9SAndroid Build Coastguard Worker    torch.cuda.is_available()
43*da0073e9SAndroid Build Coastguard Worker    and sys.platform != "darwin"
44*da0073e9SAndroid Build Coastguard Worker    and sys.platform != "win32"
45*da0073e9SAndroid Build Coastguard Worker    and not IS_JETSON
46*da0073e9SAndroid Build Coastguard Worker    and not TEST_WITH_ROCM
47*da0073e9SAndroid Build Coastguard Worker)  # https://github.com/pytorch/pytorch/issues/90940
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerclass SubProcess(mp.Process):
53*da0073e9SAndroid Build Coastguard Worker    def __init__(self, tensor):
54*da0073e9SAndroid Build Coastguard Worker        super().__init__()
55*da0073e9SAndroid Build Coastguard Worker        self.tensor = tensor
56*da0073e9SAndroid Build Coastguard Worker        self.daemon = True
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def run(self):
59*da0073e9SAndroid Build Coastguard Worker        self.tensor.add_(3)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Workerdef _test_cuda_ipc_deadlock_actor(queue, iterations):
63*da0073e9SAndroid Build Coastguard Worker    for i in range(iterations):
64*da0073e9SAndroid Build Coastguard Worker        if not queue.empty():
65*da0073e9SAndroid Build Coastguard Worker            queue.get()
66*da0073e9SAndroid Build Coastguard Worker        time.sleep(0.01)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef _test_cuda_ipc_deadlock_learner(queue, iterations):
70*da0073e9SAndroid Build Coastguard Worker    net = torch.nn.LSTM(1, 1).cuda()
71*da0073e9SAndroid Build Coastguard Worker    for i in range(iterations):
72*da0073e9SAndroid Build Coastguard Worker        if not queue.full():
73*da0073e9SAndroid Build Coastguard Worker            queue.put(copy.deepcopy(net.state_dict()))
74*da0073e9SAndroid Build Coastguard Worker        time.sleep(0.01)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerdef simple_fill(queue, event):
78*da0073e9SAndroid Build Coastguard Worker    data = queue.get()
79*da0073e9SAndroid Build Coastguard Worker    data[0][:] = 4
80*da0073e9SAndroid Build Coastguard Worker    event.set()
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerdef simple_pool_fill(tensor):
84*da0073e9SAndroid Build Coastguard Worker    tensor.fill_(4)
85*da0073e9SAndroid Build Coastguard Worker    return tensor.add(1)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workerdef send_tensor(queue, event, device, dtype):
89*da0073e9SAndroid Build Coastguard Worker    t = torch.ones(5, 5, device=device, dtype=dtype)
90*da0073e9SAndroid Build Coastguard Worker    queue.put(t)
91*da0073e9SAndroid Build Coastguard Worker    queue.put(t)
92*da0073e9SAndroid Build Coastguard Worker    event.wait()
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerdef send_and_delete_tensors(queue, event, device, dtype, count, size=5):
96*da0073e9SAndroid Build Coastguard Worker    for i in range(count):
97*da0073e9SAndroid Build Coastguard Worker        t = torch.full([size], i, device=device, dtype=dtype)
98*da0073e9SAndroid Build Coastguard Worker        queue.put(t)
99*da0073e9SAndroid Build Coastguard Worker        del t
100*da0073e9SAndroid Build Coastguard Worker    event.wait()
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Workerdef receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5):
104*da0073e9SAndroid Build Coastguard Worker    s = torch.full([size], 0, device=device, dtype=dtype)
105*da0073e9SAndroid Build Coastguard Worker    for i in range(count):
106*da0073e9SAndroid Build Coastguard Worker        t = queue.get()
107*da0073e9SAndroid Build Coastguard Worker        s += t
108*da0073e9SAndroid Build Coastguard Worker    out_queue.put(s)
109*da0073e9SAndroid Build Coastguard Worker    event.wait()
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerdef receive_and_send(queue, out_queue, event, count):
113*da0073e9SAndroid Build Coastguard Worker    for i in range(count):
114*da0073e9SAndroid Build Coastguard Worker        t = queue.get()
115*da0073e9SAndroid Build Coastguard Worker        out_queue.put(t.clone())
116*da0073e9SAndroid Build Coastguard Worker    event.wait()
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerdef sum_tensors(inq, outq):
120*da0073e9SAndroid Build Coastguard Worker    with torch.cuda.device(1):
121*da0073e9SAndroid Build Coastguard Worker        tensors = inq.get()
122*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
123*da0073e9SAndroid Build Coastguard Worker            outq.put(
124*da0073e9SAndroid Build Coastguard Worker                (
125*da0073e9SAndroid Build Coastguard Worker                    tensor.sum().item(),
126*da0073e9SAndroid Build Coastguard Worker                    tensor.get_device(),
127*da0073e9SAndroid Build Coastguard Worker                    tensor.numel(),
128*da0073e9SAndroid Build Coastguard Worker                    tensor.storage().size(),
129*da0073e9SAndroid Build Coastguard Worker                )
130*da0073e9SAndroid Build Coastguard Worker            )
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Workerdef queue_get_exception(inqueue, outqueue):
134*da0073e9SAndroid Build Coastguard Worker    os.close(2)  # hide expected error message
135*da0073e9SAndroid Build Coastguard Worker    try:
136*da0073e9SAndroid Build Coastguard Worker        torch.zeros(5, 5).cuda()
137*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
138*da0073e9SAndroid Build Coastguard Worker        outqueue.put(e)
139*da0073e9SAndroid Build Coastguard Worker    else:
140*da0073e9SAndroid Build Coastguard Worker        outqueue.put("no exception")
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker# Multiply by two in a separate stream
144*da0073e9SAndroid Build Coastguard Workerdef cuda_multiply_two(queue, ready, done):
145*da0073e9SAndroid Build Coastguard Worker    ready.set()
146*da0073e9SAndroid Build Coastguard Worker    with torch.cuda.stream(torch.cuda.Stream()):
147*da0073e9SAndroid Build Coastguard Worker        cuda_event, tensor = queue.get()
148*da0073e9SAndroid Build Coastguard Worker        cuda_event.wait()
149*da0073e9SAndroid Build Coastguard Worker        tensor.mul_(2)
150*da0073e9SAndroid Build Coastguard Worker        cuda_event.record()
151*da0073e9SAndroid Build Coastguard Worker        done.set()
152*da0073e9SAndroid Build Coastguard Worker        del cuda_event
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef requires_grad_variable_sharing(queue, ready):
156*da0073e9SAndroid Build Coastguard Worker    var = queue.get()
157*da0073e9SAndroid Build Coastguard Worker    ready.set()
158*da0073e9SAndroid Build Coastguard Worker    queue.put(var.requires_grad)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Workerdef integer_parameter_serialization(iparam):
162*da0073e9SAndroid Build Coastguard Worker    iparam + 1
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerdef autograd_sharing(queue, ready, master_modified, device, is_parameter):
166*da0073e9SAndroid Build Coastguard Worker    var = queue.get()
167*da0073e9SAndroid Build Coastguard Worker    ready.set()
168*da0073e9SAndroid Build Coastguard Worker    master_modified.wait()
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    expected_var = torch.arange(1.0, 26, device=device).view(5, 5)
171*da0073e9SAndroid Build Coastguard Worker    expected_var[0, 0] = 1000
172*da0073e9SAndroid Build Coastguard Worker    is_ok = var.data.equal(expected_var)
173*da0073e9SAndroid Build Coastguard Worker    var.data[:] = torch.ones(5, 5, device=device)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    is_ok &= var.grad is None
176*da0073e9SAndroid Build Coastguard Worker    is_ok &= not var._backward_hooks
177*da0073e9SAndroid Build Coastguard Worker    if is_parameter:
178*da0073e9SAndroid Build Coastguard Worker        is_ok &= type(var) == Parameter
179*da0073e9SAndroid Build Coastguard Worker    else:
180*da0073e9SAndroid Build Coastguard Worker        is_ok &= type(var) == torch.Tensor
181*da0073e9SAndroid Build Coastguard Worker    var._grad = torch.ones(5, 5, device=device)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    queue.put(is_ok)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef mixed_type_producer(queue, event):
187*da0073e9SAndroid Build Coastguard Worker    for _ in range(10):
188*da0073e9SAndroid Build Coastguard Worker        float_tensor = torch.ones(2, 2).float().cuda()
189*da0073e9SAndroid Build Coastguard Worker        byte_tensor = torch.zeros(2, 2).byte().cuda()
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        queue.put(float_tensor)
192*da0073e9SAndroid Build Coastguard Worker        queue.put(byte_tensor)
193*da0073e9SAndroid Build Coastguard Worker        event.wait()
194*da0073e9SAndroid Build Coastguard Worker        event.clear()
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Workerdef simple_autograd_function(a=1):
198*da0073e9SAndroid Build Coastguard Worker    torch.rand(3).requires_grad_(True).mean().backward()
199*da0073e9SAndroid Build Coastguard Worker    return a**2
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
203*da0073e9SAndroid Build Coastguard Workerdef fs_sharing():
204*da0073e9SAndroid Build Coastguard Worker    prev_strategy = mp.get_sharing_strategy()
205*da0073e9SAndroid Build Coastguard Worker    mp.set_sharing_strategy("file_system")
206*da0073e9SAndroid Build Coastguard Worker    try:
207*da0073e9SAndroid Build Coastguard Worker        yield
208*da0073e9SAndroid Build Coastguard Worker    finally:
209*da0073e9SAndroid Build Coastguard Worker        mp.set_sharing_strategy(prev_strategy)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Workerclass leak_checker:
213*da0073e9SAndroid Build Coastguard Worker    def __init__(self, test_case):
214*da0073e9SAndroid Build Coastguard Worker        self.checked_pids = [os.getpid()]
215*da0073e9SAndroid Build Coastguard Worker        self.test_case = test_case
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
218*da0073e9SAndroid Build Coastguard Worker        self.next_fds = self._get_next_fds(10)
219*da0073e9SAndroid Build Coastguard Worker        return self
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args):
222*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
223*da0073e9SAndroid Build Coastguard Worker            torch.cuda.ipc_collect()
224*da0073e9SAndroid Build Coastguard Worker        if args[0] is None:
225*da0073e9SAndroid Build Coastguard Worker            # Check that the 10th available file-descriptor at the end of the
226*da0073e9SAndroid Build Coastguard Worker            # test is no more than 4 higher than the 10th available at the
227*da0073e9SAndroid Build Coastguard Worker            # start. This attempts to catch file descriptor leaks, but allows
228*da0073e9SAndroid Build Coastguard Worker            # one-off initialization that may use up a file descriptor
229*da0073e9SAndroid Build Coastguard Worker            # TODO: Disabled because this check is too flaky
230*da0073e9SAndroid Build Coastguard Worker            # available_fds = self._get_next_fds(10)
231*da0073e9SAndroid Build Coastguard Worker            # self.test_case.assertLessEqual(
232*da0073e9SAndroid Build Coastguard Worker            #     available_fds[-1] - self.next_fds[-1], 5)
233*da0073e9SAndroid Build Coastguard Worker            self.test_case.assertFalse(self.has_shm_files())
234*da0073e9SAndroid Build Coastguard Worker        return False
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    def check_pid(self, pid):
237*da0073e9SAndroid Build Coastguard Worker        self.checked_pids.append(pid)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    def _get_next_fds(self, n=1):
240*da0073e9SAndroid Build Coastguard Worker        # dup uses the lowest-numbered unused descriptor for the new descriptor
241*da0073e9SAndroid Build Coastguard Worker        fds = [os.dup(0) for i in range(n)]
242*da0073e9SAndroid Build Coastguard Worker        for fd in fds:
243*da0073e9SAndroid Build Coastguard Worker            os.close(fd)
244*da0073e9SAndroid Build Coastguard Worker        return fds
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    def has_shm_files(self, wait=True):
247*da0073e9SAndroid Build Coastguard Worker        if not HAS_SHM_FILES:
248*da0073e9SAndroid Build Coastguard Worker            return False
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        result = self._has_shm_files()
251*da0073e9SAndroid Build Coastguard Worker        if not result or mp.get_sharing_strategy() != "file_system" or not wait:
252*da0073e9SAndroid Build Coastguard Worker            return result
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        total_waiting_time = 0
255*da0073e9SAndroid Build Coastguard Worker        waiting_time = 0.5
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and result:
258*da0073e9SAndroid Build Coastguard Worker            time.sleep(waiting_time)
259*da0073e9SAndroid Build Coastguard Worker            total_waiting_time += waiting_time
260*da0073e9SAndroid Build Coastguard Worker            result = self._has_shm_files()
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        return result
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def _has_shm_files(self):
265*da0073e9SAndroid Build Coastguard Worker        gc.collect()
266*da0073e9SAndroid Build Coastguard Worker        names = ["torch_" + str(pid) for pid in self.checked_pids]
267*da0073e9SAndroid Build Coastguard Worker        for filename in os.listdir("/dev/shm"):
268*da0073e9SAndroid Build Coastguard Worker            for name in names:
269*da0073e9SAndroid Build Coastguard Worker                if filename.startswith(name):
270*da0073e9SAndroid Build Coastguard Worker                    return True
271*da0073e9SAndroid Build Coastguard Worker        return False
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
275*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
276*da0073e9SAndroid Build Coastguard Worker    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
277*da0073e9SAndroid Build Coastguard Worker)
278*da0073e9SAndroid Build Coastguard Workerclass TestMultiprocessing(TestCase):
279*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
280*da0073e9SAndroid Build Coastguard Worker        # This will keep tests isolated from each-other
281*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
282*da0073e9SAndroid Build Coastguard Worker            torch.cuda.ipc_collect()
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def _test_sharing(self, ctx=mp, device="cpu", dtype=torch.float, repeat=1):
285*da0073e9SAndroid Build Coastguard Worker        def test_fill():
286*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(5, 5).to(device, dtype)
287*da0073e9SAndroid Build Coastguard Worker            q = ctx.Queue()
288*da0073e9SAndroid Build Coastguard Worker            e = ctx.Event()
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker            data = [x, x[:, 1]]
291*da0073e9SAndroid Build Coastguard Worker            q.put(data)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker            p = ctx.Process(target=simple_fill, args=(q, e))
294*da0073e9SAndroid Build Coastguard Worker            p.daemon = True
295*da0073e9SAndroid Build Coastguard Worker            lc.check_pid(p.pid)
296*da0073e9SAndroid Build Coastguard Worker            p.start()
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker            total_waiting_time = 0
299*da0073e9SAndroid Build Coastguard Worker            waiting_time = 0.5
300*da0073e9SAndroid Build Coastguard Worker            is_set = False
301*da0073e9SAndroid Build Coastguard Worker            # Once the child process is done, it will set the event to notify the
302*da0073e9SAndroid Build Coastguard Worker            # parent accordingly
303*da0073e9SAndroid Build Coastguard Worker            while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and not is_set:
304*da0073e9SAndroid Build Coastguard Worker                time.sleep(waiting_time)
305*da0073e9SAndroid Build Coastguard Worker                total_waiting_time += waiting_time
306*da0073e9SAndroid Build Coastguard Worker                is_set = e.is_set()
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_set)
309*da0073e9SAndroid Build Coastguard Worker            if device != "meta":
310*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(data[0].eq(4).all())
311*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(data[1].eq(4).all())
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker            p.join(100)
314*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        def test_receive():
317*da0073e9SAndroid Build Coastguard Worker            q = ctx.Queue()
318*da0073e9SAndroid Build Coastguard Worker            e = ctx.Event()
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker            p = ctx.Process(target=send_tensor, args=(q, e, device, dtype))
321*da0073e9SAndroid Build Coastguard Worker            p.daemon = True
322*da0073e9SAndroid Build Coastguard Worker            lc.check_pid(p.pid)
323*da0073e9SAndroid Build Coastguard Worker            p.start()
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker            t1 = q.get()
326*da0073e9SAndroid Build Coastguard Worker            t2 = q.get()
327*da0073e9SAndroid Build Coastguard Worker            if device == "meta":
328*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t1.size(), t2.size())
329*da0073e9SAndroid Build Coastguard Worker            else:
330*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t1.eq(1).all())
331*da0073e9SAndroid Build Coastguard Worker            s1 = t1.storage()
332*da0073e9SAndroid Build Coastguard Worker            s2 = t2.storage()
333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(type(s1), type(s2))
334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s1.data_ptr(), s1.data_ptr())
335*da0073e9SAndroid Build Coastguard Worker            if device == "meta":
336*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s1.size(), s2.size())
337*da0073e9SAndroid Build Coastguard Worker            else:
338*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s1, s2)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker            # We need to delete this tensors to allow producer (child process)
341*da0073e9SAndroid Build Coastguard Worker            # collect them properly
342*da0073e9SAndroid Build Coastguard Worker            del t1, t2
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            # Mark the event as done and join the process
345*da0073e9SAndroid Build Coastguard Worker            e.set()
346*da0073e9SAndroid Build Coastguard Worker            p.join(100)
347*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        with leak_checker(self) as lc:
350*da0073e9SAndroid Build Coastguard Worker            for _ in range(repeat):
351*da0073e9SAndroid Build Coastguard Worker                test_fill()
352*da0073e9SAndroid Build Coastguard Worker                test_receive()
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker    def _test_preserve_sharing(self, ctx=mp, repeat=1):
355*da0073e9SAndroid Build Coastguard Worker        def do_test():
356*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 5)
357*da0073e9SAndroid Build Coastguard Worker            data = [x.storage(), x, x[2], x[:, 1]]
358*da0073e9SAndroid Build Coastguard Worker            q = ctx.Queue()
359*da0073e9SAndroid Build Coastguard Worker            q.put(data)
360*da0073e9SAndroid Build Coastguard Worker            new_data = q.get(timeout=1)
361*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(new_data, data, atol=0, rtol=0)
362*da0073e9SAndroid Build Coastguard Worker            storage_cdata = data[0]._cdata
363*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(new_data[0]._cdata, storage_cdata)
364*da0073e9SAndroid Build Coastguard Worker            for t in new_data[1:]:
365*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.storage()._cdata, storage_cdata)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        with leak_checker(self):
368*da0073e9SAndroid Build Coastguard Worker            for _ in range(repeat):
369*da0073e9SAndroid Build Coastguard Worker                do_test()
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    def _test_pool(self, ctx=mp, repeat=1):
372*da0073e9SAndroid Build Coastguard Worker        def do_test():
373*da0073e9SAndroid Build Coastguard Worker            p = ctx.Pool(2)
374*da0073e9SAndroid Build Coastguard Worker            for proc in p._pool:
375*da0073e9SAndroid Build Coastguard Worker                lc.check_pid(proc.pid)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker            buffers = [torch.zeros(2, 2) for i in range(4)]
378*da0073e9SAndroid Build Coastguard Worker            results = p.map(simple_pool_fill, buffers, 1)
379*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(results), len(buffers))
380*da0073e9SAndroid Build Coastguard Worker            for r in results:
381*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r, torch.ones(2, 2) * 5, atol=0, rtol=0)
382*da0073e9SAndroid Build Coastguard Worker            for b in buffers:
383*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b, torch.ones(2, 2) * 4, atol=0, rtol=0)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker            p.close()
386*da0073e9SAndroid Build Coastguard Worker            p.join()
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        with leak_checker(self) as lc:
389*da0073e9SAndroid Build Coastguard Worker            for _ in range(repeat):
390*da0073e9SAndroid Build Coastguard Worker                do_test()
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
393*da0073e9SAndroid Build Coastguard Worker        platform == "darwin", "file descriptor strategy is not supported on macOS"
394*da0073e9SAndroid Build Coastguard Worker    )
395*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
396*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ASAN,
397*da0073e9SAndroid Build Coastguard Worker        "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
398*da0073e9SAndroid Build Coastguard Worker    )
399*da0073e9SAndroid Build Coastguard Worker    def test_fd_sharing(self):
400*da0073e9SAndroid Build Coastguard Worker        self._test_sharing(repeat=TEST_REPEATS)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
403*da0073e9SAndroid Build Coastguard Worker        platform == "darwin", "file descriptor strategy is not supported on macOS"
404*da0073e9SAndroid Build Coastguard Worker    )
405*da0073e9SAndroid Build Coastguard Worker    def test_fd_preserve_sharing(self):
406*da0073e9SAndroid Build Coastguard Worker        self._test_preserve_sharing(repeat=TEST_REPEATS)
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
409*da0073e9SAndroid Build Coastguard Worker        platform == "darwin", "file descriptor strategy is not supported on macOS"
410*da0073e9SAndroid Build Coastguard Worker    )
411*da0073e9SAndroid Build Coastguard Worker    def test_fd_pool(self):
412*da0073e9SAndroid Build Coastguard Worker        self._test_pool(repeat=TEST_REPEATS)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
415*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ASAN,
416*da0073e9SAndroid Build Coastguard Worker        "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
417*da0073e9SAndroid Build Coastguard Worker    )
418*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
419*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_TORCHDYNAMO,
420*da0073e9SAndroid Build Coastguard Worker        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
421*da0073e9SAndroid Build Coastguard Worker    )
422*da0073e9SAndroid Build Coastguard Worker    def test_fs_sharing(self):
423*da0073e9SAndroid Build Coastguard Worker        with fs_sharing():
424*da0073e9SAndroid Build Coastguard Worker            # The test works but is very slow on MacOS, see https://github.com/pytorch/pytorch/pull/93183,
425*da0073e9SAndroid Build Coastguard Worker            # so run it only once there. The delay is in waiting for the child process to terminate (join)
426*da0073e9SAndroid Build Coastguard Worker            repeat = 1 if IS_MACOS else TEST_REPEATS
427*da0073e9SAndroid Build Coastguard Worker            self._test_sharing(repeat=repeat)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
430*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_TORCHDYNAMO,
431*da0073e9SAndroid Build Coastguard Worker        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
432*da0073e9SAndroid Build Coastguard Worker    )
433*da0073e9SAndroid Build Coastguard Worker    def test_fs_preserve_sharing(self):
434*da0073e9SAndroid Build Coastguard Worker        with fs_sharing():
435*da0073e9SAndroid Build Coastguard Worker            self._test_preserve_sharing(repeat=TEST_REPEATS)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
438*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_TORCHDYNAMO,
439*da0073e9SAndroid Build Coastguard Worker        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
440*da0073e9SAndroid Build Coastguard Worker    )
441*da0073e9SAndroid Build Coastguard Worker    def test_fs_pool(self):
442*da0073e9SAndroid Build Coastguard Worker        with fs_sharing():
443*da0073e9SAndroid Build Coastguard Worker            self._test_pool(repeat=TEST_REPEATS)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist")
446*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
447*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_TORCHDYNAMO,
448*da0073e9SAndroid Build Coastguard Worker        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
449*da0073e9SAndroid Build Coastguard Worker    )
450*da0073e9SAndroid Build Coastguard Worker    def test_fs(self):
451*da0073e9SAndroid Build Coastguard Worker        def queue_put():
452*da0073e9SAndroid Build Coastguard Worker            x = torch.DoubleStorage(4)
453*da0073e9SAndroid Build Coastguard Worker            q = mp.Queue()
454*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(lc.has_shm_files())
455*da0073e9SAndroid Build Coastguard Worker            q.put(x)
456*da0073e9SAndroid Build Coastguard Worker            time.sleep(0.05)  # queue serializes asynchronously
457*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(lc.has_shm_files(wait=False))
458*da0073e9SAndroid Build Coastguard Worker            q.get()
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker        with fs_sharing(), leak_checker(self) as lc:
461*da0073e9SAndroid Build Coastguard Worker            for _ in range(TEST_REPEATS):
462*da0073e9SAndroid Build Coastguard Worker                queue_put()
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    def test_inherit_tensor(self):
465*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(5, 5)
466*da0073e9SAndroid Build Coastguard Worker        p = SubProcess(t.share_memory_())
467*da0073e9SAndroid Build Coastguard Worker        p.start()
468*da0073e9SAndroid Build Coastguard Worker        p.join(2)
469*da0073e9SAndroid Build Coastguard Worker        if p.exitcode is None:
470*da0073e9SAndroid Build Coastguard Worker            print("test_inherit_tensor: SubProcess too slow")
471*da0073e9SAndroid Build Coastguard Worker        else:
472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing")
475*da0073e9SAndroid Build Coastguard Worker    def test_autograd_errors(self):
476*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("fork")
477*da0073e9SAndroid Build Coastguard Worker        simple_autograd_function()
478*da0073e9SAndroid Build Coastguard Worker        # Autograd only uses thread when GPUs are involved
479*da0073e9SAndroid Build Coastguard Worker        if (
480*da0073e9SAndroid Build Coastguard Worker            torch.cuda.is_available()
481*da0073e9SAndroid Build Coastguard Worker            or torch.backends.mps.is_available()
482*da0073e9SAndroid Build Coastguard Worker            or torch.xpu.is_available()
483*da0073e9SAndroid Build Coastguard Worker        ):
484*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Unable to handle autograd"):
485*da0073e9SAndroid Build Coastguard Worker                with ctx.Pool(3) as pool:
486*da0073e9SAndroid Build Coastguard Worker                    pool.map(simple_autograd_function, [1, 2, 3])
487*da0073e9SAndroid Build Coastguard Worker        else:
488*da0073e9SAndroid Build Coastguard Worker            with ctx.Pool(3) as pool:
489*da0073e9SAndroid Build Coastguard Worker                pool.map(simple_autograd_function, [1, 2, 3])
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
492*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN, "Test needs to use spawn multiprocessing"
493*da0073e9SAndroid Build Coastguard Worker    )
494*da0073e9SAndroid Build Coastguard Worker    def test_autograd_fine_with_spawn(self):
495*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
496*da0073e9SAndroid Build Coastguard Worker        simple_autograd_function()
497*da0073e9SAndroid Build Coastguard Worker        with ctx.Pool(3) as pool:
498*da0073e9SAndroid Build Coastguard Worker            pool.map(simple_autograd_function, [1, 2, 3])
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
501*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
502*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
503*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
504*da0073e9SAndroid Build Coastguard Worker    )
505*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
506*da0073e9SAndroid Build Coastguard Worker    def test_cuda_simple(self):
507*da0073e9SAndroid Build Coastguard Worker        torch.cuda.FloatTensor([1])  # initialize CUDA outside of leak checker
508*da0073e9SAndroid Build Coastguard Worker        self._test_sharing(mp.get_context("spawn"), "cuda", torch.float)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
511*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
512*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
513*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
514*da0073e9SAndroid Build Coastguard Worker    )
515*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
516*da0073e9SAndroid Build Coastguard Worker    def test_cuda_memory_allocation(self):
517*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
518*da0073e9SAndroid Build Coastguard Worker        q = ctx.Queue()
519*da0073e9SAndroid Build Coastguard Worker        e = ctx.Event()
520*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(
521*da0073e9SAndroid Build Coastguard Worker            target=send_and_delete_tensors, args=(q, e, "cuda", torch.int, 5)
522*da0073e9SAndroid Build Coastguard Worker        )
523*da0073e9SAndroid Build Coastguard Worker        p.start()
524*da0073e9SAndroid Build Coastguard Worker        t = []
525*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
526*da0073e9SAndroid Build Coastguard Worker            t.append(q.get())
527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0], torch.full([5], 0, dtype=torch.int32))
528*da0073e9SAndroid Build Coastguard Worker        del t
529*da0073e9SAndroid Build Coastguard Worker        e.set()
530*da0073e9SAndroid Build Coastguard Worker        p.join(1)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
533*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
534*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
535*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
536*da0073e9SAndroid Build Coastguard Worker    )
537*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
538*da0073e9SAndroid Build Coastguard Worker    def test_cuda_ipc_deadlock(self):
539*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
540*da0073e9SAndroid Build Coastguard Worker        queue = ctx.Queue(1)
541*da0073e9SAndroid Build Coastguard Worker        processes = dict(
542*da0073e9SAndroid Build Coastguard Worker            a=ctx.Process(target=_test_cuda_ipc_deadlock_actor, args=(queue, 100)),
543*da0073e9SAndroid Build Coastguard Worker            l=ctx.Process(target=_test_cuda_ipc_deadlock_learner, args=(queue, 100)),
544*da0073e9SAndroid Build Coastguard Worker        )
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        for p in processes.values():
547*da0073e9SAndroid Build Coastguard Worker            p.start()
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        for p in processes.values():
550*da0073e9SAndroid Build Coastguard Worker            p.join(10)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        for p in processes.values():
553*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @slowTest
556*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
557*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
558*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
559*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
560*da0073e9SAndroid Build Coastguard Worker    )
561*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
562*da0073e9SAndroid Build Coastguard Worker    def test_cuda_send_many(self, name=None, size=5, count=100000):
563*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
564*da0073e9SAndroid Build Coastguard Worker        q1 = ctx.Queue()
565*da0073e9SAndroid Build Coastguard Worker        q2 = ctx.Queue()
566*da0073e9SAndroid Build Coastguard Worker        q3 = ctx.Queue()
567*da0073e9SAndroid Build Coastguard Worker        e1 = ctx.Event()
568*da0073e9SAndroid Build Coastguard Worker        e2 = ctx.Event()
569*da0073e9SAndroid Build Coastguard Worker        e3 = ctx.Event()
570*da0073e9SAndroid Build Coastguard Worker        p1 = ctx.Process(
571*da0073e9SAndroid Build Coastguard Worker            target=send_and_delete_tensors,
572*da0073e9SAndroid Build Coastguard Worker            args=(q1, e1, "cuda", torch.long, count, size),
573*da0073e9SAndroid Build Coastguard Worker        )
574*da0073e9SAndroid Build Coastguard Worker        p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count))
575*da0073e9SAndroid Build Coastguard Worker        p3 = ctx.Process(
576*da0073e9SAndroid Build Coastguard Worker            target=receive_and_send_sum,
577*da0073e9SAndroid Build Coastguard Worker            args=(q2, q3, e3, "cuda", torch.long, count, size),
578*da0073e9SAndroid Build Coastguard Worker        )
579*da0073e9SAndroid Build Coastguard Worker        p1.start()
580*da0073e9SAndroid Build Coastguard Worker        p2.start()
581*da0073e9SAndroid Build Coastguard Worker        p3.start()
582*da0073e9SAndroid Build Coastguard Worker        result = q3.get()
583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result[0], int(count * (count - 1) / 2))
584*da0073e9SAndroid Build Coastguard Worker        del result
585*da0073e9SAndroid Build Coastguard Worker        e1.set()
586*da0073e9SAndroid Build Coastguard Worker        e2.set()
587*da0073e9SAndroid Build Coastguard Worker        e3.set()
588*da0073e9SAndroid Build Coastguard Worker        p1.join(1)
589*da0073e9SAndroid Build Coastguard Worker        p2.join(1)
590*da0073e9SAndroid Build Coastguard Worker        p3.join(1)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
593*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
594*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
595*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
596*da0073e9SAndroid Build Coastguard Worker    )
597*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
598*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
599*da0073e9SAndroid Build Coastguard Worker    def test_cuda_small_tensors(self):
600*da0073e9SAndroid Build Coastguard Worker        # Check multiple small tensors which will likely use the same
601*da0073e9SAndroid Build Coastguard Worker        # underlying cached allocation
602*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
603*da0073e9SAndroid Build Coastguard Worker        tensors = []
604*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
605*da0073e9SAndroid Build Coastguard Worker            device = i % 2
606*da0073e9SAndroid Build Coastguard Worker            tensors += [torch.arange(i * 5.0, (i + 1) * 5).cuda(device)]
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        inq = ctx.Queue()
609*da0073e9SAndroid Build Coastguard Worker        outq = ctx.Queue()
610*da0073e9SAndroid Build Coastguard Worker        inq.put(tensors)
611*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(target=sum_tensors, args=(inq, outq))
612*da0073e9SAndroid Build Coastguard Worker        p.start()
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        results = []
615*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
616*da0073e9SAndroid Build Coastguard Worker            results.append(outq.get())
617*da0073e9SAndroid Build Coastguard Worker        p.join()
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        for i, _tensor in enumerate(tensors):
620*da0073e9SAndroid Build Coastguard Worker            v, device, tensor_size, storage_size = results[i]
621*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v, torch.arange(i * 5.0, (i + 1) * 5).sum())
622*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device, i % 2)
623*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_size, 5)
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker            # You might think this should be the case, but it's not!  After
626*da0073e9SAndroid Build Coastguard Worker            # data from the CUDA caching allocator goes through IPC, the
627*da0073e9SAndroid Build Coastguard Worker            # size of the storage is the size of the *cached cudaMalloc for
628*da0073e9SAndroid Build Coastguard Worker            # the entire memory block* of the storage, not just the storage.
629*da0073e9SAndroid Build Coastguard Worker            # See Note [CUDA IPC and the caching allocator] for more info
630*da0073e9SAndroid Build Coastguard Worker            #
631*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(storage_size, 5)
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        # Collect current process (producer) files, make sure nothing holds
634*da0073e9SAndroid Build Coastguard Worker        # ref to the sent tensors
635*da0073e9SAndroid Build Coastguard Worker        del _tensor
636*da0073e9SAndroid Build Coastguard Worker        del tensors
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker        # We need to collect, as CUDA MP implementation holds one shared
639*da0073e9SAndroid Build Coastguard Worker        # memory 'file' for performance reason
640*da0073e9SAndroid Build Coastguard Worker        torch.cuda.ipc_collect()
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
643*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
644*da0073e9SAndroid Build Coastguard Worker    def test_cuda_bad_call(self):
645*da0073e9SAndroid Build Coastguard Worker        # Initialize CUDA
646*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(5, 5).cuda().cpu()
647*da0073e9SAndroid Build Coastguard Worker        inq = mp.Queue()
648*da0073e9SAndroid Build Coastguard Worker        outq = mp.Queue()
649*da0073e9SAndroid Build Coastguard Worker        p = mp.Process(target=queue_get_exception, args=(inq, outq))
650*da0073e9SAndroid Build Coastguard Worker        p.start()
651*da0073e9SAndroid Build Coastguard Worker        inq.put(t)
652*da0073e9SAndroid Build Coastguard Worker        p.join()
653*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(outq.get(), RuntimeError)
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
656*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
657*da0073e9SAndroid Build Coastguard Worker    def test_wrong_cuda_fork(self):
658*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr(
659*da0073e9SAndroid Build Coastguard Worker            """\
660*da0073e9SAndroid Build Coastguard Workerimport torch
661*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process
662*da0073e9SAndroid Build Coastguard Workerdef run(rank):
663*da0073e9SAndroid Build Coastguard Worker    torch.cuda.set_device(rank)
664*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
665*da0073e9SAndroid Build Coastguard Worker    size = 2
666*da0073e9SAndroid Build Coastguard Worker    processes = []
667*da0073e9SAndroid Build Coastguard Worker    for rank in range(size):
668*da0073e9SAndroid Build Coastguard Worker        # it would work fine without the line below
669*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(20, 2).cuda()
670*da0073e9SAndroid Build Coastguard Worker        p = Process(target=run, args=(rank,))
671*da0073e9SAndroid Build Coastguard Worker        p.start()
672*da0073e9SAndroid Build Coastguard Worker        processes.append(p)
673*da0073e9SAndroid Build Coastguard Worker    for p in processes:
674*da0073e9SAndroid Build Coastguard Worker        p.join()
675*da0073e9SAndroid Build Coastguard Worker"""
676*da0073e9SAndroid Build Coastguard Worker        )
677*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.")
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
680*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
681*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
682*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
683*da0073e9SAndroid Build Coastguard Worker    )
684*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
685*da0073e9SAndroid Build Coastguard Worker    def test_event(self):
686*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
687*da0073e9SAndroid Build Coastguard Worker        queue = ctx.Queue()
688*da0073e9SAndroid Build Coastguard Worker        ready = ctx.Event()
689*da0073e9SAndroid Build Coastguard Worker        done = ctx.Event()
690*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(target=cuda_multiply_two, args=(queue, ready, done))
691*da0073e9SAndroid Build Coastguard Worker        p.start()
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        ready.wait()
694*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(torch.cuda.Stream()):
695*da0073e9SAndroid Build Coastguard Worker            tensor = torch.cuda.FloatTensor([1, 1, 1, 1])
696*da0073e9SAndroid Build Coastguard Worker            # Use a sleep kernel to test events. Without the event, the
697*da0073e9SAndroid Build Coastguard Worker            # multiply happens before the add.
698*da0073e9SAndroid Build Coastguard Worker            event = torch.cuda.Event(interprocess=True)
699*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(20000000)  # about 30 ms
700*da0073e9SAndroid Build Coastguard Worker            tensor.add_(1)
701*da0073e9SAndroid Build Coastguard Worker            event.record()
702*da0073e9SAndroid Build Coastguard Worker            queue.put((event, tensor))
703*da0073e9SAndroid Build Coastguard Worker            done.wait()  # must wait until subprocess records event
704*da0073e9SAndroid Build Coastguard Worker            event.synchronize()
705*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(tensor), [4, 4, 4, 4])
706*da0073e9SAndroid Build Coastguard Worker        p.join()
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker    @staticmethod
709*da0073e9SAndroid Build Coastguard Worker    def _test_event_multiprocess_child(event, p2c, c2p):
710*da0073e9SAndroid Build Coastguard Worker        c2p.put(0)  # notify parent child is ready
711*da0073e9SAndroid Build Coastguard Worker        p2c.get()  # wait for record in parent
712*da0073e9SAndroid Build Coastguard Worker        event.synchronize()
713*da0073e9SAndroid Build Coastguard Worker        c2p.put(1)  # notify parent synchronization is done
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
716*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
717*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
718*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
719*da0073e9SAndroid Build Coastguard Worker    )
720*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
721*da0073e9SAndroid Build Coastguard Worker    def test_event_multiprocess(self):
722*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event(enable_timing=False, interprocess=True)
723*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
726*da0073e9SAndroid Build Coastguard Worker        p2c = ctx.SimpleQueue()
727*da0073e9SAndroid Build Coastguard Worker        c2p = ctx.SimpleQueue()
728*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(
729*da0073e9SAndroid Build Coastguard Worker            target=TestMultiprocessing._test_event_multiprocess_child,
730*da0073e9SAndroid Build Coastguard Worker            args=(event, p2c, c2p),
731*da0073e9SAndroid Build Coastguard Worker        )
732*da0073e9SAndroid Build Coastguard Worker        p.start()
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        c2p.get()  # wait for until child process is ready
735*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(50000000)  # spin for about 50 ms
736*da0073e9SAndroid Build Coastguard Worker        event.record()
737*da0073e9SAndroid Build Coastguard Worker        p2c.put(0)  # notify child event is recorded
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(event.query())
740*da0073e9SAndroid Build Coastguard Worker        c2p.get()  # wait for synchronization in child
741*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(event.query())
742*da0073e9SAndroid Build Coastguard Worker        p.join()
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
745*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
746*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
747*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
748*da0073e9SAndroid Build Coastguard Worker    )
749*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
750*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
751*da0073e9SAndroid Build Coastguard Worker    def test_event_handle_multi_gpu(self):
752*da0073e9SAndroid Build Coastguard Worker        d0 = torch.device("cuda:0")
753*da0073e9SAndroid Build Coastguard Worker        d1 = torch.device("cuda:1")
754*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
755*da0073e9SAndroid Build Coastguard Worker            e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
758*da0073e9SAndroid Build Coastguard Worker            # create handle on different device from un-recorded event
759*da0073e9SAndroid Build Coastguard Worker            e0.ipc_handle()
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d0):
762*da0073e9SAndroid Build Coastguard Worker            e1 = torch.cuda.Event(enable_timing=False, interprocess=True)
763*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.Stream()
764*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(50000000)  # spin for about 50 ms
765*da0073e9SAndroid Build Coastguard Worker            e1.record(stream)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(d1):
768*da0073e9SAndroid Build Coastguard Worker            # create handle on different device from recorded event
769*da0073e9SAndroid Build Coastguard Worker            e1.ipc_handle()
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    @staticmethod
772*da0073e9SAndroid Build Coastguard Worker    def _test_event_handle_importer_consumer(handle, p2c, c2p):
773*da0073e9SAndroid Build Coastguard Worker        e1 = torch.cuda.Event.from_ipc_handle(0, handle)
774*da0073e9SAndroid Build Coastguard Worker        c2p.put(0)  # notify parent child is ready
775*da0073e9SAndroid Build Coastguard Worker        p2c.get()  # wait for record in parent
776*da0073e9SAndroid Build Coastguard Worker        e1.synchronize()
777*da0073e9SAndroid Build Coastguard Worker        c2p.put(1)  # notify synchronization is done in child
778*da0073e9SAndroid Build Coastguard Worker        p2c.get()  # wait for parent to finish before destructing child event
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
781*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
782*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
783*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
784*da0073e9SAndroid Build Coastguard Worker    )
785*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
786*da0073e9SAndroid Build Coastguard Worker    def test_event_handle_importer(self):
787*da0073e9SAndroid Build Coastguard Worker        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
788*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
791*da0073e9SAndroid Build Coastguard Worker        p2c = ctx.SimpleQueue()
792*da0073e9SAndroid Build Coastguard Worker        c2p = ctx.SimpleQueue()
793*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(
794*da0073e9SAndroid Build Coastguard Worker            target=TestMultiprocessing._test_event_handle_importer_consumer,
795*da0073e9SAndroid Build Coastguard Worker            args=(e0.ipc_handle(), p2c, c2p),
796*da0073e9SAndroid Build Coastguard Worker        )
797*da0073e9SAndroid Build Coastguard Worker        p.start()
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker        c2p.get()  # wait for child to become ready
800*da0073e9SAndroid Build Coastguard Worker        torch.cuda._sleep(50000000)  # spin for about 50 ms
801*da0073e9SAndroid Build Coastguard Worker        e0.record()
802*da0073e9SAndroid Build Coastguard Worker        p2c.put(0)  # notify child event is recorded
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(e0.query())
805*da0073e9SAndroid Build Coastguard Worker        c2p.get()  # wait for synchronization in child
806*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
807*da0073e9SAndroid Build Coastguard Worker        p2c.put(1)  # notify child that parent is done
808*da0073e9SAndroid Build Coastguard Worker        p.join()
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    @staticmethod
811*da0073e9SAndroid Build Coastguard Worker    def _test_event_handle_exporter_consumer(handle, p2c, c2p):
812*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
813*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
814*da0073e9SAndroid Build Coastguard Worker            e1 = torch.cuda.Event.from_ipc_handle(torch.cuda.current_device(), handle)
815*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(50000000)  # spin for about 50 ms
816*da0073e9SAndroid Build Coastguard Worker            e1.record()
817*da0073e9SAndroid Build Coastguard Worker            c2p.put(0)
818*da0073e9SAndroid Build Coastguard Worker            # wait for parent process finished synchronization before
819*da0073e9SAndroid Build Coastguard Worker            # destructing e1
820*da0073e9SAndroid Build Coastguard Worker            p2c.get()
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
823*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
824*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
825*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
826*da0073e9SAndroid Build Coastguard Worker    )
827*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
828*da0073e9SAndroid Build Coastguard Worker    def test_event_handle_exporter(self):
829*da0073e9SAndroid Build Coastguard Worker        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
832*da0073e9SAndroid Build Coastguard Worker        p2c = ctx.SimpleQueue()
833*da0073e9SAndroid Build Coastguard Worker        c2p = ctx.SimpleQueue()
834*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(
835*da0073e9SAndroid Build Coastguard Worker            target=TestMultiprocessing._test_event_handle_exporter_consumer,
836*da0073e9SAndroid Build Coastguard Worker            args=(e0.ipc_handle(), p2c, c2p),
837*da0073e9SAndroid Build Coastguard Worker        )
838*da0073e9SAndroid Build Coastguard Worker        p.start()
839*da0073e9SAndroid Build Coastguard Worker        # wait for event in child process is recorded
840*da0073e9SAndroid Build Coastguard Worker        c2p.get()
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(e0.query())
843*da0073e9SAndroid Build Coastguard Worker        e0.synchronize()
844*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(e0.query())
845*da0073e9SAndroid Build Coastguard Worker        p2c.put(0)
846*da0073e9SAndroid Build Coastguard Worker        p.join()
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    def _test_empty_tensor_sharing(self, dtype, device):
849*da0073e9SAndroid Build Coastguard Worker        q = mp.Queue()
850*da0073e9SAndroid Build Coastguard Worker        empty = torch.tensor([], dtype=dtype, device=device)
851*da0073e9SAndroid Build Coastguard Worker        q.put(empty)
852*da0073e9SAndroid Build Coastguard Worker        out = q.get(timeout=1)
853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, empty)
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_sharing(self):
856*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.float32, torch.device("cpu"))
857*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.int64, torch.device("cpu"))
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
860*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_sharing_cuda(self):
861*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.float32, torch.device("cuda"))
862*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.int64, torch.device("cuda"))
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_sharing_meta(self):
865*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.float32, torch.device("meta"))
866*da0073e9SAndroid Build Coastguard Worker        self._test_empty_tensor_sharing(torch.int64, torch.device("meta"))
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker    def test_tensor_sharing_meta(self):
869*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
870*da0073e9SAndroid Build Coastguard Worker        device = torch.device("meta")
871*da0073e9SAndroid Build Coastguard Worker        q = mp.Queue()
872*da0073e9SAndroid Build Coastguard Worker        empty = torch.tensor([1], dtype=dtype, device=device)
873*da0073e9SAndroid Build Coastguard Worker        q.put(empty)
874*da0073e9SAndroid Build Coastguard Worker        out = q.get(timeout=1)
875*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, empty)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker    def test_meta_simple(self):
878*da0073e9SAndroid Build Coastguard Worker        self._test_sharing(mp.get_context("spawn"), "meta", torch.float)
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker    def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False):
881*da0073e9SAndroid Build Coastguard Worker        device = "cuda" if var.is_cuda else "cpu"
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker        ready = ctx.Event()
884*da0073e9SAndroid Build Coastguard Worker        master_modified = ctx.Event()
885*da0073e9SAndroid Build Coastguard Worker        queue = ctx.Queue()
886*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(
887*da0073e9SAndroid Build Coastguard Worker            target=autograd_sharing,
888*da0073e9SAndroid Build Coastguard Worker            args=(queue, ready, master_modified, device, is_parameter),
889*da0073e9SAndroid Build Coastguard Worker        )
890*da0073e9SAndroid Build Coastguard Worker        p.daemon = True
891*da0073e9SAndroid Build Coastguard Worker        p.start()
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker        # This would cause an error if we tried to serialize the hooks,
894*da0073e9SAndroid Build Coastguard Worker        # because it's a closure and pickle doesn't support closures.
895*da0073e9SAndroid Build Coastguard Worker        @torch.utils.hooks.unserializable_hook
896*da0073e9SAndroid Build Coastguard Worker        def hook(*unused):
897*da0073e9SAndroid Build Coastguard Worker            pass
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker        if var.requires_grad:
900*da0073e9SAndroid Build Coastguard Worker            var.register_hook(hook)
901*da0073e9SAndroid Build Coastguard Worker        var._grad = torch.zeros(5, 5, device=device)
902*da0073e9SAndroid Build Coastguard Worker        queue.put(var)
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        ready.wait()
905*da0073e9SAndroid Build Coastguard Worker        var.data[0, 0] = 1000
906*da0073e9SAndroid Build Coastguard Worker        var.grad.data[:] = torch.ones(5, 5, device=device) * 4
907*da0073e9SAndroid Build Coastguard Worker        master_modified.set()
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        worker_ok = queue.get()
910*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(worker_ok)
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(var.data, torch.ones(5, 5, device=device))
913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(var.grad.data, torch.ones(5, 5, device=device) * 4)
914*da0073e9SAndroid Build Coastguard Worker        p.join(100)
915*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p.is_alive())
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker    # Check sharing a cudaMalloc allocation with different types of storage.
918*da0073e9SAndroid Build Coastguard Worker    # (Issue #11422)
919*da0073e9SAndroid Build Coastguard Worker    def _test_mixed_types_cuda_sharing(self, ctx=mp):
920*da0073e9SAndroid Build Coastguard Worker        all_ones = torch.ones(2, 2).float()
921*da0073e9SAndroid Build Coastguard Worker        all_zeros = torch.zeros(2, 2).byte()
922*da0073e9SAndroid Build Coastguard Worker        queue = ctx.Queue()
923*da0073e9SAndroid Build Coastguard Worker        event = ctx.Event()
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(target=mixed_type_producer, args=(queue, event))
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker        p.start()
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
930*da0073e9SAndroid Build Coastguard Worker            float_tensor = queue.get()
931*da0073e9SAndroid Build Coastguard Worker            byte_tensor = queue.get()
932*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(float_tensor, all_ones)
933*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(byte_tensor, all_zeros)
934*da0073e9SAndroid Build Coastguard Worker            del float_tensor, byte_tensor
935*da0073e9SAndroid Build Coastguard Worker            event.set()
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        time.sleep(5)
938*da0073e9SAndroid Build Coastguard Worker        p.join()
939*da0073e9SAndroid Build Coastguard Worker
940*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
941*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ASAN,
942*da0073e9SAndroid Build Coastguard Worker        "non-deterministically hangs with ASAN https://github.com/pytorch/pytorch/issues/94024",
943*da0073e9SAndroid Build Coastguard Worker    )
944*da0073e9SAndroid Build Coastguard Worker    def test_variable_sharing(self):
945*da0073e9SAndroid Build Coastguard Worker        for requires_grad in [True, False]:
946*da0073e9SAndroid Build Coastguard Worker            var = torch.arange(1.0, 26).view(5, 5).requires_grad_(requires_grad)
947*da0073e9SAndroid Build Coastguard Worker            self._test_autograd_sharing(var)
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/14997
950*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "non-deterministically hangs with ASAN")
951*da0073e9SAndroid Build Coastguard Worker    def test_leaf_variable_sharing(self):
952*da0073e9SAndroid Build Coastguard Worker        devices = ["cpu"]
953*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available() and not NO_MULTIPROCESSING_SPAWN and TEST_CUDA_IPC:
954*da0073e9SAndroid Build Coastguard Worker            devices.append("cuda")
955*da0073e9SAndroid Build Coastguard Worker        for device in devices:
956*da0073e9SAndroid Build Coastguard Worker            for requires_grad in [True, False]:
957*da0073e9SAndroid Build Coastguard Worker                var = (
958*da0073e9SAndroid Build Coastguard Worker                    torch.arange(1.0, 26, device=device)
959*da0073e9SAndroid Build Coastguard Worker                    .view(5, 5)
960*da0073e9SAndroid Build Coastguard Worker                    .requires_grad_(requires_grad)
961*da0073e9SAndroid Build Coastguard Worker                )
962*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(var.is_leaf)
963*da0073e9SAndroid Build Coastguard Worker                ctx = mp.get_context("spawn") if device == "cuda" else mp
964*da0073e9SAndroid Build Coastguard Worker                ready = ctx.Event()
965*da0073e9SAndroid Build Coastguard Worker                queue = ctx.Queue()
966*da0073e9SAndroid Build Coastguard Worker                p = ctx.Process(
967*da0073e9SAndroid Build Coastguard Worker                    target=requires_grad_variable_sharing, args=(queue, ready)
968*da0073e9SAndroid Build Coastguard Worker                )
969*da0073e9SAndroid Build Coastguard Worker                p.daemon = True
970*da0073e9SAndroid Build Coastguard Worker                p.start()
971*da0073e9SAndroid Build Coastguard Worker                queue.put(var)
972*da0073e9SAndroid Build Coastguard Worker                ready.wait()
973*da0073e9SAndroid Build Coastguard Worker                worker_requires_grad = queue.get()
974*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(worker_requires_grad == requires_grad)
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker    def test_non_leaf_variable_sharing(self):
977*da0073e9SAndroid Build Coastguard Worker        devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
978*da0073e9SAndroid Build Coastguard Worker        for device in devices:
979*da0073e9SAndroid Build Coastguard Worker            var0 = torch.arange(1.0, 26, device=device).view(5, 5).requires_grad_(True)
980*da0073e9SAndroid Build Coastguard Worker            var = var0 * 2
981*da0073e9SAndroid Build Coastguard Worker            # Don't use a regular Queue; it uses a background thread (which
982*da0073e9SAndroid Build Coastguard Worker            # means we can't catch the exceptions)
983*da0073e9SAndroid Build Coastguard Worker            queue = mp.SimpleQueue()
984*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
985*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"requires_grad", lambda: queue.put(var)
986*da0073e9SAndroid Build Coastguard Worker            )
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
989*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
990*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
991*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
992*da0073e9SAndroid Build Coastguard Worker    )
993*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
994*da0073e9SAndroid Build Coastguard Worker    def test_cuda_variable_sharing(self):
995*da0073e9SAndroid Build Coastguard Worker        for requires_grad in [True, False]:
996*da0073e9SAndroid Build Coastguard Worker            var = (
997*da0073e9SAndroid Build Coastguard Worker                torch.arange(1.0, 26, device="cuda")
998*da0073e9SAndroid Build Coastguard Worker                .view(5, 5)
999*da0073e9SAndroid Build Coastguard Worker                .requires_grad_(requires_grad)
1000*da0073e9SAndroid Build Coastguard Worker            )
1001*da0073e9SAndroid Build Coastguard Worker            self._test_autograd_sharing(var, mp.get_context("spawn"))
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1004*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
1005*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
1006*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
1007*da0073e9SAndroid Build Coastguard Worker    )
1008*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1009*da0073e9SAndroid Build Coastguard Worker    def test_mixed_types_cuda_sharing(self):
1010*da0073e9SAndroid Build Coastguard Worker        self._test_mixed_types_cuda_sharing(mp.get_context("spawn"))
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker    def test_parameter_sharing(self):
1013*da0073e9SAndroid Build Coastguard Worker        param = Parameter(torch.arange(1.0, 26).view(5, 5))
1014*da0073e9SAndroid Build Coastguard Worker        self._test_autograd_sharing(param, is_parameter=True)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1017*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
1018*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
1019*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
1020*da0073e9SAndroid Build Coastguard Worker    )
1021*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1022*da0073e9SAndroid Build Coastguard Worker    def test_cuda_parameter_sharing(self):
1023*da0073e9SAndroid Build Coastguard Worker        param = Parameter(torch.arange(1.0, 26, device="cuda").view(5, 5))
1024*da0073e9SAndroid Build Coastguard Worker        self._test_autograd_sharing(param, mp.get_context("spawn"), is_parameter=True)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1027*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
1028*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
1029*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
1030*da0073e9SAndroid Build Coastguard Worker    )
1031*da0073e9SAndroid Build Coastguard Worker    def test_integer_parameter_serialization_cpu(self):
1032*da0073e9SAndroid Build Coastguard Worker        self._test_integer_parameter_serialization(device="cpu")
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1035*da0073e9SAndroid Build Coastguard Worker        NO_MULTIPROCESSING_SPAWN,
1036*da0073e9SAndroid Build Coastguard Worker        "Disabled for environments that \
1037*da0073e9SAndroid Build Coastguard Worker                     don't support multiprocessing with spawn start method",
1038*da0073e9SAndroid Build Coastguard Worker    )
1039*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1040*da0073e9SAndroid Build Coastguard Worker    def test_integer_parameter_serialization_cuda(self):
1041*da0073e9SAndroid Build Coastguard Worker        self._test_integer_parameter_serialization(device="cuda")
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker    def _test_integer_parameter_serialization(self, device):
1044*da0073e9SAndroid Build Coastguard Worker        param = torch.nn.Parameter(
1045*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0, dtype=torch.int64, device=device), requires_grad=False
1046*da0073e9SAndroid Build Coastguard Worker        )
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        ctx = mp.get_context("spawn")
1049*da0073e9SAndroid Build Coastguard Worker        p = ctx.Process(target=integer_parameter_serialization, args=(param,))
1050*da0073e9SAndroid Build Coastguard Worker        p.start()
1051*da0073e9SAndroid Build Coastguard Worker        p.join()
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1054*da0073e9SAndroid Build Coastguard Worker            0,
1055*da0073e9SAndroid Build Coastguard Worker            p.exitcode,
1056*da0073e9SAndroid Build Coastguard Worker            msg=f'Failed to serialize successfully for "{device}" device!',
1057*da0073e9SAndroid Build Coastguard Worker        )
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker    def test_empty_shared(self):
1060*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([])
1061*da0073e9SAndroid Build Coastguard Worker        t.share_memory_()
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker    def _test_is_shared(self):
1064*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(5, 5)
1065*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t.is_shared())
1066*da0073e9SAndroid Build Coastguard Worker        t.share_memory_()
1067*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.is_shared())
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1070*da0073e9SAndroid Build Coastguard Worker        platform == "darwin", "file descriptor strategy is not supported on macOS"
1071*da0073e9SAndroid Build Coastguard Worker    )
1072*da0073e9SAndroid Build Coastguard Worker    def test_is_shared(self):
1073*da0073e9SAndroid Build Coastguard Worker        self._test_is_shared()
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    def test_fs_is_shared(self):
1076*da0073e9SAndroid Build Coastguard Worker        with fs_sharing():
1077*da0073e9SAndroid Build Coastguard Worker            self._test_is_shared()
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
1080*da0073e9SAndroid Build Coastguard Worker    def test_is_shared_cuda(self):
1081*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(5, 5).cuda()
1082*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t.is_shared())
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.platform != "linux", "Only runs on Linux; requires prctl(2)")
1085*da0073e9SAndroid Build Coastguard Worker    def test_set_thread_name(self):
1086*da0073e9SAndroid Build Coastguard Worker        name = "test name"
1087*da0073e9SAndroid Build Coastguard Worker        mp._set_thread_name(name)
1088*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mp._get_thread_name(), name)
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1092*da0073e9SAndroid Build Coastguard Worker    run_tests()
1093