xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/dist_autograd_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import sys
4import threading
5import time
6from enum import Enum
7import random
8import torch
9import torch.nn as nn
10from datetime import timedelta
11import torch.distributed as dist
12import torch.distributed.autograd as dist_autograd
13import torch.distributed.rpc as rpc
14import torch.testing._internal.dist_utils
15from torch.autograd import Function
16from torch.autograd.function import once_differentiable
17from torch.distributed.rpc import RRef
18from torch.testing._internal.common_utils import IS_MACOS, skip_but_pass_in_sandcastle_if
19from torch.testing._internal.dist_utils import (
20    dist_init,
21    initialize_pg,
22    wait_until_node_failure,
23    worker_name,
24)
25from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
26    RpcAgentTestFixture,
27)
28from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
29
30
31# Right now we test up to 3-layer nested rpc calls.
32# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
33# sent from prev rank respectively.
34# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
35# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
36# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
37rpc_done = [False, False, False, False]
38ctx_ids = [-1, -1, -1, -1]
39
40known_context_ids = set()
41
42requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
43
44# Send rpc done info and context_id to
45# dst_rank = (self.rank + rank_distance) % self.world_size
46# we don't need a lock here since the GIL is held while executing remote
47# python UDFs, so access is serialized across several workers.
48def _set_rpc_done(ctx_id, rank_distance):
49    global rpc_done
50    global ctx_ids
51    global known_context_ids
52    rpc_done[rank_distance] = True
53    ctx_ids[rank_distance] = ctx_id
54    known_context_ids.add(ctx_id)
55
56
57def _check_rpc_done(rank_distance):
58    while not rpc_done[rank_distance]:
59        time.sleep(0.1)
60
61
62def _torch_ones(sizes, requires_grad=False):
63    return torch.ones(sizes, requires_grad=requires_grad)
64
65# This method must be called on the rref owner, and verifies that the grad of
66# rref tensor equals to the given grad.
67def _compare_owner_value(context_id, rref, grad):
68    grads = dist_autograd.get_gradients(context_id)
69    x = grads[rref.local_value()]
70    if x.is_sparse:
71        assert grad.is_sparse
72        x = x.to_dense()
73        grad = grad.to_dense()
74    else:
75        assert not grad.is_sparse
76    return torch.equal(x, grad)
77
78
79def create_tensor():
80    return torch.ones((3, 3), requires_grad=True)
81
82
83def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
84    i = [[0, 1, 1], [2, 0, 2]]
85    v = [3.2, 4.1, 5.3]
86    tensor = torch.sparse_coo_tensor(
87        i, v, (3, 3), requires_grad=requires_grad, dtype=dtype
88    )
89    if coalesce:
90        tensor = tensor.coalesce()
91    return tensor
92
93
94@torch.jit.script
95def create_torchscript_tensor() -> torch.Tensor:
96    return torch.ones((3, 3)).requires_grad_()
97
98
99def my_py_add(t1, t2):
100    return torch.add(t1, t2)
101
102
103def my_scalar_add(a, b):
104    return a + b
105
106
107def my_rref_add(rref_t1, t2):
108    ret = torch.add(rref_t1.local_value(), t2)
109    return ret
110
111
112@torch.jit.script
113def my_script_add(t1, t2):
114    return torch.add(t1, t2)
115
116
117@torch.jit.script
118def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor:
119    t1 = ref_t1.to_here()
120    return torch.add(t1, t2)
121
122
123def my_nested_rref_add(dst, rref_t1, t2):
124    return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
125
126
127def ret_requires_grad():
128    return requires_grad_tensor
129
130
131def my_py_nested_call(t1, t2, dst, world_size, hops):
132    next_dst = (dst + 1) % world_size
133    if hops > 0:
134        return rpc.rpc_sync(
135            worker_name(next_dst),
136            my_py_nested_call,
137            args=(t1, t2, next_dst, world_size, hops - 1),
138        )
139    else:
140        return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2))
141
142
143# after dist autograd context is cleaned up, it should be cleaned up on other
144# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
145# ensures that all the contexts have been cleaned up in that timeframe.any
146def _all_contexts_cleaned_up(timeout_seconds=10):
147    global known_context_ids
148    start = time.time()
149    context_id_to_raised = set()
150    while (
151        time.time() - start < timeout_seconds
152        and context_id_to_raised != known_context_ids
153    ):
154        for context_id in known_context_ids:
155            try:
156                dist_autograd._retrieve_context(context_id)
157            except RuntimeError:
158                context_id_to_raised.add(context_id)
159    # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
160    success = context_id_to_raised == known_context_ids
161    return success
162
163
164# This function creates a dis autograd context, run rpc_sync on the given ps,
165# and then blocks until the ps has verified the grads are correctly accumulated.
166def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
167    with dist_autograd.context() as context_id:
168        ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
169        if sparse:
170            loss = torch.sparse.sum(ret)
171        else:
172            loss = ret.sum()
173        dist_autograd.backward(context_id, [loss])
174        # prevent deleting dist autograd context
175        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
176        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
177
178# This function is the same as _run_trainer, except rpc calls torchscript
179# function "my_script_ref_add" instead of python function "my_rref_add"
180def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
181    with dist_autograd.context() as context_id:
182        ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
183        if sparse:
184            loss = torch.sparse.sum(ret)
185        else:
186            loss = ret.sum()
187        dist_autograd.backward(context_id, [loss])
188        # prevent deleting dist autograd context
189        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
190        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
191
192
193class SimulateBackwardError(Function):
194    _simulate_error = True
195
196    @staticmethod
197    def forward(ctx, input):
198        return input
199
200    @staticmethod
201    @once_differentiable
202    def backward(ctx, input):
203        if SimulateBackwardError._simulate_error:
204            raise Exception("Simulate error on backward pass")  # noqa: TRY002
205        else:
206            return input
207
208
209class ExecMode(Enum):
210    LOCAL = 1  # Run the operation locally.
211    RPC_SYNC = 2  # Run the operation using rpc_sync
212    REMOTE = 3  # Run the operation using remote.
213    RPC_ASYNC = 4  # Run the operation using rpc_async
214
215
216# Common utils for both CPU and CUDA test suites
217class CommonDistAutogradTest(RpcAgentTestFixture):
218    def _exec_func_with_dst(self, dst, exec_mode, method, *args):
219        if ExecMode.LOCAL == exec_mode:
220            if len(args) == 1 and isinstance(args[0], list):
221                return method(*args[0])
222            return method(*args)
223        elif ExecMode.RPC_SYNC == exec_mode:
224            return rpc.rpc_sync(worker_name(dst), method, args=(args))
225        elif ExecMode.REMOTE == exec_mode:
226            return rpc.remote(worker_name(dst), method, args=(args)).to_here()
227        elif ExecMode.RPC_ASYNC == exec_mode:
228            fut = rpc.rpc_async(worker_name(dst), method, args=(args))
229            return fut.wait()
230        else:
231            raise ValueError(f"Unrecognized ExecMode {exec_mode}")
232
233    def _exec_func(self, exec_mode, method, *args):
234        return self._exec_func_with_dst(
235            self._next_rank(), exec_mode, method, *args
236        )
237
238    def _next_rank(self):
239        if hasattr(self, "dst_rank"):
240            self.dst_rank = (self.dst_rank + 1) % self.world_size
241            if self.dst_rank == self.rank:
242                return self._next_rank()
243        else:
244            self.dst_rank = (self.rank + 1) % self.world_size
245        return self.dst_rank
246
247    def _check_rpc_done(self, rank_distance):
248        _check_rpc_done(rank_distance)
249
250    def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
251        if exec_mode == ExecMode.LOCAL:
252            torch.autograd.backward(tensors)
253            return [arg.grad for arg in args]
254        else:
255            self._verify_backwards_remote(tensors, context_id, local_grads, *args)
256
257    def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
258        dist_autograd.backward(context_id, tensors)
259
260        # Verify grads were accumulated appropriately.
261        grads = dist_autograd.get_gradients(context_id)
262        nargs = len(args)
263        ngrads = 0
264        for i in range(0, nargs):
265            if local_grads[i] is not None:
266                self.assertIn(args[i], grads)
267                self.assertEqual(local_grads[i], grads[args[i]])
268                ngrads += 1
269            else:
270                self.assertNotIn(args[i], grads)
271
272        self.assertEqual(ngrads, len(grads))
273
274    def _test_graph(self, fn, exec_mode, sparse):
275        dst_rank = (self.rank + 1) % self.world_size
276
277        initialize_pg(self.file_init_method, self.rank, self.world_size)
278
279        with dist_autograd.context() as context_id:
280            if sparse:
281                t1 = build_sparse_tensor()
282                t2 = build_sparse_tensor()
283            else:
284                t1 = torch.ones(3, 3, requires_grad=True)
285                t2 = torch.zeros(3, 3, requires_grad=True)
286            if ExecMode.RPC_SYNC == exec_mode:
287                ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
288            elif ExecMode.REMOTE == exec_mode:
289                ret = rpc.remote(
290                    worker_name(dst_rank), fn, args=(t1, t2)
291                ).to_here()
292            else:
293                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
294
295            rpc.rpc_sync(
296                worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
297            )
298
299            # Verify graph for current context id.
300            ctx = dist_autograd._current_context()
301            self.assertEqual(context_id, ctx._context_id())
302            send_functions = ctx._send_functions()
303            self.assertEqual(1, len(send_functions))
304            recv_functions = ctx._recv_functions()
305            self.assertEqual(1, len(recv_functions))
306            self._verify_graph_for_first_rpc_call(
307                next(iter(send_functions.values())),
308                next(iter(recv_functions.values())),
309                t1,
310                t2,
311                ret,
312            )
313
314            # Wait for the prev rank to be done with rpc.
315            self._check_rpc_done(1)
316            # Verify graph for previous context id.
317            ctx = dist_autograd._retrieve_context(ctx_ids[1])
318            send_functions = ctx._send_functions()
319            self.assertEqual(1, len(send_functions))
320            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
321            # this barrier is needed so one worker does not clean up their
322            # autograd context before another worker tries to access it.
323            dist.barrier()
324
325        # autograd context should be cleaned up by now.
326        with self.assertRaises(RuntimeError):
327            ctx = dist_autograd._retrieve_context(context_id)
328
329        # No autograd context available.
330        with self.assertRaises(RuntimeError):
331            ctx = dist_autograd._current_context()
332
333    # 3-layer nested calls
334    def _test_graph_for_py_nested_call(self, exec_mode, sparse):
335        dst_rank = (self.rank + 1) % self.world_size
336
337        initialize_pg(self.file_init_method, self.rank, self.world_size)
338
339        with dist_autograd.context() as context_id:
340            if sparse:
341                t1 = build_sparse_tensor(requires_grad=True)
342                t2 = build_sparse_tensor(requires_grad=True)
343            else:
344                t1 = torch.ones(3, 3, requires_grad=True)
345                t2 = torch.zeros(3, 3, requires_grad=True)
346            nest_dst_rank = (dst_rank + 1) % self.world_size
347            if ExecMode.RPC_SYNC == exec_mode:
348                ret = rpc.rpc_sync(
349                    worker_name(dst_rank),
350                    my_py_nested_call,
351                    args=(t1, t2, dst_rank, self.world_size, 1),
352                )
353            elif ExecMode.REMOTE == exec_mode:
354                ret = rpc.remote(
355                    worker_name(dst_rank),
356                    my_py_nested_call,
357                    args=(t1, t2, dst_rank, self.world_size, 1),
358                ).to_here()
359            else:
360                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
361
362            # Barrier to ensure all RPCs are done.
363            dist.barrier()
364
365            for rd in [1, 2, 3]:
366                rpc.rpc_sync(
367                    worker_name((self.rank + rd) % self.world_size),
368                    _set_rpc_done,
369                    args=(context_id, rd),
370                )
371
372            # Barrier to ensure all set_rpc_done have completed.
373            dist.barrier()
374
375            # For self.rank, it has 4 graphs to verify
376            # One is for current context id when this rank send first rpc call.
377            # Second one is for prev context id when this rank make 1st nested
378            # call.
379            # Third one is for prev prev context id when this rank make
380            # 2nd nested call.
381            # Last one is for prev prev prev context id when this rank
382            # execute the torch.add() operator.
383
384            # Verify first graph for current context id.
385            ctx = dist_autograd._current_context()
386            self.assertEqual(context_id, ctx._context_id())
387            send_functions = ctx._send_functions()
388            self.assertEqual(1, len(send_functions))
389            recv_functions = ctx._recv_functions()
390            self.assertEqual(1, len(recv_functions))
391            self._verify_graph_for_first_rpc_call(
392                next(iter(send_functions.values())),
393                next(iter(recv_functions.values())),
394                t1,
395                t2,
396                ret,
397            )
398
399            # Verify second graph for 1st nested call.
400            ctx = dist_autograd._retrieve_context(ctx_ids[1])
401            self._verify_graph_for_nested_rpc_call(ctx)
402
403            # Verify third graph for 2nd nested call.
404            ctx = dist_autograd._retrieve_context(ctx_ids[2])
405            self._verify_graph_for_nested_rpc_call(ctx)
406
407            # verify last graph for rpc call execution.
408            ctx = dist_autograd._retrieve_context(ctx_ids[3])
409            send_functions = ctx._send_functions()
410            self.assertEqual(1, len(send_functions))
411            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
412            # this barrier is needed so one worker does not clean up their
413            # autograd context before another worker tries to access it.
414            dist.barrier()
415
416    # Rank0->Rank1->Rank0
417    def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
418        dst_rank = (self.rank + 1) % self.world_size
419
420        initialize_pg(self.file_init_method, self.rank, self.world_size)
421
422        with dist_autograd.context() as context_id:
423            if sparse:
424                t1 = build_sparse_tensor(requires_grad=True)
425                t2 = build_sparse_tensor(requires_grad=True)
426            else:
427                t1 = torch.ones(3, 3, requires_grad=True)
428                t2 = torch.zeros(3, 3, requires_grad=True)
429            if ExecMode.RPC_SYNC == exec_mode:
430                ret = rpc.rpc_sync(
431                    worker_name(dst_rank),
432                    my_py_nested_call,
433                    args=(
434                        t1,
435                        t2,
436                        (self.rank - 1 + self.world_size) % self.world_size,
437                        self.world_size,
438                        0,
439                    ),
440                )
441            elif ExecMode.REMOTE == exec_mode:
442                ret = rpc.remote(
443                    worker_name(dst_rank),
444                    my_py_nested_call,
445                    args=(
446                        t1,
447                        t2,
448                        (self.rank - 1 + self.world_size) % self.world_size,
449                        self.world_size,
450                        0,
451                    ),
452                ).to_here()
453            else:
454                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
455
456            rpc.rpc_sync(
457                worker_name((self.rank + 1) % self.world_size),
458                _set_rpc_done,
459                args=(context_id, 1),
460            )
461
462            # For self.rank, it has 2 graphs to verify.
463            # One is for current context id when this rank send first rpc
464            # call and execute the torch.add() operator.
465            # Another one is for prev context id when this rank make
466            # nested call.
467            ctx = dist_autograd._current_context()
468            self.assertEqual(context_id, ctx._context_id())
469            send_functions = ctx._send_functions()
470            self.assertEqual(2, len(send_functions))
471            recv_functions = ctx._recv_functions()
472            self.assertEqual(2, len(recv_functions))
473            self._verify_graph_for_first_rpc_call(
474                next(iter(send_functions.values())),
475                list(recv_functions.values())[1],
476                t1,
477                t2,
478                ret,
479            )
480            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])
481
482            # Verify two pairs of send and recv functions for nested
483            # call
484            self._check_rpc_done(1)
485            ctx = dist_autograd._retrieve_context(ctx_ids[1])
486            self._verify_graph_for_nested_rpc_call(ctx)
487            # this barrier is needed so one worker does not clean up their
488            # autograd context before another worker tries to access it.
489            dist.barrier()
490
491    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
492        initialize_pg(self.file_init_method, self.rank, self.world_size)
493        dst_rank = (self.rank + 1) % self.world_size
494        with dist_autograd.context() as context_id:
495            if sparse:
496                t1 = build_sparse_tensor(requires_grad=False)
497                t2 = build_sparse_tensor(requires_grad=False)
498            else:
499                t1 = torch.ones(3, 3, requires_grad=False)
500                t2 = torch.zeros(3, 3, requires_grad=False)
501            if ExecMode.RPC_SYNC == exec_mode:
502                ret = rpc.rpc_sync(
503                    worker_name(dst_rank), torch.add, args=(t1, t2)
504                )
505            elif ExecMode.REMOTE == exec_mode:
506                ret = rpc.remote(
507                    worker_name(dst_rank), torch.add, args=(t1, t2)
508                ).to_here()
509            else:
510                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
511
512            rpc.rpc_sync(
513                worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
514            )
515
516            ctx = dist_autograd._current_context()
517            send_functions = ctx._send_functions()
518            self.assertEqual(len(send_functions), 0)
519            recv_functions = ctx._recv_functions()
520            self.assertEqual(len(recv_functions), 0)
521
522            # Wait for the prev rank to be done with rpc.
523            self._check_rpc_done(1)
524            # NB: RRef.to_here() always passes the autograd context to the
525            # the callee, as the caller does not know whether the return
526            # value would contain a requires_grad tensor or not.
527            #
528            # rpc/remote with udf (_set_rpc_done here) also always passes the
529            # autograd context to the callee due to the same reason.
530            self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1]))
531            dist.barrier()
532
533    def _test_rpc_complex_args(self, exec_mode, sparse):
534        with dist_autograd.context() as context_id:
535            num_tensors = 10
536            tensors = []
537            for i in range(num_tensors):
538                if sparse:
539                    tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
540                else:
541                    tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
542                tensors.append(tensor)
543            dst_rank = self._next_rank()
544            if ExecMode.RPC_SYNC == exec_mode:
545                ret = rpc.rpc_sync(
546                    worker_name(dst_rank), torch.stack, args=(tensors,)
547                )
548            elif ExecMode.REMOTE == exec_mode:
549                ret = rpc.remote(
550                    worker_name(dst_rank), torch.stack, args=(tensors,)
551                ).to_here()
552            else:
553                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
554
555            self.assertEqual(torch.stack(tensors), ret)
556
557            # Verify appropriate tensors have been attached the autograd graph.
558            next_funcs = next(iter(dist_autograd._current_context()._send_functions().values())).next_functions
559            idx = 0
560            for i in range(len(next_funcs)):
561                self.assertEqual(
562                    "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
563                )
564                self.assertEqual(tensors[i], next_funcs[i][0].variable)
565
566            # Verify that the worker id has been recorded in the context
567            ctx = dist_autograd._current_context()
568            worker_ids = ctx._known_worker_ids()
569            self.assertEqual(len(worker_ids), 1)
570            self.assertEqual(worker_ids, {dst_rank})
571
572    def context_cleanup_test_helper(self, rpc_args, func, nested=False):
573        initialize_pg(self.file_init_method, self.rank, self.world_size)
574
575        # test that in dist autograd, in the case that tensors communicated over RPC do
576        # NOT require grad, we still cleanup the dist autograd contexts created
577        # on other nodes. This is because the autograd context is still
578        # communicated over RPC even if tensor arguments do not require grad, as
579        #  it is possible that the response could.
580        if nested:
581            dst_rank = (self.rank + 1) % self.world_size
582            nested_dst_rank = (dst_rank + 1) % self.world_size
583            dst_ranks = {dst_rank}
584        else:
585            dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
586
587        with dist_autograd.context() as context_id:
588            for dst_rank in dst_ranks:
589                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
590                rpc.rpc_sync(
591                    worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
592                )
593                if nested:
594                    rpc.rpc_sync(
595                        worker_name(nested_dst_rank),
596                        _set_rpc_done,
597                        args=(context_id, 2),
598                    )
599        # the thread's context id should be cleaned up
600        with self.assertRaises(RuntimeError):
601            dist_autograd._retrieve_context(context_id)
602        # Ensure all peers have finished mutating the
603        # `known_context_ids` set.
604        dist.barrier()
605        # check that all contexts have been cleaned up.
606        success = _all_contexts_cleaned_up()
607        self.assertTrue(success)
608
609    def _backward_no_grad_on_tensor(self, t1, t2, sparse):
610        with dist_autograd.context() as context_id:
611            loss = rpc.rpc_sync(
612                worker_name(self._next_rank()),
613                torch.add,
614                args=(t1, t2))
615            if sparse:
616                loss = torch.sparse.sum(loss)
617            else:
618                loss = loss.sum()
619            dist_autograd.backward(context_id, [loss], retain_graph=True)
620            self.assertIsNone(t1.grad)
621            self.assertIsNone(t2.grad)
622
623            # Now populate .grad with local autograd engine and
624            # verify dist autograd doesn't mess with it.
625            loss_local = torch.add(t1, t2)
626            if sparse:
627                loss_local = torch.sparse.sum(loss_local)
628            else:
629                loss_local = loss_local.sum()
630            loss_local.backward()
631            self.assertIsNotNone(t1.grad)
632            self.assertIsNotNone(t2.grad)
633
634            t1_grad_before = t1.grad
635            t2_grad_before = t2.grad
636            dist_autograd.backward(context_id, [loss])
637            self.assertEqual(t1_grad_before, t1.grad)
638            self.assertEqual(t2_grad_before, t2.grad)
639
640    # The current rank first creates a tensor on the rref_owner, and then passes
641    # the rref with another tensor to the callee to run either my_rref_add or
642    # my_nested_rref_add, depending on whether the callee is the rref owner.
643    # The grad of tensor lives on the current rank, and the grad of the rref
644    # tensor lives on the rref owner.
645    def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
646        local_ret = torch.add(t1, t2)
647        if sparse:
648            local_ret = torch.sparse.sum(local_ret)
649        else:
650            local_ret = local_ret.sum()
651        local_ret.backward()
652        with dist_autograd.context() as context_id:
653            if sparse:
654                rref_t1 = rpc.remote(
655                    rref_owner, build_sparse_tensor, args=(False, True,)
656                )
657            else:
658                rref_t1 = rpc.remote(
659                    rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
660                )
661            if callee == rref_owner:
662                rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
663            else:
664                rref = rpc.remote(
665                    callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
666                )
667            ret = rref.to_here()
668            if sparse:
669                ret = torch.sparse.sum(ret)
670            else:
671                ret = ret.sum()
672            dist_autograd.backward(context_id, [ret])
673
674            # verify grads on caller
675            grads = dist_autograd.get_gradients(context_id)
676            self.assertIn(t2, grads)
677            self.assertEqual(grads[t2], t2.grad)
678
679            # verify grads on rref owner
680            self.assertTrue(
681                rpc.rpc_sync(
682                    rref_owner,
683                    _compare_owner_value,
684                    args=(context_id, rref_t1, t1.grad),
685                )
686            )
687
688    # In this test, every rank will serve as a parameter server (ps) and a
689    # driver, and then kicks off trainers on the other three ranks. So, we have:
690    # ps = rank0 with trainers = rank1/2/3
691    # ps = rank2 with trainers = rank2/3/0
692    # ps = rank3 with trainers = rank3/0/1
693    # ps = rank4 with trainers = rank0/1/2
694    #
695    # These four test ps-trainer groups run on completely separate autograd
696    # graphs, but they share the same set of underlying RpcAgents.
697    def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
698        if sparse:
699            t1 = build_sparse_tensor(requires_grad=True)
700            t2 = build_sparse_tensor(requires_grad=True)
701        else:
702            t1 = torch.ones((3, 3), requires_grad=True)
703            t2 = torch.zeros((3, 3), requires_grad=True)
704
705        local_ret = torch.add(t1, t2)
706        if sparse:
707            torch.sparse.sum(local_ret).backward()
708        else:
709            local_ret.sum().backward()
710
711        # create rref on self
712        rref_t1 = rpc.remote(
713            worker_name(self.rank),
714            create_ref_fn,
715            args=())
716
717        # kick off forward and backward pass on three other workers (trainers)
718        rank_diffs = [1, 2, 3]
719        futures = []
720        for rank_diff in rank_diffs:
721            futures.append(
722                rpc.rpc_async(
723                    worker_name((self.rank + rank_diff) % self.world_size),
724                    trainer_fn,
725                    args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
726                )
727            )
728
729        # check if the trainers have done with their backward pass
730        for rank_diff in rank_diffs:
731            self._check_rpc_done(rank_diff)
732
733        # trainers are done and holding the context for verification
734        accumulate_grad_func = None
735        for rank_diff in rank_diffs:
736            # make sure grads are accumulated for the same tensors and values
737            # are all correct
738            ctx_id = ctx_ids[rank_diff]
739            grads = dist_autograd.get_gradients(ctx_id)
740            local_t1 = rref_t1.to_here()
741            self.assertIn(local_t1, grads)
742            self.assertEqual(grads[local_t1], t1.grad)
743
744        # unblock trainers
745        _set_rpc_done(None, 0)
746
747        # wait until all trainers are done
748        torch.futures.wait_all(futures)
749
750    def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
751        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
752            with dist_autograd.context() as context_id:
753                # Multiple RPCs between different nodes.
754                val = self._exec_func(exec_mode, torch.add, t1, t2)
755                val = self._exec_func(exec_mode, torch.mul, t3, val)
756                s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
757                s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
758                if sparse:
759                    val = self._exec_func(exec_mode, torch.mul, s1, s2)
760                    val = self._exec_func(exec_mode, torch.mul, val, val)
761                    loss = torch.sparse.sum(val)
762                else:
763                    val = self._exec_func(exec_mode, torch.bmm, s1, s2)
764                    val = self._exec_func(exec_mode, torch.matmul, val, val)
765                    loss = val.sum()
766
767                ret = self._verify_backwards(
768                    exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
769                )
770                local_grads = ret if ret else local_grads
771
772    def _backward_different_dtypes(self, t1, t2, sparse):
773        local_grads = None
774        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
775            with dist_autograd.context() as context_id:
776                loss = self._exec_func(exec_mode, torch.add, t1, t2)
777                if sparse:
778                    loss = torch.sparse.sum(loss)
779                else:
780                    loss = loss.sum()
781                local_grads = self._verify_backwards(
782                    exec_mode, [loss], context_id, local_grads, t1, t2
783                )
784
785    # Run the same code locally and with dist autograd and verify gradients
786    # are same.
787    def _backward_simple_python_udf(self, t1, t2, sparse):
788        local_grads = None
789        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
790            with dist_autograd.context() as context_id:
791                ret = self._exec_func(exec_mode, my_py_add, t1, t2)
792                if sparse:
793                    loss = torch.sparse.sum(ret)
794                else:
795                    loss = ret.sum()
796                local_grads = self._verify_backwards(
797                    exec_mode, [loss], context_id, local_grads, t1, t2
798                )
799
800    # Run the same code locally and with dist autograd and verify gradients
801    # are same.
802    def _backward_simple_script_call(self, t1, t2, sparse):
803        local_grads = None
804        for exec_mode in [
805            ExecMode.LOCAL,
806            ExecMode.RPC_SYNC,
807            ExecMode.RPC_ASYNC,
808            ExecMode.REMOTE,
809        ]:
810            with dist_autograd.context() as context_id:
811                forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
812                if sparse:
813                    loss = torch.sparse.sum(forward_ret)
814                else:
815                    loss = forward_ret.sum()
816                ret = self._verify_backwards(
817                    exec_mode, [loss], context_id, local_grads, t1, t2
818                )
819                local_grads = ret if ret else local_grads
820
821    def _nested_backward_accumulate_grads(self, t1, t2, sparse):
822        with dist_autograd.context() as context_id:
823            ret = rpc.rpc_sync(
824                worker_name(self._next_rank()),
825                DistAutogradTest._test_nested_backward_accumulate_grads,
826                args=(t1, t2, self._next_rank()),
827            )
828            if sparse:
829                loss = torch.sparse.sum(ret)
830            else:
831                loss = ret.sum()
832            # Run backward twice.
833            dist_autograd.backward(context_id, [loss], retain_graph=True)
834            dist_autograd.backward(context_id, [loss])
835
836    def _backwards_nested_python_udf(self, t1, t2, sparse):
837        t3 = t1 * t2
838        t4 = t1 + t2
839        res = t3 + t4
840        loss = t1 * t2 * t3 * t4 * res
841        if sparse:
842            loss = torch.sparse.sum(loss)
843        else:
844            loss = loss.sum()
845        torch.autograd.backward([loss])
846
847        # Now run distributed autograd.
848        with dist_autograd.context() as context_id:
849            loss = rpc.rpc_sync(
850                worker_name(self._next_rank()),
851                DistAutogradTest._nested_python_udf,
852                args=(t1, t2, self._next_rank()),
853            )
854            if sparse:
855                loss = torch.sparse.sum(loss)
856            else:
857                loss = loss.sum()
858            dist_autograd.backward(context_id, [loss])
859            grads = dist_autograd.get_gradients(context_id)
860            self.assertEqual(t1.grad, grads[t1])
861            self.assertEqual(t2.grad, grads[t2])
862
863    def _mixed_requires_grad(self, t1, t2, sparse):
864        for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
865            with dist_autograd.context() as context_id:
866                ret = self._exec_func(
867                    exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
868                )
869                self.assertEqual(t1 * t2, ret)
870                if sparse:
871                    loss = torch.sparse.sum(ret)
872                else:
873                    loss = ret.sum()
874                dist_autograd.backward(context_id, [loss])
875                self.assertTrue(t1.requires_grad)
876                self.assertFalse(t2.requires_grad)
877                grads = dist_autograd.get_gradients(context_id)
878                self.assertIn(t1, grads)
879                self.assertNotIn(t2, grads)
880                self.assertEqual(t2, grads[t1])
881
882    def _multiple_backward(self, t1, t2, sparse):
883        with dist_autograd.context() as context_id:
884            loss = rpc.rpc_sync(
885                worker_name(self._next_rank()),
886                torch.add,
887                args=(t1, t2))
888            if sparse:
889                loss = torch.sparse.sum(loss)
890            else:
891                loss = loss.sum()
892            # Run backward in a loop multiple times.
893            for i in range(1000):
894                dist_autograd.backward(context_id, [loss], retain_graph=True)
895
896    # For current context, this rank sends t1 and t2 tensors to dst_rank,
897    # then get t3 = torch.add(t1, t2) result tensor.
898    # For the current context in this rank, it expects graph like this:
899    #  send function:
900    #              rpcSendBackward
901    #                  /          \
902    #  t1.AccumulateGrad         t2.AccumulateGrad
903    #
904    #  recv function:
905    #
906    #            |
907    #          t3.rpcRecvBackward
908    #
909    def _verify_graph_for_first_rpc_call(
910        self, send_function, recv_function, t1, t2, ret
911    ):
912        # Retrieve the next functions in the graph.
913        next_funcs = send_function.next_functions
914        self.assertEqual(2, len(next_funcs))
915
916        # We should now hit t1 and t2 in the autograd graph.
917        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name())
918        self.assertEqual(t1, next_funcs[0][0].variable)
919        self.assertEqual(0, next_funcs[0][1])
920        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name())
921        self.assertEqual(t2, next_funcs[1][0].variable)
922        self.assertEqual(0, next_funcs[1][1])
923
924        # Test recv functions.
925        self.assertEqual(ret.grad_fn, recv_function)
926
927    # Run the same code locally and with dist autograd and verify gradients
928    # are same.
929    def _backward_simple(self, dst, t1, t2, local_grads, sparse):
930        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
931            with dist_autograd.context() as context_id:
932                ret = self._exec_func_with_dst(
933                    dst, exec_mode, torch.add, t1, t2
934                )
935                if sparse:
936                    loss = torch.sparse.sum(ret)
937                else:
938                    loss = ret.sum()
939                ret = self._verify_backwards(
940                    exec_mode, [loss], context_id, local_grads, t1, t2
941                )
942                local_grads = ret if ret else local_grads
943
944    # For a context passed from previous nested chain calls, this rank
945    # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
946    # result tensor t3 back.
947    # For this context in this rank, it expects graph like this:
948    #  send and recv functions:
949    #       rpcSendBackward
950    #           |
951    #          t3.AddBackward0
952    #          /             \
953    # t1.recvRpcBackward    t2.recvRpcBackward
954    def _verify_graph_for_rpc_call_exec(self, send_function):
955        # Verify next function is AddBackward0
956        next_funcs = send_function.next_functions
957        self.assertEqual(1, len(next_funcs))
958        add_backward_fn = next_funcs[0][0]
959        self.assertEqual("AddBackward0", add_backward_fn.name())
960
961        # Verify the next two functions are the same recv backward function.
962        next_funcs = add_backward_fn.next_functions
963        self.assertEqual(2, len(next_funcs))
964        self.assertEqual(
965            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
966        )
967        self.assertEqual(
968            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
969        )
970        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
971
972    # For a context passed from previous nested chain calls, this rank
973    # receives two tensors t1 and t2, forwards t1 and t2 tensors using
974    # nested rpc call to next dst. In return route, receive result tensor t3
975    # from next dst and forwarding t3 back to previous calls.
976    # For this context in this rank, it expects graph like this:
977    #  send and recv functions for receiving and forwarding t1 and t2:
978    #       rpcSendBackward
979    #          /          \
980    # t1.recvRpcBackward    t2.recvRpcBackward
981    #  send and recv functions for receiving and forwarding t3:
982    #       rpcSendBackward
983    #             |
984    #           t3.recvRpcBackward
985    def _verify_graph_for_nested_rpc_call(self, ctx):
986        send_functions = ctx._send_functions()
987        self.assertEqual(2, len(send_functions))
988
989        # For send function when making nest rpc call,
990        # next functions of the send function are two recv functions
991        # for received two tensors from previous call
992        next_funcs = next(iter(send_functions.values())).next_functions
993        self.assertEqual(2, len(next_funcs))
994        self.assertEqual(
995            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
996        )
997        self.assertEqual(
998            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
999        )
1000        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
1001
1002        # For send function when returning response to previous call
1003        # next function of the send function is the recv function
1004        # for received tensor result returned from nested call
1005        next_funcs = list(send_functions.values())[1].next_functions
1006        self.assertEqual(1, len(next_funcs))
1007        self.assertEqual(
1008            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
1009        )
1010
1011
1012class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
1013
1014    # Sparse tests only work with TensorPipeAgent.
1015    @dist_init
1016    def test_graph_for_builtin_call_sparse(self):
1017        self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
1018
1019    @dist_init
1020    def test_graph_for_python_call_sparse(self):
1021        self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
1022
1023    @dist_init
1024    def test_graph_for_builtin_remote_call_sparse(self):
1025        self._test_graph(torch.add, ExecMode.REMOTE, True)
1026
1027    @dist_init
1028    def test_graph_for_python_remote_call_sparse(self):
1029        self._test_graph(my_py_add, ExecMode.REMOTE, True)
1030
1031    @dist_init
1032    def test_graph_for_py_nested_call_sparse(self):
1033        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
1034
1035    @dist_init
1036    def test_graph_for_py_nested_remote_call_sparse(self):
1037        self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
1038
1039    @dist_init
1040    def test_graph_for_py_nested_call_itself_sparse(self):
1041        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
1042
1043    @dist_init
1044    def test_graph_for_py_nested_remote_call_itself_sparse(self):
1045        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
1046
1047    @dist_init
1048    def test_no_graph_with_tensors_not_require_grad_sparse(self):
1049        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
1050
1051    @dist_init
1052    def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
1053        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
1054
1055    @dist_init
1056    def test_rpc_complex_args_sparse(self):
1057        self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
1058
1059    @dist_init
1060    def test_remote_complex_args_sparse(self):
1061        self._test_rpc_complex_args(ExecMode.REMOTE, True)
1062
1063    @dist_init
1064    def test_context_cleanup_tensor_with_grad_sparse(self):
1065        t1 = build_sparse_tensor(requires_grad=True)
1066        t2 = build_sparse_tensor(requires_grad=True)
1067        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
1068
1069    @dist_init
1070    def test_context_cleanup_tensor_no_grad_sparse(self):
1071        t1 = build_sparse_tensor(requires_grad=False)
1072        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
1073
1074    @dist_init
1075    def test_context_cleanup_nested_rpc_sparse(self):
1076        t1 = build_sparse_tensor(requires_grad=True)
1077        t2 = build_sparse_tensor(requires_grad=True)
1078        dst_rank = (self.rank + 1) % self.world_size
1079        args = (t1, t2, dst_rank, self.world_size, 0)
1080        self.context_cleanup_test_helper(
1081            rpc_args=args, func=my_py_nested_call, nested=True
1082        )
1083
1084    @dist_init
1085    def test_backward_no_grad_on_tensor_sparse(self):
1086        self._backward_no_grad_on_tensor(
1087            build_sparse_tensor(requires_grad=True),
1088            build_sparse_tensor(requires_grad=True),
1089            True
1090        )
1091
1092    @dist_init
1093    def test_backward_simple_sparse(self):
1094        self._backward_simple(
1095            self._next_rank(),
1096            build_sparse_tensor(requires_grad=True),
1097            build_sparse_tensor(requires_grad=True),
1098            None,
1099            True
1100        )
1101
1102    @dist_init
1103    def test_backward_simple_self_sparse(self):
1104        self._backward_simple(
1105            self.rank,
1106            build_sparse_tensor(requires_grad=True),
1107            build_sparse_tensor(requires_grad=True),
1108            None,
1109            True
1110        )
1111
1112    @dist_init
1113    def test_backward_rref_multi_sparse(self):
1114        if self.rank > 0:
1115            callee = "worker0"
1116            rref_owner = callee
1117            self._backward_rref(
1118                callee,
1119                rref_owner,
1120                build_sparse_tensor(requires_grad=True),
1121                build_sparse_tensor(requires_grad=True),
1122                None,
1123                True
1124            )
1125
1126    @dist_init
1127    def test_backward_rref_sparse(self):
1128        callee = worker_name(self._next_rank())
1129        rref_owner = callee
1130        self._backward_rref(
1131            callee,
1132            rref_owner,
1133            build_sparse_tensor(requires_grad=True),
1134            build_sparse_tensor(requires_grad=True),
1135            None,
1136            True
1137        )
1138
1139    @dist_init
1140    def test_backward_rref_nested_sparse(self):
1141        callee = worker_name((self.rank + 1) % self.world_size)
1142        rref_owner = worker_name((self.rank + 2) % self.world_size)
1143        self._backward_rref(
1144            callee,
1145            rref_owner,
1146            build_sparse_tensor(requires_grad=True),
1147            build_sparse_tensor(requires_grad=True),
1148            None,
1149            True
1150        )
1151
1152    @dist_init
1153    def test_trainer_ps_sparse(self):
1154        self._test_trainer_ps(
1155            build_sparse_tensor,
1156            _run_trainer,
1157            True
1158        )
1159
1160    @dist_init
1161    def test_backward_multiple_round_trips_sparse(self):
1162        self._backward_multiple_round_trips(
1163            build_sparse_tensor(requires_grad=True),
1164            build_sparse_tensor(requires_grad=False),
1165            build_sparse_tensor(requires_grad=True),
1166            build_sparse_tensor(requires_grad=False),
1167            build_sparse_tensor(requires_grad=True),
1168            None,
1169            True
1170        )
1171
1172    @dist_init
1173    def test_backward_different_dtypes_sparse(self):
1174        self._backward_different_dtypes(
1175            build_sparse_tensor(requires_grad=True, dtype=torch.float32),
1176            build_sparse_tensor(requires_grad=True, dtype=torch.float64),
1177            True
1178        )
1179
1180    @dist_init
1181    def test_backward_simple_python_udf_sparse(self):
1182        self._backward_simple_python_udf(
1183            build_sparse_tensor(requires_grad=True),
1184            build_sparse_tensor(requires_grad=True),
1185            True
1186        )
1187
1188    @dist_init
1189    def test_backward_simple_script_call_sparse(self):
1190        self._backward_simple_script_call(
1191            build_sparse_tensor(requires_grad=True),
1192            build_sparse_tensor(requires_grad=True),
1193            True
1194        )
1195
1196    @dist_init
1197    def test_nested_backward_accumulate_grads_sparse(self):
1198        self._nested_backward_accumulate_grads(
1199            build_sparse_tensor(requires_grad=True),
1200            build_sparse_tensor(requires_grad=True),
1201            True
1202        )
1203
1204    @dist_init
1205    def test_backwards_nested_python_udf_sparse(self):
1206        # Run equivalent of _nested_python_udf locally.
1207        self._backwards_nested_python_udf(
1208            build_sparse_tensor(requires_grad=True),
1209            build_sparse_tensor(requires_grad=True),
1210            True
1211        )
1212
1213    @dist_init
1214    def test_mixed_requires_grad_sparse(self):
1215        self._mixed_requires_grad(
1216            build_sparse_tensor(requires_grad=True),
1217            build_sparse_tensor(requires_grad=False),
1218            True
1219        )
1220
1221    @dist_init
1222    def test_multiple_backward_sparse(self):
1223        self._multiple_backward(
1224            build_sparse_tensor(requires_grad=True),
1225            build_sparse_tensor(requires_grad=True),
1226            True
1227        )
1228
1229    @dist_init
1230    def test_embedding_bag_with_no_grad_tensors(self):
1231        dst = self._next_rank()
1232        remote_embedding = rpc.remote(
1233            worker_name(dst),
1234            torch.nn.EmbeddingBag,
1235            args=(16, 16),
1236            kwargs={"mode": "sum", "sparse": True},
1237        )
1238        local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True)
1239
1240        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1241        # requires_grad = True to record send/recv functions
1242        per_sample_weights = torch.rand((8), requires_grad=True)
1243        offsets = torch.LongTensor([0, 4])
1244
1245        local_res = local_embedding(input, offsets, per_sample_weights)
1246
1247        # Run backward twice.
1248        torch.autograd.backward([local_res.sum()], retain_graph=True)
1249        torch.autograd.backward([local_res.sum()])
1250        local_grad = local_embedding.weight.grad
1251
1252        with dist_autograd.context() as context_id:
1253            res = rpc.rpc_sync(
1254                worker_name(dst),
1255                DistAutogradTest._call_remote_embedding,
1256                args=(remote_embedding, input, offsets, per_sample_weights),
1257            )
1258
1259            # Run backward twice to test accumulation of sparse gradients.
1260            dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
1261            dist_autograd.backward(context_id, [res.sum()])
1262
1263            remote_grad = rpc.rpc_sync(
1264                worker_name(dst),
1265                DistAutogradTest._get_grad,
1266                args=(remote_embedding, context_id),
1267            )
1268
1269            self.assertEqual(local_grad, remote_grad)
1270
1271
1272class DistAutogradTest(CommonDistAutogradTest):
1273    @dist_init
1274    def test_autograd_context(self):
1275        # Verify max possible id.
1276        max_auto_increment = 281474976710655
1277        self.assertEqual(
1278            max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()
1279        )
1280
1281        context_ids = []
1282        for i in range(200):
1283            with dist_autograd.context() as context_id:
1284                self.assertEqual(
1285                    context_id,
1286                    dist_autograd._retrieve_context(context_id)._context_id(),
1287                )
1288                # First 16 bits should be worker_id.
1289                self.assertEqual(self.worker_id, context_id >> 48)
1290                context_ids.append(context_id)
1291
1292        for context_id in context_ids:
1293            with self.assertRaisesRegex(
1294                RuntimeError,
1295                f"Could not find autograd context with id: {context_id}",
1296            ):
1297                dist_autograd._retrieve_context(context_id)
1298
1299    @dist_init
1300    def test_nested_context(self):
1301        with dist_autograd.context() as context_id:
1302            # Nested contexts not supported.
1303            with self.assertRaisesRegex(
1304                RuntimeError, "Already have an autograd context id for this thread"
1305            ):
1306                with dist_autograd.context() as context_id:
1307                    pass
1308
1309    @dist_init
1310    def test_graph_for_builtin_call(self):
1311        self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
1312
1313    @dist_init
1314    def test_graph_for_python_call(self):
1315        self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
1316
1317    @dist_init
1318    def test_graph_for_builtin_remote_call(self):
1319        self._test_graph(torch.add, ExecMode.REMOTE, False)
1320
1321    @dist_init
1322    def test_graph_for_python_remote_call(self):
1323        self._test_graph(my_py_add, ExecMode.REMOTE, False)
1324
1325    @dist_init
1326    def test_graph_for_py_nested_call(self):
1327        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
1328
1329    @dist_init
1330    def test_graph_for_py_nested_remote_call(self):
1331        self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
1332
1333    @dist_init
1334    def test_graph_for_py_nested_call_itself(self):
1335        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
1336
1337    @dist_init
1338    def test_graph_for_py_nested_remote_call_itself(self):
1339        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
1340
1341    @dist_init
1342    def test_no_graph_with_tensors_not_require_grad(self):
1343        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
1344
1345    @dist_init
1346    def test_no_graph_with_tensors_not_require_grad_remote(self):
1347        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
1348
1349    def _test_grad_only_on_return_value(self, exec_mode):
1350        initialize_pg(self.file_init_method, self.rank, self.world_size)
1351        dst_rank = (self.rank + 1) % self.world_size
1352        with dist_autograd.context() as context_id:
1353            if ExecMode.RPC_SYNC == exec_mode:
1354                ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
1355            elif ExecMode.REMOTE == exec_mode:
1356                ret = rpc.remote(
1357                    worker_name(dst_rank), ret_requires_grad
1358                ).to_here()
1359            else:
1360                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
1361
1362            dist_autograd.backward(context_id, [ret.sum()])
1363
1364            rpc.rpc_sync(
1365                worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
1366            )
1367
1368            # Wait for the prev rank to be done with rpc.
1369            self._check_rpc_done(1)
1370            grads = dist_autograd.get_gradients(ctx_ids[1])
1371            self.assertEqual(1, len(grads))
1372            self.assertIn(requires_grad_tensor, grads)
1373            self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor])
1374            # due to the above get_gradients call, ensure that dist autograd
1375            # contexts aren't cleaned up until all workers exit context managers
1376            dist.barrier()
1377
1378    @dist_init
1379    def test_grad_only_on_return_value(self):
1380        self._test_grad_only_on_return_value(ExecMode.RPC_SYNC)
1381
1382    @dist_init
1383    def test_grad_only_on_return_value_remote(self):
1384        self._test_grad_only_on_return_value(ExecMode.REMOTE)
1385
1386    @dist_init
1387    def test_rpc_complex_args(self):
1388        self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
1389
1390    @dist_init
1391    def test_remote_complex_args(self):
1392        self._test_rpc_complex_args(ExecMode.REMOTE, False)
1393
1394    @dist_init
1395    def test_context_cleanup_tensor_with_grad(self):
1396        t1 = torch.ones(3, 3, requires_grad=True)
1397        t2 = torch.zeros(3, 3, requires_grad=True)
1398        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
1399
1400    @dist_init
1401    def test_context_cleanup_tensor_no_grad(self):
1402        t1 = torch.ones(3, 3, requires_grad=False)
1403        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
1404
1405    @dist_init
1406    def test_context_cleanup_no_tensors(self):
1407        self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
1408
1409    @dist_init
1410    def test_context_cleanup_nested_rpc(self):
1411        t1 = torch.ones(3, 3, requires_grad=True)
1412        t2 = torch.zeros(3, 3, requires_grad=True)
1413        dst_rank = (self.rank + 1) % self.world_size
1414        args = (t1, t2, dst_rank, self.world_size, 0)
1415        self.context_cleanup_test_helper(
1416            rpc_args=args, func=my_py_nested_call, nested=True
1417        )
1418
1419    @dist_init
1420    def test_worker_ids_recorded(self):
1421        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
1422        with dist_autograd.context() as context_id:
1423            # if no tensors require grad, we should still record worker_ids, as
1424            # the autograd context ID is still passed to other workers.
1425            t1 = torch.ones(3, 3, requires_grad=False)
1426            t2 = torch.zeros(3, 3, requires_grad=False)
1427            for dst_rank in dst_ranks:
1428                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
1429                rpc.rpc_sync(
1430                    worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
1431                )
1432            # all worker_ids in dst_ranks should be recorded.
1433            ctx = dist_autograd._current_context()
1434            worker_ids = ctx._known_worker_ids()
1435            self.assertEqual(worker_ids, dst_ranks)
1436
1437            # worker_ids should be recorded when tensors do require grad
1438            t1.requires_grad = True
1439            t2.requires_grad = True
1440            for dst_rank in dst_ranks:
1441                ret = rpc.rpc_sync(
1442                    worker_name(dst_rank), torch.add, args=(t1, t2)
1443                )
1444                rpc.rpc_sync(
1445                    worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
1446                )
1447            # all worker_ids in dst_ranks should be recorded.
1448            worker_ids = ctx._known_worker_ids()
1449            self.assertEqual(worker_ids, dst_ranks)
1450
1451    @dist_init
1452    def test_dist_autograd_profiling(self):
1453        with dist_autograd.context() as context_id:
1454            t1 = torch.rand(3, 3, requires_grad=True)
1455            t2 = torch.rand(3, 3, requires_grad=True)
1456            loss = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)).sum()
1457            with torch.autograd.profiler.profile() as p:
1458                dist_autograd.backward(context_id, [loss])
1459
1460        function_events = p.function_events
1461
1462        def get_event(partial_key):
1463            return next(event for event in function_events if partial_key in event.name)
1464
1465        send_event = get_event("SendRpcBackward")
1466        recv_event = get_event("RecvRpcBackward")
1467        backward_event = get_event("torch::distributed::autograd::backward")
1468        # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed.
1469        self.assertEqual(send_event.count, 1)
1470        self.assertEqual(recv_event.count, 1)
1471        # The CPU total for backward event should be great than send and recv, since
1472        # applying those functions in the backwards pass is a subset of the entire backward pass.
1473        self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total)
1474        self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total)
1475
1476    @dist_init
1477    def test_error_in_context(self):
1478        with dist_autograd.context() as context_id:
1479            t1 = torch.rand(3, 3, requires_grad=True)
1480            t2 = torch.rand(6, 6, requires_grad=True)
1481
1482            with self.assertRaises(RuntimeError):
1483                # This should throw an error since matrix sizes don't match.
1484                rpc.rpc_sync(
1485                    worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
1486                )
1487
1488    @dist_init
1489    def test_backward_no_grad_on_tensor(self):
1490        self._backward_no_grad_on_tensor(
1491            torch.rand((3, 3), requires_grad=True),
1492            torch.rand((3, 3), requires_grad=True),
1493            False
1494        )
1495
1496    @dist_init
1497    def test_backward_simple(self):
1498        self._backward_simple(
1499            self._next_rank(),
1500            torch.rand((3, 3), requires_grad=True),
1501            torch.rand((3, 3), requires_grad=True),
1502            None,
1503            False
1504        )
1505
1506    @dist_init
1507    def test_backward_simple_self(self):
1508        self._backward_simple(
1509            self.rank,
1510            torch.rand((3, 3), requires_grad=True),
1511            torch.rand((3, 3), requires_grad=True),
1512            None,
1513            False
1514        )
1515
1516    @dist_init
1517    def test_backward_rref(self):
1518        callee = worker_name(self._next_rank())
1519        rref_owner = callee
1520        self._backward_rref(
1521            callee,
1522            rref_owner,
1523            torch.rand((3, 3), requires_grad=True),
1524            torch.rand((3, 3), requires_grad=True),
1525            None,
1526            False
1527        )
1528
1529    @dist_init
1530    def test_backward_rref_multi(self):
1531        if self.rank > 0:
1532            callee = "worker0"
1533            rref_owner = callee
1534            self._backward_rref(
1535                callee,
1536                rref_owner,
1537                torch.rand((3, 3), requires_grad=True),
1538                torch.rand((3, 3), requires_grad=True),
1539                None,
1540                False
1541            )
1542
1543    @dist_init
1544    def test_backward_rref_nested(self):
1545        callee = worker_name((self.rank + 1) % self.world_size)
1546        rref_owner = worker_name((self.rank + 2) % self.world_size)
1547        self._backward_rref(
1548            callee,
1549            rref_owner,
1550            torch.rand((3, 3), requires_grad=True),
1551            torch.rand((3, 3), requires_grad=True),
1552            None,
1553            False
1554        )
1555
1556    @dist_init
1557    def test_trainer_ps(self):
1558        self._test_trainer_ps(
1559            create_tensor,
1560            _run_trainer,
1561            False
1562        )
1563
1564    @dist_init
1565    def test_trainer_ps_torchscript_functions(self):
1566        # TODO, need more investigation
1567        # there is rref leak when shutting down, suspect it is because
1568        # ref as arg is passed to pybind boundary, and the ref is not garbage
1569        # collected by python when calling shutdown()
1570        import torch.distributed.rpc.api as api
1571        api._ignore_rref_leak = True
1572
1573        self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False)
1574
1575    @dist_init
1576    def test_backward_multiple_round_trips(self):
1577        self._backward_multiple_round_trips(
1578            torch.rand((3, 3), requires_grad=True),
1579            torch.rand((3, 3)),
1580            torch.rand((3, 3), requires_grad=True),
1581            torch.rand((3, 3)),
1582            torch.rand((3, 3), requires_grad=True),
1583            None,
1584            False
1585        )
1586
1587    @dist_init
1588    def test_backward_different_tensor_dims(self):
1589        local_grads = None
1590        t1 = torch.rand((4, 6), requires_grad=True)
1591        t2 = torch.rand((6, 5))
1592        t3 = torch.rand((5, 7), requires_grad=True)
1593        t4 = torch.rand((7, 9))
1594
1595        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
1596            with dist_autograd.context() as context_id:
1597                val = self._exec_func(exec_mode, torch.matmul, t1, t2)
1598                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4))
1599                loss = val.sum()
1600
1601                ret = self._verify_backwards(
1602                    exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4
1603                )
1604                local_grads = ret if ret else local_grads
1605
1606    @dist_init
1607    def test_backward_unused_tensors(self):
1608        local_grads = None
1609        t1 = torch.rand((3, 3), requires_grad=True)
1610        t2 = torch.rand((3, 3), requires_grad=True)
1611        t3 = torch.rand((3, 3), requires_grad=True)
1612        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
1613            with dist_autograd.context() as context_id:
1614                s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
1615                val = self._exec_func(
1616                    exec_mode,
1617                    torch.matmul,
1618                    torch.narrow(s, 0, 0, 1),
1619                    torch.narrow(s, 0, 2, 1),
1620                )
1621
1622                loss = val.sum()
1623                ret = self._verify_backwards(
1624                    exec_mode, [loss], context_id, local_grads, t1, t2, t3
1625                )
1626                local_grads = ret if ret else local_grads
1627
1628    @dist_init
1629    def test_backward_multiple_output_tensors(self):
1630        local_grads = None
1631        t = torch.rand((10, 2), requires_grad=True)
1632        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
1633            with dist_autograd.context() as context_id:
1634                tensor_list = self._exec_func(exec_mode, torch.split, t, 2)
1635                t1 = tensor_list[0]
1636                t2 = tensor_list[2]
1637                t3 = tensor_list[4]
1638
1639                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3))
1640
1641                loss = val.sum()
1642                ret = self._verify_backwards(
1643                    exec_mode, [loss], context_id, local_grads, t
1644                )
1645                local_grads = ret if ret else local_grads
1646
1647    def _run_test_backward_unused_send_function_in_thread(self):
1648        with dist_autograd.context() as context_id:
1649            t1 = torch.rand((3, 3), requires_grad=True)
1650            t2 = torch.rand((3, 3), requires_grad=True)
1651
1652            # We don't use the result of an RPC function, as a result the
1653            # backward pass would hang in the "FAST" mode.
1654            res = rpc.rpc_sync(
1655                worker_name(self._next_rank()), torch.add, args=(t1, t2)
1656            )
1657
1658            val = torch.mul(t1, t2)
1659
1660            # Run backward, this would hang forever.
1661            dist_autograd.backward(context_id, [val.sum()])
1662
1663    @dist_init
1664    def test_backward_unused_send_function(self):
1665        # Run the test in a thread which would never finish.
1666        t = threading.Thread(
1667            target=self._run_test_backward_unused_send_function_in_thread
1668        )
1669        t.daemon = True
1670        t.start()
1671        t.join(10)  # Wait for 10s.
1672
1673        # Verify thread is still alive (indicating backward hasn't completed yet).
1674        self.assertTrue(t.is_alive())
1675
1676    @dist_init
1677    def test_backward_autograd_engine_error(self):
1678        with dist_autograd.context() as context_id:
1679            t1 = torch.rand((3, 3), requires_grad=True)
1680            t2 = torch.rand((3, 3), requires_grad=True)
1681            # Perform some ops before error simulation.
1682            tmp = (t1 + t2) * (t1 + t2)
1683            t3 = SimulateBackwardError.apply(tmp)
1684
1685            # Run multiple round trips across different nodes and verify the
1686            # original node receives an error thrown on a node deep in the chain.
1687            val = rpc.rpc_sync(
1688                worker_name(self._next_rank()), torch.add, args=(t2, t3)
1689            )
1690            val = rpc.rpc_sync(
1691                worker_name(self._next_rank()), torch.mul, args=(val, t2)
1692            )
1693            val = rpc.rpc_sync(
1694                worker_name(self._next_rank()), torch.matmul, args=(val, t2)
1695            )
1696            val = rpc.rpc_sync(
1697                worker_name(self._next_rank()), torch.div, args=(val, t2)
1698            )
1699
1700            with self.assertRaisesRegex(
1701                RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass"
1702            ):
1703                # Run backwards, and validate we receive an error.
1704                dist_autograd.backward(context_id, [val.sum()])
1705
1706    @dist_init(clean_shutdown=False)
1707    @skip_but_pass_in_sandcastle_if(
1708        IS_MACOS,
1709        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
1710    )
1711    def test_backward_node_failure(self):
1712        rpc._set_rpc_timeout(5)  # 5 seconds
1713        initialize_pg(self.file_init_method, self.rank, self.world_size)
1714
1715        with dist_autograd.context() as context_id:
1716            t1 = torch.rand((3, 3), requires_grad=True)
1717            t2 = torch.rand((3, 3), requires_grad=True)
1718            res = rpc.rpc_sync(
1719                worker_name(self._next_rank()), torch.add, args=(t1, t2)
1720            )
1721
1722            # Wait for all RPCs to be done.
1723            dist.barrier()
1724
1725            # Kill all odd rank nodes.
1726            if self.rank % 2 == 0:
1727                shutdown_error_regex = self.get_shutdown_error_regex()
1728                # Wait for all other nodes to die.
1729                for rank in range(self.world_size):
1730                    if rank % 2 != 0:
1731                        wait_until_node_failure(rank, shutdown_error_regex)
1732
1733                # Shutdown sequence is not very well defined and as a result
1734                # we might see any error given by get_shutdown_error_regex()
1735                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
1736                    # Run backwards, and validate we receive an error since all
1737                    # other nodes are dead.
1738                    dist_autograd.backward(context_id, [res.sum()])
1739            else:
1740                # Exit all other nodes.
1741                pass
1742
1743    @dist_init
1744    def test_backward_without_context(self):
1745        t1 = torch.rand((3, 3), requires_grad=True)
1746        t2 = torch.rand((3, 3), requires_grad=True)
1747
1748        context_id = 100  # dummy context_id
1749        with self.assertRaisesRegex(
1750            RuntimeError,
1751            f"Could not find autograd context with id: {context_id}",
1752        ):
1753            res = rpc.rpc_sync(
1754                worker_name(self._next_rank()), torch.add, args=(t1, t2)
1755            )
1756            dist_autograd.backward(context_id, [res.sum()])
1757
1758    @dist_init
1759    def test_backward_without_rpc(self):
1760        dst_rank = self.rank
1761        with dist_autograd.context() as context_id:
1762            t1 = torch.rand((3, 3), requires_grad=True)
1763            t2 = torch.rand((3, 3), requires_grad=True)
1764            t3 = torch.add(t1, t2)
1765
1766            dist_autograd.backward(context_id, [t3.sum()])
1767            grads = dist_autograd.get_gradients(context_id)
1768            self.assertEqual(2, len(grads))
1769            self.assertIn(t1, grads)
1770            self.assertIn(t2, grads)
1771            self.assertEqual(torch.ones(3, 3), grads[t1])
1772            self.assertEqual(torch.ones(3, 3), grads[t2])
1773
1774    @dist_init
1775    def test_backward_invalid_args(self):
1776        with dist_autograd.context() as context_id:
1777
1778            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
1779                dist_autograd.backward(context_id, None)
1780
1781            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
1782                dist_autograd.backward(None, None)
1783
1784            with self.assertRaisesRegex(
1785                RuntimeError, "No tensors provided for gradient computation"
1786            ):
1787                dist_autograd.backward(context_id, [])
1788
1789            with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
1790                t = torch.rand(3, 3)
1791                dist_autograd.backward(context_id, [t])
1792
1793            with self.assertRaisesRegex(
1794                RuntimeError, "is not a scalar, all roots need to be scalar"
1795            ):
1796                t = torch.rand(3, 3, requires_grad=True)
1797                dist_autograd.backward(context_id, [t])
1798
1799            with self.assertRaisesRegex(
1800                RuntimeError, "does not have a valid gradient function"
1801            ):
1802                t = torch.rand(1, requires_grad=True)
1803                dist_autograd.backward(context_id, [t])
1804
1805    @dist_init
1806    def test_backward_multiple_roots(self):
1807        local_grads = None
1808        t1 = torch.rand((3, 3), requires_grad=True)
1809        t2 = torch.rand((3, 3), requires_grad=True)
1810        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
1811            with dist_autograd.context() as context_id:
1812                r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum()
1813                r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum()
1814                r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
1815                r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
1816
1817                local_grads = self._verify_backwards(
1818                    exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
1819                )
1820
1821    @dist_init
1822    def test_backward_different_dtypes(self):
1823        self._backward_different_dtypes(
1824            torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
1825            torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
1826            False
1827        )
1828
1829    @dist_init
1830    def test_backward_simple_python_udf(self):
1831        self._backward_simple_python_udf(
1832            torch.rand(3, 3, requires_grad=True),
1833            torch.rand(3, 3, requires_grad=True),
1834            False
1835        )
1836
1837    @dist_init
1838    def test_backward_simple_script_call(self):
1839        self._backward_simple_script_call(
1840            torch.rand(3, 3, requires_grad=True),
1841            torch.rand(3, 3, requires_grad=True),
1842            False
1843        )
1844
1845    @staticmethod
1846    def _complex_python_udf(t1, t2):
1847        t3 = torch.nn.functional.linear(t1, t2)
1848        t4 = torch.nn.functional.linear(t2, t3)
1849        t5 = torch.nn.functional.linear(t3, t4)
1850        return torch.linalg.multi_dot([t1, t2, t3, t4, t5])
1851
1852    @dist_init
1853    def test_backward_complex_python_udf(self):
1854        # Run the same code locally and with dist autograd and verify gradients
1855        # are same.
1856        local_grads = None
1857        t1 = torch.rand((3, 3), requires_grad=True)
1858        t2 = torch.rand((3, 3), requires_grad=True)
1859        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
1860            with dist_autograd.context() as context_id:
1861                ret = self._exec_func(
1862                    exec_mode, DistAutogradTest._complex_python_udf, t1, t2
1863                )
1864                loss = ret.sum()
1865                local_grads = self._verify_backwards(
1866                    exec_mode, [loss], context_id, local_grads, t1, t2
1867                )
1868
1869    @staticmethod
1870    def _python_udf_with_backward_error(t1, t2):
1871        t3 = t1 + t2
1872        t4 = SimulateBackwardError.apply(t3)
1873        return torch.linalg.multi_dot([t1, t2, t3, t4])
1874
1875    @staticmethod
1876    def _nested_rpc_call_backward_error(t1, t2, dst):
1877        t1 = t1 * t2
1878        t2 = t1 + t2
1879        res = rpc.rpc_sync(
1880            worker_name(dst),
1881            DistAutogradTest._python_udf_with_backward_error,
1882            args=(t1, t2),
1883        )
1884        return torch.linalg.multi_dot([t1, t2, res])
1885
1886    @dist_init
1887    def test_backward_python_udf_error(self):
1888        t1 = torch.rand((3, 3), requires_grad=True)
1889        t2 = torch.rand((3, 3), requires_grad=True)
1890        with dist_autograd.context() as context_id:
1891            loss = rpc.rpc_sync(
1892                worker_name(self._next_rank()),
1893                DistAutogradTest._nested_rpc_call_backward_error,
1894                args=(t1, t2, self._next_rank()),
1895            )
1896            with self.assertRaisesRegex(
1897                RuntimeError, "Simulate error on backward pass"
1898            ):
1899                dist_autograd.backward(context_id, [loss.sum()])
1900
1901    _backward_done = False
1902
1903    @dist_init(clean_shutdown=False)
1904    @skip_but_pass_in_sandcastle_if(
1905        IS_MACOS,
1906        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
1907    )
1908    def test_backward_node_failure_python_udf(self):
1909        # Set a short timeout to quickly time out failed RPCs.
1910        rpc._set_rpc_timeout(5)  # 5 seconds
1911        initialize_pg(self.file_init_method, self.rank, self.world_size)
1912
1913        with dist_autograd.context() as context_id:
1914            t1 = torch.rand((3, 3), requires_grad=True)
1915            t2 = torch.rand((3, 3), requires_grad=True)
1916
1917            dst = self._next_rank()
1918            res = rpc.rpc_sync(
1919                worker_name(dst),
1920                my_py_nested_call,
1921                args=(t1, t2, dst, self.world_size, 1),
1922            )
1923
1924            dist.barrier()
1925
1926            # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error.
1927            if self.rank == 2:
1928                return
1929
1930            store = dist.distributed_c10d._get_default_store()
1931            if self.rank == 0:
1932                # Wait for rank 2 to die.
1933                shutdown_error_regex = self.get_shutdown_error_regex()
1934                wait_until_node_failure(2, shutdown_error_regex)
1935                # Shutdown sequence is not very well defined and as a result
1936                # we might see any error given by get_shutdown_error_regex().
1937                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
1938                    # Run backwards, and validate we receive an error since rank 2 is dead.
1939                    dist_autograd.backward(context_id, [res.sum()])
1940
1941                # Mark rank 0 is done in the store, since the RPC framework on
1942                # some nodes might be broken at this point.
1943                store.set('test_backward_node_failure_python_udf_rank0_done', "True")
1944            else:
1945                # Wait for backward to finish on rank 0.
1946                store.wait(['test_backward_node_failure_python_udf_rank0_done'], timedelta(seconds=10))
1947
1948    @staticmethod
1949    def _nested_python_udf(t1, t2, dst):
1950        t3 = t1 * t2
1951        t4 = t1 + t2
1952        res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
1953        return t1 * t2 * t3 * t4 * res
1954
1955    @dist_init
1956    def test_backwards_nested_python_udf(self):
1957        # Run equivalent of _nested_python_udf locally.
1958        self._backwards_nested_python_udf(
1959            torch.rand(3, 3, requires_grad=True),
1960            torch.rand(3, 3, requires_grad=True),
1961            False
1962        )
1963
1964    _test_clean_context_backward_context_id = None
1965
1966    class MyBackwardFunc(Function):
1967        @staticmethod
1968        def forward(ctx, input):
1969            return input
1970
1971        @staticmethod
1972        @once_differentiable
1973        def backward(ctx, input):
1974            assert DistAutogradTest._test_clean_context_backward_context_id is not None
1975
1976            # Release the context to simulate error (use barrier before releasing
1977            # context to ensure all nodes execute the backward function).
1978            dist.barrier()
1979            dist_autograd._release_context(
1980                DistAutogradTest._test_clean_context_backward_context_id
1981            )
1982
1983            # Verify all contexts are cleaned up.
1984            assert _all_contexts_cleaned_up()
1985
1986            return input
1987
1988    @dist_init
1989    def test_clean_context_during_backward(self):
1990        """
1991        This test simulates the situation where the 'backward' call might throw
1992        an exception locally which would lead to the autograd context being
1993        cleaned up if we're using the context manager. As a result, the autograd
1994        context might be cleaned up while some threads are still using the
1995        autograd context.
1996
1997        It is fine for the 'backward' call to throw an exception in this test,
1998        but the process should not crash.
1999        """
2000        initialize_pg(self.file_init_method, self.rank, self.world_size)
2001
2002        context = dist_autograd._new_context()
2003        context_id = context._context_id()
2004        DistAutogradTest._test_clean_context_backward_context_id = context_id
2005
2006        # Send the context id to all nodes.
2007        for i in range(0, self.world_size):
2008            if i != self.rank:
2009                rank_distance = (i - self.rank + self.world_size) % self.world_size
2010                rpc.rpc_sync(
2011                    worker_name(i),
2012                    _set_rpc_done,
2013                    args=(context_id, rank_distance),
2014                )
2015
2016        dist.barrier()
2017
2018        # Verify all context ids have been received.
2019        self.assertEqual(self.world_size - 1, len(known_context_ids))
2020
2021        t1 = torch.rand((3, 3), requires_grad=True)
2022        for i in range(0, 100):
2023            dst = self._next_rank()
2024            t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
2025
2026        # Call MyBackwardFunc as the first op of the backward pass to
2027        # ensure we release the context early in the backward pass.
2028        t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
2029        self.assertEqual(100, len(context._send_functions()))
2030
2031        context_id = 100  # dummy context_id
2032        with self.assertRaisesRegex(
2033            RuntimeError,
2034            f"Could not find autograd context with id: {context_id}",
2035        ):
2036            dist_autograd.backward(context_id, [t1.sum()])
2037
2038        # HACK: Killing workers since otherwise the autograd engine gets stuck on
2039        # other nodes. The proper fix would be addressing:
2040        # https://github.com/pytorch/pytorch/issues/27643, which would inform
2041        # other nodes about the failure.
2042        # The autograd engine gets stuck on other nodes since they're waiting to
2043        # receive gradients from the node that received an error (and as a
2044        # result it didn't execute the rest of the graph).
2045        dist.barrier()
2046        rpc.shutdown(graceful=False)
2047        sys.exit(0)
2048
2049    @classmethod
2050    def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
2051        embedding = embedding_rref.local_value()
2052        return embedding(input, offsets, per_sample_weights)
2053
2054    @classmethod
2055    def _get_grad(cls, embedding_rref, context_id):
2056        embedding = embedding_rref.local_value()
2057        grad_map = dist_autograd.get_gradients(context_id)
2058        return grad_map[embedding.weight]
2059
2060    @classmethod
2061    def _mixed_requires_grad_operaton(cls, t1, t2):
2062        if t2.requires_grad:
2063            return t1 - t2
2064        else:
2065            return t1 * t2
2066
2067    @dist_init
2068    def test_mixed_requires_grad(self):
2069        self._mixed_requires_grad(
2070            torch.rand(3, 3, requires_grad=True),
2071            torch.rand(3, 3, requires_grad=False),
2072            False
2073        )
2074
2075    class TestDebugInfoFunc(Function):
2076        @staticmethod
2077        def forward(ctx, input):
2078            return input
2079
2080        @staticmethod
2081        @once_differentiable
2082        def backward(ctx, input):
2083            debug_info = dist_autograd._get_debug_info()
2084            assert debug_info is not None
2085            backward_passes = int(debug_info["num_current_backward_passes"])
2086
2087            # Hard to validate exact numbers because of the distributed nature.
2088            # We can't use a barrier() here since that would block the single
2089            # CPU thread available for autograd and can cause deadlocks.
2090            assert backward_passes >= 1 and backward_passes <= 4
2091            return input
2092
2093    @dist_init
2094    def test_debug_info(self):
2095        initialize_pg(self.file_init_method, self.rank, self.world_size)
2096
2097        t1 = torch.rand((3, 3), requires_grad=True)
2098        t2 = torch.rand((3, 3), requires_grad=True)
2099        with dist_autograd.context() as context_id:
2100            i = 0
2101            res = {}
2102            res[i] = t1
2103            for rank in range(self.world_size):
2104                if rank != self.rank:
2105                    res[i + 1] = rpc.rpc_sync(
2106                        worker_name(rank), torch.add, args=(res[i], t2)
2107                    )
2108                    i += 1
2109
2110            # Call custom function in middle of backward pass to ensure all
2111            # nodes are still waiting on a backward().
2112            res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i])
2113            i += 1
2114
2115            for rank in range(self.world_size):
2116                if rank != self.rank:
2117                    res[i + 1] = rpc.rpc_sync(
2118                        worker_name(rank), torch.add, args=(res[i], t2)
2119                    )
2120                    i += 1
2121
2122            dist_autograd.backward(context_id, [res[i].sum()])
2123
2124            debug_info = dist_autograd._get_debug_info()
2125            num_autograd_context = int(debug_info["num_autograd_contexts"])
2126            # Need atleast one context and not more than 4.
2127            self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4)
2128
2129        for rd in range(self.world_size - 1):
2130            rpc.rpc_sync(
2131                worker_name((self.rank + rd + 1) % self.world_size),
2132                _set_rpc_done,
2133                args=(context_id, rd + 1),
2134            )
2135
2136        dist.barrier()
2137
2138        # Validate information
2139        debug_info = dist_autograd._get_debug_info()
2140        assert debug_info is not None
2141        self.assertEqual(0, int(debug_info["num_current_backward_passes"]))
2142        # only have `num_current_backward_passes` and `num_autograd contexts`
2143        self.assertTrue(len(debug_info) == 2)
2144
2145        self.assertTrue(_all_contexts_cleaned_up())
2146
2147        # All contexts should be cleaned up.
2148        debug_info = dist_autograd._get_debug_info()
2149        self.assertEqual(0, int(debug_info["num_autograd_contexts"]))
2150
2151    @staticmethod
2152    def _workload_thread():
2153        t1 = torch.rand((3, 3), requires_grad=True)
2154        t2 = torch.rand((3, 3), requires_grad=True)
2155        with dist_autograd.context() as context_id:
2156            t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2))
2157            t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3))
2158            t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
2159            t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
2160
2161            dist_autograd.backward(context_id, [t6.sum()])
2162
2163    @dist_init
2164    def test_async_dist_autograd(self):
2165        """
2166        This test ensures async processing for distributed autograd works
2167        appropriately. This is achieved by spawning multiple threads and
2168        hammering a single node with a lot of backward() calls.
2169        """
2170
2171        initialize_pg(self.file_init_method, self.rank, self.world_size)
2172        if self.rank != 0:
2173            # All other ranks schedule work on rank 0.
2174            threads = []
2175            for i in range(20):
2176                t = threading.Thread(target=DistAutogradTest._workload_thread)
2177                t.start()
2178                threads.append(t)
2179
2180            for thread in threads:
2181                thread.join()
2182
2183        dist.barrier()
2184
2185    @dist_init
2186    def test_backward_accumulate_grads(self):
2187        t1 = torch.rand((3, 3), requires_grad=True)
2188        t2 = torch.rand((3, 3), requires_grad=True)
2189        with dist_autograd.context() as context_id:
2190            t3 = torch.matmul(t1, t2)
2191            # Run backward twice.
2192            torch.autograd.backward([t3.sum()], retain_graph=True)
2193            torch.autograd.backward([t3.sum()])
2194
2195            t3 = rpc.rpc_sync(
2196                worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
2197            )
2198            # Run backward twice.
2199            dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
2200            dist_autograd.backward(context_id, [t3.sum()])
2201
2202            # Verify the gradients are same for local and remote execution.
2203            grads = dist_autograd.get_gradients(context_id)
2204            self.assertEqual(2, len(grads))
2205            self.assertIn(t1, grads)
2206            self.assertIn(t2, grads)
2207            self.assertEqual(t1.grad, grads[t1])
2208            self.assertEqual(t2.grad, grads[t2])
2209
2210    @staticmethod
2211    def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
2212        return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
2213
2214    @dist_init
2215    def test_nested_backward_accumulate_grads(self):
2216        self._nested_backward_accumulate_grads(
2217            torch.rand(3, 3, requires_grad=True),
2218            torch.rand(3, 3, requires_grad=True),
2219            False
2220        )
2221
2222    @dist_init
2223    def test_multiple_backward(self):
2224        self._multiple_backward(
2225            torch.rand(3, 3, requires_grad=True),
2226            torch.rand(3, 3, requires_grad=True),
2227            False
2228        )
2229
2230    @dist_init(clean_shutdown=False)
2231    def test_multiple_backward_with_errors(self):
2232        initialize_pg(self.file_init_method, self.rank, self.world_size)
2233        t1 = torch.rand((3, 3), requires_grad=True)
2234        t2 = torch.rand((3, 3), requires_grad=True)
2235        with dist_autograd.context() as context_id:
2236            loss = rpc.rpc_sync(
2237                f'worker{self._next_rank()}',
2238                DistAutogradTest._python_udf_with_backward_error,
2239                args=(t1, t2)).sum()
2240
2241            try:
2242                # Run backward in a loop multiple times.
2243                for i in range(100):
2244                    if i < 50:
2245                        with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
2246                            dist_autograd.backward(context_id, [loss], retain_graph=True)
2247                    elif i > 50:
2248                        # Recovered from error.
2249                        dist_autograd.backward(context_id, [loss], retain_graph=True)
2250                    else:
2251                        dist.barrier()
2252                        SimulateBackwardError._simulate_error = False
2253                        dist.barrier()
2254            finally:
2255                # Sync before resetting flag.
2256                dist.barrier()
2257
2258                # Reset the flag.
2259                SimulateBackwardError._simulate_error = True
2260
2261    @dist_init
2262    def test_backward_verify_hooks(self):
2263        t1 = torch.ones((3, 3), requires_grad=True)
2264        # Double the gradient.
2265        t1.register_hook(lambda grad: grad * 2)
2266        t2 = torch.ones((3, 3), requires_grad=True)
2267        local_grads = None
2268        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
2269            with dist_autograd.context() as context_id:
2270                ret = self._exec_func(exec_mode, torch.matmul, t1, t2)
2271                loss = ret.sum()
2272                ret = self._verify_backwards(
2273                    exec_mode, [loss], context_id, local_grads, t1, t2
2274                )
2275                local_grads = ret if ret else local_grads
2276
2277    @dist_init
2278    def test_no_grad_copy(self):
2279        '''
2280        Similar to test in test_autograd.py.
2281        '''
2282        # create autograd function that saves grad pointer as class static
2283        class MyFunc(Function):
2284            static_grad_ptr = None
2285
2286            @staticmethod
2287            def forward(ctx, inp1, inp2):
2288                return inp1 + inp2
2289
2290            @staticmethod
2291            def backward(ctx, grad):
2292                MyFunc.static_grad_ptr = grad.data_ptr()
2293                return grad, grad
2294
2295        class MyFuncSingleGrad(Function):
2296            static_grad_ptr = None
2297
2298            @staticmethod
2299            def forward(ctx, inp):
2300                return inp
2301
2302            @staticmethod
2303            def backward(ctx, grad):
2304                MyFuncSingleGrad.static_grad_ptr = grad.data_ptr()
2305                return grad
2306
2307        class NonContGradFunc(Function):
2308            @staticmethod
2309            def forward(ctx, inp1):
2310                ctx.size = inp1.size()
2311                return torch.tensor([1.])
2312
2313            @staticmethod
2314            def backward(ctx, grad):
2315                return torch.ones(1).expand(ctx.size)
2316
2317        a = torch.randn(5, 6, requires_grad=True)
2318        b = torch.randn(5, 6, requires_grad=True)
2319        # non-contiguous grad should be copied
2320        with dist_autograd.context() as context_id:
2321            dist_autograd.backward(context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))])
2322            grads = dist_autograd.get_gradients(context_id)
2323            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
2324            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
2325
2326        # test case that should trigger no copy for a
2327        with dist_autograd.context() as context_id:
2328            dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]])
2329            grads = dist_autograd.get_gradients(context_id)
2330            p_g = MyFuncSingleGrad.static_grad_ptr
2331            p_a = grads[a].data_ptr()
2332            # Verify there was no clone.
2333            self.assertTrue(p_a == p_g)
2334
2335        # Test case that should trigger copy for both of a,b. This is
2336        # different in the distributed autograd case since we hold
2337        # a reference to all grads in a vector until all accumulation is done.
2338        with dist_autograd.context() as context_id:
2339            dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]])
2340            grads = dist_autograd.get_gradients(context_id)
2341            p_g = MyFunc.static_grad_ptr
2342            p_a = grads[a].data_ptr()
2343            p_b = grads[b].data_ptr()
2344            # check a,b uses different grad buffer
2345            self.assertFalse(p_a == p_b)
2346            # both should be copied.
2347            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
2348            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
2349
2350    @dist_init
2351    def test_no_grad_copy_sparse(self):
2352        # create autograd function that saves grad pointer as class static
2353        class MyFunc(Function):
2354            static_grad_ptr = None
2355
2356            @staticmethod
2357            def forward(ctx, inp):
2358                return inp
2359
2360            @staticmethod
2361            def backward(ctx, grad):
2362                MyFunc.static_grad_ptr = grad._values().data_ptr()
2363                return grad
2364
2365        class NonContGradFunc(Function):
2366            static_grad_ptr = None
2367
2368            @staticmethod
2369            def forward(ctx, inp1, inp2):
2370                return inp1 + inp2
2371
2372            @staticmethod
2373            def backward(ctx, grad):
2374                # Create a sparse tensor with non-contiguous indices and values
2375                # and return as grad.
2376                v = torch.rand(1, 3)
2377                i = torch.ones(1, 1, dtype=torch.long)
2378                nv = v.expand(8, 3)
2379                ni = i.expand(1, 8)
2380                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
2381                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
2382                return ngrad, ngrad
2383
2384        a = torch.randn(10, 3, requires_grad=True)
2385        b = torch.randn(10, 3, requires_grad=True)
2386        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
2387        offsets = torch.tensor([0, 4])
2388        import torch.nn.functional as F
2389
2390        # test case that should trigger no copy for a.
2391        with dist_autograd.context() as context_id:
2392            emb_matrix = MyFunc.apply(a)
2393            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
2394            dist_autograd.backward(context_id, [loss], retain_graph=True)
2395            grads = dist_autograd.get_gradients(context_id)
2396            p_g = MyFunc.static_grad_ptr
2397            p_a = grads[a]._values().data_ptr()
2398            # check a uses the same buffer
2399            self.assertTrue(p_a == p_g)
2400
2401            # Run backwards multiple times.
2402            for i in range(10):
2403                dist_autograd.backward(context_id, [loss], retain_graph=True)
2404
2405        # non-contiguous indices and value, we should trigger a copy.
2406        with dist_autograd.context() as context_id:
2407            emb_matrix = NonContGradFunc.apply(a, b)
2408            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
2409            dist_autograd.backward(context_id, [loss], retain_graph=True)
2410            grads = dist_autograd.get_gradients(context_id)
2411            p_g = NonContGradFunc.static_grad_ptr
2412            p_a = grads[a]._values().data_ptr()
2413            p_b = grads[b]._values().data_ptr()
2414            # check a,b uses different grad buffer
2415            self.assertFalse(p_a == p_b)
2416            # Verify we cloned both grads.
2417            self.assertFalse(p_a == p_g)
2418            self.assertFalse(p_b == p_g)
2419
2420            # Run backwards multiple times to verify accumulation.
2421            for i in range(10):
2422                dist_autograd.backward(context_id, [loss], retain_graph=True)
2423
2424    @dist_init
2425    def test_grad_copy_sparse_indices_extra_ref(self):
2426        # create autograd function that saves grad pointer as class static
2427        class MyFunc(Function):
2428            static_grad_ptr = None
2429            static_grad_indices_ref = None
2430            static_grad_values_ref = None
2431
2432            @staticmethod
2433            def forward(ctx, inp):
2434                return inp
2435
2436            @staticmethod
2437            def backward(ctx, grad):
2438                MyFunc.static_grad_ptr = grad._values().data_ptr()
2439                # indices() and values() return views, so holding onto
2440                # references of them would not increment refcount of indices
2441                # and values inside the sparse tensor.
2442                MyFunc.static_grad_indices_ref = grad._indices()
2443                MyFunc.static_grad_values_ref = grad._values()
2444                return grad
2445
2446        a = torch.randn(10, 3, requires_grad=True)
2447        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
2448        offsets = torch.tensor([0, 4])
2449        import torch.nn.functional as F
2450
2451        with dist_autograd.context() as context_id:
2452            emb_matrix = MyFunc.apply(a)
2453            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
2454            dist_autograd.backward(context_id, [loss], retain_graph=True)
2455            grads = dist_autograd.get_gradients(context_id)
2456            p_g = MyFunc.static_grad_ptr
2457            p_a = grads[a]._values().data_ptr()
2458            self.assertIsNotNone(MyFunc.static_grad_indices_ref)
2459            self.assertIsNotNone(MyFunc.static_grad_values_ref)
2460            # grad would be stolen, since static_grad_indices_ref and
2461            # static_grad_values_ref are holding onto views and don't bump the
2462            # refcount.
2463            self.assertTrue(p_g == p_a)
2464
2465    @dist_init
2466    def test_post_hooks(self):
2467        self.hook_called_times = 0
2468
2469        def post_hook_add_one(output_grads, input_grads):
2470            self.hook_called_times += 1
2471            return output_grads
2472
2473        def post_hook_add_two(output_grads, input_grads):
2474            self.hook_called_times += 2
2475            return output_grads
2476
2477        t = torch.rand(10, 10, requires_grad=True)
2478        a = t + t
2479
2480        # Register post hooks
2481        accumulate_grad_0 = a.grad_fn.next_functions[0][0]
2482        accumulate_grad_0.register_hook(post_hook_add_one)
2483        accumulate_grad_0.register_hook(post_hook_add_two)
2484
2485        accumulate_grad_1 = a.grad_fn.next_functions[1][0]
2486        accumulate_grad_1.register_hook(post_hook_add_two)
2487
2488        with dist_autograd.context() as context_id:
2489            loss = a.sum()
2490            dist_autograd.backward(context_id, [loss])
2491            self.assertEqual(5, self.hook_called_times)
2492            grads = dist_autograd.get_gradients(context_id)
2493            self.assertEqual(1, len(grads))
2494            self.assertTrue(t in grads)
2495
2496    @staticmethod
2497    def _slow_add(t1, t2):
2498        time.sleep(1)
2499        t3 = t1 + t2
2500        t3.requires_grad = True
2501        return t3
2502
2503    @dist_init
2504    def test_thread_local_context_id(self):
2505        t1 = torch.rand((3, 3))
2506        t2 = torch.rand((3, 3))
2507
2508        t3 = t1 + t2
2509        t3.requires_grad = True
2510        t3.sum().backward()
2511
2512        dst = worker_name((self.rank + 1) % self.world_size)
2513        rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2))
2514
2515        with dist_autograd.context() as context_id:
2516            loss = rref.to_here().sum()
2517            # due to slow add, the continuation of this backward pass will be
2518            # invoked by the previous rpc.remote thread which does not have a
2519            # valid context_id. So, this can test whether we propagate
2520            # thread_local states properly when jumping across threads on the
2521            # server side.
2522            dist_autograd.backward(context_id, [loss])
2523            self.assertTrue(
2524                rpc.rpc_sync(
2525                    dst,
2526                    _compare_owner_value,
2527                    args=(context_id, rref, t3.grad)
2528                )
2529            )
2530
2531
2532class CudaDistAutogradTest(CommonDistAutogradTest):
2533    @skip_if_lt_x_gpu(1)
2534    @dist_init
2535    def test_gpu_simple(self):
2536        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
2537        t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
2538        (t1 + t2).sum().backward()
2539        with dist_autograd.context() as context_id:
2540            t3 = t1 + t2
2541            dist_autograd.backward(context_id, [t3.sum()])
2542            grads = dist_autograd.get_gradients(context_id)
2543            self.assertEqual(2, len(grads))
2544            self.assertEqual(t1.grad, grads[t1])
2545            self.assertEqual(t2.grad, grads[t2])
2546
2547    @skip_if_lt_x_gpu(1)
2548    @dist_init
2549    def test_gpu_to_cpu_continuation(self):
2550        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
2551        t2 = torch.rand(3, 3, requires_grad=True)
2552        # Run a few iterations.
2553        for i in range(3):
2554            t1.grad = None
2555            t2.grad = None
2556            # Root is CPU
2557            local_grads = None
2558            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
2559                with dist_autograd.context() as context_id:
2560                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
2561                    t4 = t3.cuda(0) + t1
2562                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
2563                    t6 = t5.cuda(0) + t4
2564                    t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5)
2565                    # Autograd graph consists of CPU -> GPU -> CPU execution.
2566                    ret = self._verify_backwards(
2567                        exec_mode, [t7.sum()], context_id, local_grads, t1, t2
2568                    )
2569                    local_grads = ret if ret else local_grads
2570
2571    @skip_if_lt_x_gpu(1)
2572    @dist_init
2573    def test_gpu_to_cpu_continuation_gpu_root(self):
2574        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
2575        t2 = torch.rand(3, 3, requires_grad=True)
2576        # Run a few iterations.
2577        for i in range(3):
2578            t1.grad = None
2579            t2.grad = None
2580            # Root is CPU
2581            local_grads = None
2582            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
2583                with dist_autograd.context() as context_id:
2584                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
2585                    t4 = t3.cuda(0) + t1
2586                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
2587                    t6 = t5.cuda(0) + t4
2588                    # Autograd graph consists of CPU -> GPU -> CPU execution.
2589                    ret = self._verify_backwards(
2590                        exec_mode, [t6.sum()], context_id, local_grads, t1, t2
2591                    )
2592                    local_grads = ret if ret else local_grads
2593
2594
2595class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
2596    # Reusing a simplified helper function from DistAutogradTest to ensure
2597    # autograd context is successfully cleaned up even when RPCs are failing.
2598    def context_cleanup_test_helper(self, rpc_args, func):
2599        initialize_pg(self.file_init_method, self.rank, self.world_size)
2600
2601        # test that in dist autograd, in the case that tensors communicated over RPC do
2602        # NOT require grad, we still cleanup the dist autograd contexts created
2603        # on other nodes. This is because the autograd context is still
2604        # communicated over RPC even if tensor arguments do not require grad, as
2605        # it is possible that the response could.
2606        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
2607
2608        with dist_autograd.context() as context_id:
2609            for dst_rank in dst_ranks:
2610                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
2611                rpc.rpc_sync(
2612                    worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
2613                )
2614        # the thread's context id should be cleaned up
2615        with self.assertRaises(RuntimeError):
2616            dist_autograd._retrieve_context(context_id)
2617        # Ensure all peers have finished mutating the
2618        # `known_context_ids` set.
2619        dist.barrier()
2620        # check that all contexts have been cleaned up.
2621        success = _all_contexts_cleaned_up()
2622        self.assertTrue(success)
2623
2624    # no faulty_messages defined so this fails all retryable messages - see
2625    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
2626    @dist_init
2627    def test_context_cleanup_tensor_with_grad(self):
2628        t1 = torch.ones(3, 3, requires_grad=True)
2629        t2 = torch.zeros(3, 3, requires_grad=True)
2630        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
2631
2632    @dist_init
2633    def test_verify_backend_options(self):
2634        self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
2635        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
2636        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
2637        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
2638
2639
2640class WrapperModule(nn.Module):
2641    def __init__(self, model, device):
2642        super().__init__()
2643        self.model = model.to(device)
2644
2645    def forward(self, *args):
2646        return self.model(*args)
2647
2648    def gradients(self, ctx_id):
2649        grads = dist_autograd.get_gradients(ctx_id)
2650        return [grads[p] for p in self.model.parameters()]
2651
2652
2653class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
2654
2655    @skip_if_lt_x_gpu(4)
2656    def test_device_maps_backward_pass(self):
2657        options = self.rpc_backend_options
2658        dst = worker_name((self.rank + 1) % self.world_size)
2659
2660        # The reverse of this device mapping should be used for the backward pass.
2661        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
2662
2663        rpc.init_rpc(
2664            name=worker_name(self.rank),
2665            backend=self.rpc_backend,
2666            rank=self.rank,
2667            world_size=self.world_size,
2668            rpc_backend_options=options,
2669        )
2670
2671        t1 = torch.rand(10, device=self.rank, requires_grad=True)
2672        t2 = torch.rand(10, device=self.rank, requires_grad=True)
2673        with dist_autograd.context() as context_id:
2674            res = rpc.rpc_sync(dst, torch.add, args=(t1, t2))
2675            dist_autograd.backward(context_id, [res.sum()])
2676            grads = dist_autograd.get_gradients(context_id)
2677            self.assertEqual(torch.ones(10), grads[t1])
2678            self.assertEqual(torch.ones(10), grads[t2])
2679            self.assertEqual(t1.device, grads[t1].device)
2680            self.assertEqual(t2.device, grads[t2].device)
2681
2682        rpc.shutdown()
2683
2684    class MyRemoteCompute(torch.nn.Module):
2685        def forward(self, input):
2686            input = input * 2.0
2687            return input
2688
2689    class MyLocalCompute(torch.nn.Module):
2690        def __init__(self, next_stage):
2691            super().__init__()
2692            self.next_stage = next_stage
2693
2694        def forward(self, input):
2695            return self.next_stage.rpc_sync().forward(input)
2696
2697    @skip_if_lt_x_gpu(4)
2698    def test_dist_autograd_sync_streams(self):
2699
2700        options = self.rpc_backend_options
2701        dst = worker_name((self.rank + 1) % self.world_size)
2702
2703        # The reverse of this device mapping should be used for the backward pass.
2704        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
2705
2706        rpc.init_rpc(
2707            name=worker_name(self.rank),
2708            backend=self.rpc_backend,
2709            rank=self.rank,
2710            world_size=self.world_size,
2711            rpc_backend_options=options,
2712        )
2713
2714        remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute)
2715        local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute)
2716        for _ in range(10):
2717            input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
2718            # Run local autograd
2719            result = input * 2.0
2720            r = random.random()
2721            loss = result.sum() * r
2722            loss.backward()
2723
2724            # Run distributed autograd
2725            with dist_autograd.context() as context_id:
2726                result = local_compute(input)
2727                loss = result.sum() * r
2728                dist_autograd.backward(context_id, [loss])
2729
2730                # Compare grads.
2731                grads = dist_autograd.get_gradients(context_id)
2732                self.assertEqual(input.grad, grads[input])
2733
2734        rpc.shutdown()
2735
2736    @skip_if_lt_x_gpu(4)
2737    def test_gradients_synchronizations(self):
2738        options = self.rpc_backend_options
2739        for peer_rank in range(self.world_size):
2740            options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank})
2741
2742        rpc.init_rpc(
2743            name=worker_name(self.rank),
2744            backend=self.rpc_backend,
2745            rank=self.rank,
2746            world_size=self.world_size,
2747            rpc_backend_options=options,
2748        )
2749
2750        if self.rank == 0:
2751            # this is master
2752            layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)]
2753            local_layers = [l.to(0) for l in layers]
2754            remote_layers = []
2755            for rank in range(1, self.world_size):
2756                remote_layers.append(rpc.remote(
2757                    worker_name(rank),
2758                    WrapperModule,
2759                    args=(layers[rank - 1], rank)
2760                ))
2761
2762            x = torch.randn(5000, 2000).to(0)
2763            # local iteration
2764            local_model = nn.Sequential(*local_layers)
2765            local_model(x).sum().backward()
2766
2767            # remote iteration
2768            with dist_autograd.context() as context_id:
2769                for remote_layer in remote_layers:
2770                    x = remote_layer.rpc_sync().forward(x)
2771
2772                dist_autograd.backward(context_id, [x.sum()])
2773
2774                futs = []
2775                for remote_layer in remote_layers:
2776                    futs.append(remote_layer.rpc_async().gradients(context_id))
2777
2778                for i in range(len(futs)):
2779                    local_gradients = [p.grad for p in local_layers[i].parameters()]
2780                    for g1, g2 in zip(futs[i].wait(), local_gradients):
2781                        self.assertEqual(g1, g2)
2782
2783        rpc.shutdown()
2784