1# mypy: allow-untyped-defs 2 3import sys 4import threading 5import time 6from enum import Enum 7import random 8import torch 9import torch.nn as nn 10from datetime import timedelta 11import torch.distributed as dist 12import torch.distributed.autograd as dist_autograd 13import torch.distributed.rpc as rpc 14import torch.testing._internal.dist_utils 15from torch.autograd import Function 16from torch.autograd.function import once_differentiable 17from torch.distributed.rpc import RRef 18from torch.testing._internal.common_utils import IS_MACOS, skip_but_pass_in_sandcastle_if 19from torch.testing._internal.dist_utils import ( 20 dist_init, 21 initialize_pg, 22 wait_until_node_failure, 23 worker_name, 24) 25from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 26 RpcAgentTestFixture, 27) 28from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 29 30 31# Right now we test up to 3-layer nested rpc calls. 32# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id 33# sent from prev rank respectively. 34# rpc_done[2] and ctx_ids[2] represents for prev of prev rank. 35# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank. 36# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used. 37rpc_done = [False, False, False, False] 38ctx_ids = [-1, -1, -1, -1] 39 40known_context_ids = set() 41 42requires_grad_tensor = torch.ones(3, 3, requires_grad=True) 43 44# Send rpc done info and context_id to 45# dst_rank = (self.rank + rank_distance) % self.world_size 46# we don't need a lock here since the GIL is held while executing remote 47# python UDFs, so access is serialized across several workers. 48def _set_rpc_done(ctx_id, rank_distance): 49 global rpc_done 50 global ctx_ids 51 global known_context_ids 52 rpc_done[rank_distance] = True 53 ctx_ids[rank_distance] = ctx_id 54 known_context_ids.add(ctx_id) 55 56 57def _check_rpc_done(rank_distance): 58 while not rpc_done[rank_distance]: 59 time.sleep(0.1) 60 61 62def _torch_ones(sizes, requires_grad=False): 63 return torch.ones(sizes, requires_grad=requires_grad) 64 65# This method must be called on the rref owner, and verifies that the grad of 66# rref tensor equals to the given grad. 67def _compare_owner_value(context_id, rref, grad): 68 grads = dist_autograd.get_gradients(context_id) 69 x = grads[rref.local_value()] 70 if x.is_sparse: 71 assert grad.is_sparse 72 x = x.to_dense() 73 grad = grad.to_dense() 74 else: 75 assert not grad.is_sparse 76 return torch.equal(x, grad) 77 78 79def create_tensor(): 80 return torch.ones((3, 3), requires_grad=True) 81 82 83def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32): 84 i = [[0, 1, 1], [2, 0, 2]] 85 v = [3.2, 4.1, 5.3] 86 tensor = torch.sparse_coo_tensor( 87 i, v, (3, 3), requires_grad=requires_grad, dtype=dtype 88 ) 89 if coalesce: 90 tensor = tensor.coalesce() 91 return tensor 92 93 94@torch.jit.script 95def create_torchscript_tensor() -> torch.Tensor: 96 return torch.ones((3, 3)).requires_grad_() 97 98 99def my_py_add(t1, t2): 100 return torch.add(t1, t2) 101 102 103def my_scalar_add(a, b): 104 return a + b 105 106 107def my_rref_add(rref_t1, t2): 108 ret = torch.add(rref_t1.local_value(), t2) 109 return ret 110 111 112@torch.jit.script 113def my_script_add(t1, t2): 114 return torch.add(t1, t2) 115 116 117@torch.jit.script 118def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor: 119 t1 = ref_t1.to_here() 120 return torch.add(t1, t2) 121 122 123def my_nested_rref_add(dst, rref_t1, t2): 124 return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2)) 125 126 127def ret_requires_grad(): 128 return requires_grad_tensor 129 130 131def my_py_nested_call(t1, t2, dst, world_size, hops): 132 next_dst = (dst + 1) % world_size 133 if hops > 0: 134 return rpc.rpc_sync( 135 worker_name(next_dst), 136 my_py_nested_call, 137 args=(t1, t2, next_dst, world_size, hops - 1), 138 ) 139 else: 140 return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2)) 141 142 143# after dist autograd context is cleaned up, it should be cleaned up on other 144# nodes. This helper allows timeout_seconds for those RPCs to be completed, and 145# ensures that all the contexts have been cleaned up in that timeframe.any 146def _all_contexts_cleaned_up(timeout_seconds=10): 147 global known_context_ids 148 start = time.time() 149 context_id_to_raised = set() 150 while ( 151 time.time() - start < timeout_seconds 152 and context_id_to_raised != known_context_ids 153 ): 154 for context_id in known_context_ids: 155 try: 156 dist_autograd._retrieve_context(context_id) 157 except RuntimeError: 158 context_id_to_raised.add(context_id) 159 # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError. 160 success = context_id_to_raised == known_context_ids 161 return success 162 163 164# This function creates a dis autograd context, run rpc_sync on the given ps, 165# and then blocks until the ps has verified the grads are correctly accumulated. 166def _run_trainer(rref_t1, t2, ps, rank_diff, sparse): 167 with dist_autograd.context() as context_id: 168 ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) 169 if sparse: 170 loss = torch.sparse.sum(ret) 171 else: 172 loss = ret.sum() 173 dist_autograd.backward(context_id, [loss]) 174 # prevent deleting dist autograd context 175 rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) 176 rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) 177 178# This function is the same as _run_trainer, except rpc calls torchscript 179# function "my_script_ref_add" instead of python function "my_rref_add" 180def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse): 181 with dist_autograd.context() as context_id: 182 ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2)) 183 if sparse: 184 loss = torch.sparse.sum(ret) 185 else: 186 loss = ret.sum() 187 dist_autograd.backward(context_id, [loss]) 188 # prevent deleting dist autograd context 189 rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) 190 rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) 191 192 193class SimulateBackwardError(Function): 194 _simulate_error = True 195 196 @staticmethod 197 def forward(ctx, input): 198 return input 199 200 @staticmethod 201 @once_differentiable 202 def backward(ctx, input): 203 if SimulateBackwardError._simulate_error: 204 raise Exception("Simulate error on backward pass") # noqa: TRY002 205 else: 206 return input 207 208 209class ExecMode(Enum): 210 LOCAL = 1 # Run the operation locally. 211 RPC_SYNC = 2 # Run the operation using rpc_sync 212 REMOTE = 3 # Run the operation using remote. 213 RPC_ASYNC = 4 # Run the operation using rpc_async 214 215 216# Common utils for both CPU and CUDA test suites 217class CommonDistAutogradTest(RpcAgentTestFixture): 218 def _exec_func_with_dst(self, dst, exec_mode, method, *args): 219 if ExecMode.LOCAL == exec_mode: 220 if len(args) == 1 and isinstance(args[0], list): 221 return method(*args[0]) 222 return method(*args) 223 elif ExecMode.RPC_SYNC == exec_mode: 224 return rpc.rpc_sync(worker_name(dst), method, args=(args)) 225 elif ExecMode.REMOTE == exec_mode: 226 return rpc.remote(worker_name(dst), method, args=(args)).to_here() 227 elif ExecMode.RPC_ASYNC == exec_mode: 228 fut = rpc.rpc_async(worker_name(dst), method, args=(args)) 229 return fut.wait() 230 else: 231 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 232 233 def _exec_func(self, exec_mode, method, *args): 234 return self._exec_func_with_dst( 235 self._next_rank(), exec_mode, method, *args 236 ) 237 238 def _next_rank(self): 239 if hasattr(self, "dst_rank"): 240 self.dst_rank = (self.dst_rank + 1) % self.world_size 241 if self.dst_rank == self.rank: 242 return self._next_rank() 243 else: 244 self.dst_rank = (self.rank + 1) % self.world_size 245 return self.dst_rank 246 247 def _check_rpc_done(self, rank_distance): 248 _check_rpc_done(rank_distance) 249 250 def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args): 251 if exec_mode == ExecMode.LOCAL: 252 torch.autograd.backward(tensors) 253 return [arg.grad for arg in args] 254 else: 255 self._verify_backwards_remote(tensors, context_id, local_grads, *args) 256 257 def _verify_backwards_remote(self, tensors, context_id, local_grads, *args): 258 dist_autograd.backward(context_id, tensors) 259 260 # Verify grads were accumulated appropriately. 261 grads = dist_autograd.get_gradients(context_id) 262 nargs = len(args) 263 ngrads = 0 264 for i in range(0, nargs): 265 if local_grads[i] is not None: 266 self.assertIn(args[i], grads) 267 self.assertEqual(local_grads[i], grads[args[i]]) 268 ngrads += 1 269 else: 270 self.assertNotIn(args[i], grads) 271 272 self.assertEqual(ngrads, len(grads)) 273 274 def _test_graph(self, fn, exec_mode, sparse): 275 dst_rank = (self.rank + 1) % self.world_size 276 277 initialize_pg(self.file_init_method, self.rank, self.world_size) 278 279 with dist_autograd.context() as context_id: 280 if sparse: 281 t1 = build_sparse_tensor() 282 t2 = build_sparse_tensor() 283 else: 284 t1 = torch.ones(3, 3, requires_grad=True) 285 t2 = torch.zeros(3, 3, requires_grad=True) 286 if ExecMode.RPC_SYNC == exec_mode: 287 ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2)) 288 elif ExecMode.REMOTE == exec_mode: 289 ret = rpc.remote( 290 worker_name(dst_rank), fn, args=(t1, t2) 291 ).to_here() 292 else: 293 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 294 295 rpc.rpc_sync( 296 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 297 ) 298 299 # Verify graph for current context id. 300 ctx = dist_autograd._current_context() 301 self.assertEqual(context_id, ctx._context_id()) 302 send_functions = ctx._send_functions() 303 self.assertEqual(1, len(send_functions)) 304 recv_functions = ctx._recv_functions() 305 self.assertEqual(1, len(recv_functions)) 306 self._verify_graph_for_first_rpc_call( 307 next(iter(send_functions.values())), 308 next(iter(recv_functions.values())), 309 t1, 310 t2, 311 ret, 312 ) 313 314 # Wait for the prev rank to be done with rpc. 315 self._check_rpc_done(1) 316 # Verify graph for previous context id. 317 ctx = dist_autograd._retrieve_context(ctx_ids[1]) 318 send_functions = ctx._send_functions() 319 self.assertEqual(1, len(send_functions)) 320 self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values()))) 321 # this barrier is needed so one worker does not clean up their 322 # autograd context before another worker tries to access it. 323 dist.barrier() 324 325 # autograd context should be cleaned up by now. 326 with self.assertRaises(RuntimeError): 327 ctx = dist_autograd._retrieve_context(context_id) 328 329 # No autograd context available. 330 with self.assertRaises(RuntimeError): 331 ctx = dist_autograd._current_context() 332 333 # 3-layer nested calls 334 def _test_graph_for_py_nested_call(self, exec_mode, sparse): 335 dst_rank = (self.rank + 1) % self.world_size 336 337 initialize_pg(self.file_init_method, self.rank, self.world_size) 338 339 with dist_autograd.context() as context_id: 340 if sparse: 341 t1 = build_sparse_tensor(requires_grad=True) 342 t2 = build_sparse_tensor(requires_grad=True) 343 else: 344 t1 = torch.ones(3, 3, requires_grad=True) 345 t2 = torch.zeros(3, 3, requires_grad=True) 346 nest_dst_rank = (dst_rank + 1) % self.world_size 347 if ExecMode.RPC_SYNC == exec_mode: 348 ret = rpc.rpc_sync( 349 worker_name(dst_rank), 350 my_py_nested_call, 351 args=(t1, t2, dst_rank, self.world_size, 1), 352 ) 353 elif ExecMode.REMOTE == exec_mode: 354 ret = rpc.remote( 355 worker_name(dst_rank), 356 my_py_nested_call, 357 args=(t1, t2, dst_rank, self.world_size, 1), 358 ).to_here() 359 else: 360 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 361 362 # Barrier to ensure all RPCs are done. 363 dist.barrier() 364 365 for rd in [1, 2, 3]: 366 rpc.rpc_sync( 367 worker_name((self.rank + rd) % self.world_size), 368 _set_rpc_done, 369 args=(context_id, rd), 370 ) 371 372 # Barrier to ensure all set_rpc_done have completed. 373 dist.barrier() 374 375 # For self.rank, it has 4 graphs to verify 376 # One is for current context id when this rank send first rpc call. 377 # Second one is for prev context id when this rank make 1st nested 378 # call. 379 # Third one is for prev prev context id when this rank make 380 # 2nd nested call. 381 # Last one is for prev prev prev context id when this rank 382 # execute the torch.add() operator. 383 384 # Verify first graph for current context id. 385 ctx = dist_autograd._current_context() 386 self.assertEqual(context_id, ctx._context_id()) 387 send_functions = ctx._send_functions() 388 self.assertEqual(1, len(send_functions)) 389 recv_functions = ctx._recv_functions() 390 self.assertEqual(1, len(recv_functions)) 391 self._verify_graph_for_first_rpc_call( 392 next(iter(send_functions.values())), 393 next(iter(recv_functions.values())), 394 t1, 395 t2, 396 ret, 397 ) 398 399 # Verify second graph for 1st nested call. 400 ctx = dist_autograd._retrieve_context(ctx_ids[1]) 401 self._verify_graph_for_nested_rpc_call(ctx) 402 403 # Verify third graph for 2nd nested call. 404 ctx = dist_autograd._retrieve_context(ctx_ids[2]) 405 self._verify_graph_for_nested_rpc_call(ctx) 406 407 # verify last graph for rpc call execution. 408 ctx = dist_autograd._retrieve_context(ctx_ids[3]) 409 send_functions = ctx._send_functions() 410 self.assertEqual(1, len(send_functions)) 411 self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values()))) 412 # this barrier is needed so one worker does not clean up their 413 # autograd context before another worker tries to access it. 414 dist.barrier() 415 416 # Rank0->Rank1->Rank0 417 def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse): 418 dst_rank = (self.rank + 1) % self.world_size 419 420 initialize_pg(self.file_init_method, self.rank, self.world_size) 421 422 with dist_autograd.context() as context_id: 423 if sparse: 424 t1 = build_sparse_tensor(requires_grad=True) 425 t2 = build_sparse_tensor(requires_grad=True) 426 else: 427 t1 = torch.ones(3, 3, requires_grad=True) 428 t2 = torch.zeros(3, 3, requires_grad=True) 429 if ExecMode.RPC_SYNC == exec_mode: 430 ret = rpc.rpc_sync( 431 worker_name(dst_rank), 432 my_py_nested_call, 433 args=( 434 t1, 435 t2, 436 (self.rank - 1 + self.world_size) % self.world_size, 437 self.world_size, 438 0, 439 ), 440 ) 441 elif ExecMode.REMOTE == exec_mode: 442 ret = rpc.remote( 443 worker_name(dst_rank), 444 my_py_nested_call, 445 args=( 446 t1, 447 t2, 448 (self.rank - 1 + self.world_size) % self.world_size, 449 self.world_size, 450 0, 451 ), 452 ).to_here() 453 else: 454 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 455 456 rpc.rpc_sync( 457 worker_name((self.rank + 1) % self.world_size), 458 _set_rpc_done, 459 args=(context_id, 1), 460 ) 461 462 # For self.rank, it has 2 graphs to verify. 463 # One is for current context id when this rank send first rpc 464 # call and execute the torch.add() operator. 465 # Another one is for prev context id when this rank make 466 # nested call. 467 ctx = dist_autograd._current_context() 468 self.assertEqual(context_id, ctx._context_id()) 469 send_functions = ctx._send_functions() 470 self.assertEqual(2, len(send_functions)) 471 recv_functions = ctx._recv_functions() 472 self.assertEqual(2, len(recv_functions)) 473 self._verify_graph_for_first_rpc_call( 474 next(iter(send_functions.values())), 475 list(recv_functions.values())[1], 476 t1, 477 t2, 478 ret, 479 ) 480 self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) 481 482 # Verify two pairs of send and recv functions for nested 483 # call 484 self._check_rpc_done(1) 485 ctx = dist_autograd._retrieve_context(ctx_ids[1]) 486 self._verify_graph_for_nested_rpc_call(ctx) 487 # this barrier is needed so one worker does not clean up their 488 # autograd context before another worker tries to access it. 489 dist.barrier() 490 491 def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): 492 initialize_pg(self.file_init_method, self.rank, self.world_size) 493 dst_rank = (self.rank + 1) % self.world_size 494 with dist_autograd.context() as context_id: 495 if sparse: 496 t1 = build_sparse_tensor(requires_grad=False) 497 t2 = build_sparse_tensor(requires_grad=False) 498 else: 499 t1 = torch.ones(3, 3, requires_grad=False) 500 t2 = torch.zeros(3, 3, requires_grad=False) 501 if ExecMode.RPC_SYNC == exec_mode: 502 ret = rpc.rpc_sync( 503 worker_name(dst_rank), torch.add, args=(t1, t2) 504 ) 505 elif ExecMode.REMOTE == exec_mode: 506 ret = rpc.remote( 507 worker_name(dst_rank), torch.add, args=(t1, t2) 508 ).to_here() 509 else: 510 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 511 512 rpc.rpc_sync( 513 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 514 ) 515 516 ctx = dist_autograd._current_context() 517 send_functions = ctx._send_functions() 518 self.assertEqual(len(send_functions), 0) 519 recv_functions = ctx._recv_functions() 520 self.assertEqual(len(recv_functions), 0) 521 522 # Wait for the prev rank to be done with rpc. 523 self._check_rpc_done(1) 524 # NB: RRef.to_here() always passes the autograd context to the 525 # the callee, as the caller does not know whether the return 526 # value would contain a requires_grad tensor or not. 527 # 528 # rpc/remote with udf (_set_rpc_done here) also always passes the 529 # autograd context to the callee due to the same reason. 530 self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1])) 531 dist.barrier() 532 533 def _test_rpc_complex_args(self, exec_mode, sparse): 534 with dist_autograd.context() as context_id: 535 num_tensors = 10 536 tensors = [] 537 for i in range(num_tensors): 538 if sparse: 539 tensor = build_sparse_tensor(requires_grad=(i % 2 == 0)) 540 else: 541 tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0)) 542 tensors.append(tensor) 543 dst_rank = self._next_rank() 544 if ExecMode.RPC_SYNC == exec_mode: 545 ret = rpc.rpc_sync( 546 worker_name(dst_rank), torch.stack, args=(tensors,) 547 ) 548 elif ExecMode.REMOTE == exec_mode: 549 ret = rpc.remote( 550 worker_name(dst_rank), torch.stack, args=(tensors,) 551 ).to_here() 552 else: 553 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 554 555 self.assertEqual(torch.stack(tensors), ret) 556 557 # Verify appropriate tensors have been attached the autograd graph. 558 next_funcs = next(iter(dist_autograd._current_context()._send_functions().values())).next_functions 559 idx = 0 560 for i in range(len(next_funcs)): 561 self.assertEqual( 562 "torch::autograd::AccumulateGrad", next_funcs[i][0].name() 563 ) 564 self.assertEqual(tensors[i], next_funcs[i][0].variable) 565 566 # Verify that the worker id has been recorded in the context 567 ctx = dist_autograd._current_context() 568 worker_ids = ctx._known_worker_ids() 569 self.assertEqual(len(worker_ids), 1) 570 self.assertEqual(worker_ids, {dst_rank}) 571 572 def context_cleanup_test_helper(self, rpc_args, func, nested=False): 573 initialize_pg(self.file_init_method, self.rank, self.world_size) 574 575 # test that in dist autograd, in the case that tensors communicated over RPC do 576 # NOT require grad, we still cleanup the dist autograd contexts created 577 # on other nodes. This is because the autograd context is still 578 # communicated over RPC even if tensor arguments do not require grad, as 579 # it is possible that the response could. 580 if nested: 581 dst_rank = (self.rank + 1) % self.world_size 582 nested_dst_rank = (dst_rank + 1) % self.world_size 583 dst_ranks = {dst_rank} 584 else: 585 dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} 586 587 with dist_autograd.context() as context_id: 588 for dst_rank in dst_ranks: 589 rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args) 590 rpc.rpc_sync( 591 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 592 ) 593 if nested: 594 rpc.rpc_sync( 595 worker_name(nested_dst_rank), 596 _set_rpc_done, 597 args=(context_id, 2), 598 ) 599 # the thread's context id should be cleaned up 600 with self.assertRaises(RuntimeError): 601 dist_autograd._retrieve_context(context_id) 602 # Ensure all peers have finished mutating the 603 # `known_context_ids` set. 604 dist.barrier() 605 # check that all contexts have been cleaned up. 606 success = _all_contexts_cleaned_up() 607 self.assertTrue(success) 608 609 def _backward_no_grad_on_tensor(self, t1, t2, sparse): 610 with dist_autograd.context() as context_id: 611 loss = rpc.rpc_sync( 612 worker_name(self._next_rank()), 613 torch.add, 614 args=(t1, t2)) 615 if sparse: 616 loss = torch.sparse.sum(loss) 617 else: 618 loss = loss.sum() 619 dist_autograd.backward(context_id, [loss], retain_graph=True) 620 self.assertIsNone(t1.grad) 621 self.assertIsNone(t2.grad) 622 623 # Now populate .grad with local autograd engine and 624 # verify dist autograd doesn't mess with it. 625 loss_local = torch.add(t1, t2) 626 if sparse: 627 loss_local = torch.sparse.sum(loss_local) 628 else: 629 loss_local = loss_local.sum() 630 loss_local.backward() 631 self.assertIsNotNone(t1.grad) 632 self.assertIsNotNone(t2.grad) 633 634 t1_grad_before = t1.grad 635 t2_grad_before = t2.grad 636 dist_autograd.backward(context_id, [loss]) 637 self.assertEqual(t1_grad_before, t1.grad) 638 self.assertEqual(t2_grad_before, t2.grad) 639 640 # The current rank first creates a tensor on the rref_owner, and then passes 641 # the rref with another tensor to the callee to run either my_rref_add or 642 # my_nested_rref_add, depending on whether the callee is the rref owner. 643 # The grad of tensor lives on the current rank, and the grad of the rref 644 # tensor lives on the rref owner. 645 def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse): 646 local_ret = torch.add(t1, t2) 647 if sparse: 648 local_ret = torch.sparse.sum(local_ret) 649 else: 650 local_ret = local_ret.sum() 651 local_ret.backward() 652 with dist_autograd.context() as context_id: 653 if sparse: 654 rref_t1 = rpc.remote( 655 rref_owner, build_sparse_tensor, args=(False, True,) 656 ) 657 else: 658 rref_t1 = rpc.remote( 659 rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True} 660 ) 661 if callee == rref_owner: 662 rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2)) 663 else: 664 rref = rpc.remote( 665 callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2) 666 ) 667 ret = rref.to_here() 668 if sparse: 669 ret = torch.sparse.sum(ret) 670 else: 671 ret = ret.sum() 672 dist_autograd.backward(context_id, [ret]) 673 674 # verify grads on caller 675 grads = dist_autograd.get_gradients(context_id) 676 self.assertIn(t2, grads) 677 self.assertEqual(grads[t2], t2.grad) 678 679 # verify grads on rref owner 680 self.assertTrue( 681 rpc.rpc_sync( 682 rref_owner, 683 _compare_owner_value, 684 args=(context_id, rref_t1, t1.grad), 685 ) 686 ) 687 688 # In this test, every rank will serve as a parameter server (ps) and a 689 # driver, and then kicks off trainers on the other three ranks. So, we have: 690 # ps = rank0 with trainers = rank1/2/3 691 # ps = rank2 with trainers = rank2/3/0 692 # ps = rank3 with trainers = rank3/0/1 693 # ps = rank4 with trainers = rank0/1/2 694 # 695 # These four test ps-trainer groups run on completely separate autograd 696 # graphs, but they share the same set of underlying RpcAgents. 697 def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse): 698 if sparse: 699 t1 = build_sparse_tensor(requires_grad=True) 700 t2 = build_sparse_tensor(requires_grad=True) 701 else: 702 t1 = torch.ones((3, 3), requires_grad=True) 703 t2 = torch.zeros((3, 3), requires_grad=True) 704 705 local_ret = torch.add(t1, t2) 706 if sparse: 707 torch.sparse.sum(local_ret).backward() 708 else: 709 local_ret.sum().backward() 710 711 # create rref on self 712 rref_t1 = rpc.remote( 713 worker_name(self.rank), 714 create_ref_fn, 715 args=()) 716 717 # kick off forward and backward pass on three other workers (trainers) 718 rank_diffs = [1, 2, 3] 719 futures = [] 720 for rank_diff in rank_diffs: 721 futures.append( 722 rpc.rpc_async( 723 worker_name((self.rank + rank_diff) % self.world_size), 724 trainer_fn, 725 args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse), 726 ) 727 ) 728 729 # check if the trainers have done with their backward pass 730 for rank_diff in rank_diffs: 731 self._check_rpc_done(rank_diff) 732 733 # trainers are done and holding the context for verification 734 accumulate_grad_func = None 735 for rank_diff in rank_diffs: 736 # make sure grads are accumulated for the same tensors and values 737 # are all correct 738 ctx_id = ctx_ids[rank_diff] 739 grads = dist_autograd.get_gradients(ctx_id) 740 local_t1 = rref_t1.to_here() 741 self.assertIn(local_t1, grads) 742 self.assertEqual(grads[local_t1], t1.grad) 743 744 # unblock trainers 745 _set_rpc_done(None, 0) 746 747 # wait until all trainers are done 748 torch.futures.wait_all(futures) 749 750 def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse): 751 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 752 with dist_autograd.context() as context_id: 753 # Multiple RPCs between different nodes. 754 val = self._exec_func(exec_mode, torch.add, t1, t2) 755 val = self._exec_func(exec_mode, torch.mul, t3, val) 756 s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) 757 s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) 758 if sparse: 759 val = self._exec_func(exec_mode, torch.mul, s1, s2) 760 val = self._exec_func(exec_mode, torch.mul, val, val) 761 loss = torch.sparse.sum(val) 762 else: 763 val = self._exec_func(exec_mode, torch.bmm, s1, s2) 764 val = self._exec_func(exec_mode, torch.matmul, val, val) 765 loss = val.sum() 766 767 ret = self._verify_backwards( 768 exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5 769 ) 770 local_grads = ret if ret else local_grads 771 772 def _backward_different_dtypes(self, t1, t2, sparse): 773 local_grads = None 774 for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: 775 with dist_autograd.context() as context_id: 776 loss = self._exec_func(exec_mode, torch.add, t1, t2) 777 if sparse: 778 loss = torch.sparse.sum(loss) 779 else: 780 loss = loss.sum() 781 local_grads = self._verify_backwards( 782 exec_mode, [loss], context_id, local_grads, t1, t2 783 ) 784 785 # Run the same code locally and with dist autograd and verify gradients 786 # are same. 787 def _backward_simple_python_udf(self, t1, t2, sparse): 788 local_grads = None 789 for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: 790 with dist_autograd.context() as context_id: 791 ret = self._exec_func(exec_mode, my_py_add, t1, t2) 792 if sparse: 793 loss = torch.sparse.sum(ret) 794 else: 795 loss = ret.sum() 796 local_grads = self._verify_backwards( 797 exec_mode, [loss], context_id, local_grads, t1, t2 798 ) 799 800 # Run the same code locally and with dist autograd and verify gradients 801 # are same. 802 def _backward_simple_script_call(self, t1, t2, sparse): 803 local_grads = None 804 for exec_mode in [ 805 ExecMode.LOCAL, 806 ExecMode.RPC_SYNC, 807 ExecMode.RPC_ASYNC, 808 ExecMode.REMOTE, 809 ]: 810 with dist_autograd.context() as context_id: 811 forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2) 812 if sparse: 813 loss = torch.sparse.sum(forward_ret) 814 else: 815 loss = forward_ret.sum() 816 ret = self._verify_backwards( 817 exec_mode, [loss], context_id, local_grads, t1, t2 818 ) 819 local_grads = ret if ret else local_grads 820 821 def _nested_backward_accumulate_grads(self, t1, t2, sparse): 822 with dist_autograd.context() as context_id: 823 ret = rpc.rpc_sync( 824 worker_name(self._next_rank()), 825 DistAutogradTest._test_nested_backward_accumulate_grads, 826 args=(t1, t2, self._next_rank()), 827 ) 828 if sparse: 829 loss = torch.sparse.sum(ret) 830 else: 831 loss = ret.sum() 832 # Run backward twice. 833 dist_autograd.backward(context_id, [loss], retain_graph=True) 834 dist_autograd.backward(context_id, [loss]) 835 836 def _backwards_nested_python_udf(self, t1, t2, sparse): 837 t3 = t1 * t2 838 t4 = t1 + t2 839 res = t3 + t4 840 loss = t1 * t2 * t3 * t4 * res 841 if sparse: 842 loss = torch.sparse.sum(loss) 843 else: 844 loss = loss.sum() 845 torch.autograd.backward([loss]) 846 847 # Now run distributed autograd. 848 with dist_autograd.context() as context_id: 849 loss = rpc.rpc_sync( 850 worker_name(self._next_rank()), 851 DistAutogradTest._nested_python_udf, 852 args=(t1, t2, self._next_rank()), 853 ) 854 if sparse: 855 loss = torch.sparse.sum(loss) 856 else: 857 loss = loss.sum() 858 dist_autograd.backward(context_id, [loss]) 859 grads = dist_autograd.get_gradients(context_id) 860 self.assertEqual(t1.grad, grads[t1]) 861 self.assertEqual(t2.grad, grads[t2]) 862 863 def _mixed_requires_grad(self, t1, t2, sparse): 864 for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]: 865 with dist_autograd.context() as context_id: 866 ret = self._exec_func( 867 exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2 868 ) 869 self.assertEqual(t1 * t2, ret) 870 if sparse: 871 loss = torch.sparse.sum(ret) 872 else: 873 loss = ret.sum() 874 dist_autograd.backward(context_id, [loss]) 875 self.assertTrue(t1.requires_grad) 876 self.assertFalse(t2.requires_grad) 877 grads = dist_autograd.get_gradients(context_id) 878 self.assertIn(t1, grads) 879 self.assertNotIn(t2, grads) 880 self.assertEqual(t2, grads[t1]) 881 882 def _multiple_backward(self, t1, t2, sparse): 883 with dist_autograd.context() as context_id: 884 loss = rpc.rpc_sync( 885 worker_name(self._next_rank()), 886 torch.add, 887 args=(t1, t2)) 888 if sparse: 889 loss = torch.sparse.sum(loss) 890 else: 891 loss = loss.sum() 892 # Run backward in a loop multiple times. 893 for i in range(1000): 894 dist_autograd.backward(context_id, [loss], retain_graph=True) 895 896 # For current context, this rank sends t1 and t2 tensors to dst_rank, 897 # then get t3 = torch.add(t1, t2) result tensor. 898 # For the current context in this rank, it expects graph like this: 899 # send function: 900 # rpcSendBackward 901 # / \ 902 # t1.AccumulateGrad t2.AccumulateGrad 903 # 904 # recv function: 905 # 906 # | 907 # t3.rpcRecvBackward 908 # 909 def _verify_graph_for_first_rpc_call( 910 self, send_function, recv_function, t1, t2, ret 911 ): 912 # Retrieve the next functions in the graph. 913 next_funcs = send_function.next_functions 914 self.assertEqual(2, len(next_funcs)) 915 916 # We should now hit t1 and t2 in the autograd graph. 917 self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) 918 self.assertEqual(t1, next_funcs[0][0].variable) 919 self.assertEqual(0, next_funcs[0][1]) 920 self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) 921 self.assertEqual(t2, next_funcs[1][0].variable) 922 self.assertEqual(0, next_funcs[1][1]) 923 924 # Test recv functions. 925 self.assertEqual(ret.grad_fn, recv_function) 926 927 # Run the same code locally and with dist autograd and verify gradients 928 # are same. 929 def _backward_simple(self, dst, t1, t2, local_grads, sparse): 930 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 931 with dist_autograd.context() as context_id: 932 ret = self._exec_func_with_dst( 933 dst, exec_mode, torch.add, t1, t2 934 ) 935 if sparse: 936 loss = torch.sparse.sum(ret) 937 else: 938 loss = ret.sum() 939 ret = self._verify_backwards( 940 exec_mode, [loss], context_id, local_grads, t1, t2 941 ) 942 local_grads = ret if ret else local_grads 943 944 # For a context passed from previous nested chain calls, this rank 945 # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends 946 # result tensor t3 back. 947 # For this context in this rank, it expects graph like this: 948 # send and recv functions: 949 # rpcSendBackward 950 # | 951 # t3.AddBackward0 952 # / \ 953 # t1.recvRpcBackward t2.recvRpcBackward 954 def _verify_graph_for_rpc_call_exec(self, send_function): 955 # Verify next function is AddBackward0 956 next_funcs = send_function.next_functions 957 self.assertEqual(1, len(next_funcs)) 958 add_backward_fn = next_funcs[0][0] 959 self.assertEqual("AddBackward0", add_backward_fn.name()) 960 961 # Verify the next two functions are the same recv backward function. 962 next_funcs = add_backward_fn.next_functions 963 self.assertEqual(2, len(next_funcs)) 964 self.assertEqual( 965 "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() 966 ) 967 self.assertEqual( 968 "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() 969 ) 970 self.assertEqual(next_funcs[0][0], next_funcs[1][0]) 971 972 # For a context passed from previous nested chain calls, this rank 973 # receives two tensors t1 and t2, forwards t1 and t2 tensors using 974 # nested rpc call to next dst. In return route, receive result tensor t3 975 # from next dst and forwarding t3 back to previous calls. 976 # For this context in this rank, it expects graph like this: 977 # send and recv functions for receiving and forwarding t1 and t2: 978 # rpcSendBackward 979 # / \ 980 # t1.recvRpcBackward t2.recvRpcBackward 981 # send and recv functions for receiving and forwarding t3: 982 # rpcSendBackward 983 # | 984 # t3.recvRpcBackward 985 def _verify_graph_for_nested_rpc_call(self, ctx): 986 send_functions = ctx._send_functions() 987 self.assertEqual(2, len(send_functions)) 988 989 # For send function when making nest rpc call, 990 # next functions of the send function are two recv functions 991 # for received two tensors from previous call 992 next_funcs = next(iter(send_functions.values())).next_functions 993 self.assertEqual(2, len(next_funcs)) 994 self.assertEqual( 995 "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() 996 ) 997 self.assertEqual( 998 "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() 999 ) 1000 self.assertEqual(next_funcs[0][0], next_funcs[1][0]) 1001 1002 # For send function when returning response to previous call 1003 # next function of the send function is the recv function 1004 # for received tensor result returned from nested call 1005 next_funcs = list(send_functions.values())[1].next_functions 1006 self.assertEqual(1, len(next_funcs)) 1007 self.assertEqual( 1008 "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() 1009 ) 1010 1011 1012class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest): 1013 1014 # Sparse tests only work with TensorPipeAgent. 1015 @dist_init 1016 def test_graph_for_builtin_call_sparse(self): 1017 self._test_graph(torch.add, ExecMode.RPC_SYNC, True) 1018 1019 @dist_init 1020 def test_graph_for_python_call_sparse(self): 1021 self._test_graph(my_py_add, ExecMode.RPC_SYNC, True) 1022 1023 @dist_init 1024 def test_graph_for_builtin_remote_call_sparse(self): 1025 self._test_graph(torch.add, ExecMode.REMOTE, True) 1026 1027 @dist_init 1028 def test_graph_for_python_remote_call_sparse(self): 1029 self._test_graph(my_py_add, ExecMode.REMOTE, True) 1030 1031 @dist_init 1032 def test_graph_for_py_nested_call_sparse(self): 1033 self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True) 1034 1035 @dist_init 1036 def test_graph_for_py_nested_remote_call_sparse(self): 1037 self._test_graph_for_py_nested_call(ExecMode.REMOTE, True) 1038 1039 @dist_init 1040 def test_graph_for_py_nested_call_itself_sparse(self): 1041 self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True) 1042 1043 @dist_init 1044 def test_graph_for_py_nested_remote_call_itself_sparse(self): 1045 self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True) 1046 1047 @dist_init 1048 def test_no_graph_with_tensors_not_require_grad_sparse(self): 1049 self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True) 1050 1051 @dist_init 1052 def test_no_graph_with_tensors_not_require_grad_remote_sparse(self): 1053 self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True) 1054 1055 @dist_init 1056 def test_rpc_complex_args_sparse(self): 1057 self._test_rpc_complex_args(ExecMode.RPC_SYNC, True) 1058 1059 @dist_init 1060 def test_remote_complex_args_sparse(self): 1061 self._test_rpc_complex_args(ExecMode.REMOTE, True) 1062 1063 @dist_init 1064 def test_context_cleanup_tensor_with_grad_sparse(self): 1065 t1 = build_sparse_tensor(requires_grad=True) 1066 t2 = build_sparse_tensor(requires_grad=True) 1067 self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) 1068 1069 @dist_init 1070 def test_context_cleanup_tensor_no_grad_sparse(self): 1071 t1 = build_sparse_tensor(requires_grad=False) 1072 self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) 1073 1074 @dist_init 1075 def test_context_cleanup_nested_rpc_sparse(self): 1076 t1 = build_sparse_tensor(requires_grad=True) 1077 t2 = build_sparse_tensor(requires_grad=True) 1078 dst_rank = (self.rank + 1) % self.world_size 1079 args = (t1, t2, dst_rank, self.world_size, 0) 1080 self.context_cleanup_test_helper( 1081 rpc_args=args, func=my_py_nested_call, nested=True 1082 ) 1083 1084 @dist_init 1085 def test_backward_no_grad_on_tensor_sparse(self): 1086 self._backward_no_grad_on_tensor( 1087 build_sparse_tensor(requires_grad=True), 1088 build_sparse_tensor(requires_grad=True), 1089 True 1090 ) 1091 1092 @dist_init 1093 def test_backward_simple_sparse(self): 1094 self._backward_simple( 1095 self._next_rank(), 1096 build_sparse_tensor(requires_grad=True), 1097 build_sparse_tensor(requires_grad=True), 1098 None, 1099 True 1100 ) 1101 1102 @dist_init 1103 def test_backward_simple_self_sparse(self): 1104 self._backward_simple( 1105 self.rank, 1106 build_sparse_tensor(requires_grad=True), 1107 build_sparse_tensor(requires_grad=True), 1108 None, 1109 True 1110 ) 1111 1112 @dist_init 1113 def test_backward_rref_multi_sparse(self): 1114 if self.rank > 0: 1115 callee = "worker0" 1116 rref_owner = callee 1117 self._backward_rref( 1118 callee, 1119 rref_owner, 1120 build_sparse_tensor(requires_grad=True), 1121 build_sparse_tensor(requires_grad=True), 1122 None, 1123 True 1124 ) 1125 1126 @dist_init 1127 def test_backward_rref_sparse(self): 1128 callee = worker_name(self._next_rank()) 1129 rref_owner = callee 1130 self._backward_rref( 1131 callee, 1132 rref_owner, 1133 build_sparse_tensor(requires_grad=True), 1134 build_sparse_tensor(requires_grad=True), 1135 None, 1136 True 1137 ) 1138 1139 @dist_init 1140 def test_backward_rref_nested_sparse(self): 1141 callee = worker_name((self.rank + 1) % self.world_size) 1142 rref_owner = worker_name((self.rank + 2) % self.world_size) 1143 self._backward_rref( 1144 callee, 1145 rref_owner, 1146 build_sparse_tensor(requires_grad=True), 1147 build_sparse_tensor(requires_grad=True), 1148 None, 1149 True 1150 ) 1151 1152 @dist_init 1153 def test_trainer_ps_sparse(self): 1154 self._test_trainer_ps( 1155 build_sparse_tensor, 1156 _run_trainer, 1157 True 1158 ) 1159 1160 @dist_init 1161 def test_backward_multiple_round_trips_sparse(self): 1162 self._backward_multiple_round_trips( 1163 build_sparse_tensor(requires_grad=True), 1164 build_sparse_tensor(requires_grad=False), 1165 build_sparse_tensor(requires_grad=True), 1166 build_sparse_tensor(requires_grad=False), 1167 build_sparse_tensor(requires_grad=True), 1168 None, 1169 True 1170 ) 1171 1172 @dist_init 1173 def test_backward_different_dtypes_sparse(self): 1174 self._backward_different_dtypes( 1175 build_sparse_tensor(requires_grad=True, dtype=torch.float32), 1176 build_sparse_tensor(requires_grad=True, dtype=torch.float64), 1177 True 1178 ) 1179 1180 @dist_init 1181 def test_backward_simple_python_udf_sparse(self): 1182 self._backward_simple_python_udf( 1183 build_sparse_tensor(requires_grad=True), 1184 build_sparse_tensor(requires_grad=True), 1185 True 1186 ) 1187 1188 @dist_init 1189 def test_backward_simple_script_call_sparse(self): 1190 self._backward_simple_script_call( 1191 build_sparse_tensor(requires_grad=True), 1192 build_sparse_tensor(requires_grad=True), 1193 True 1194 ) 1195 1196 @dist_init 1197 def test_nested_backward_accumulate_grads_sparse(self): 1198 self._nested_backward_accumulate_grads( 1199 build_sparse_tensor(requires_grad=True), 1200 build_sparse_tensor(requires_grad=True), 1201 True 1202 ) 1203 1204 @dist_init 1205 def test_backwards_nested_python_udf_sparse(self): 1206 # Run equivalent of _nested_python_udf locally. 1207 self._backwards_nested_python_udf( 1208 build_sparse_tensor(requires_grad=True), 1209 build_sparse_tensor(requires_grad=True), 1210 True 1211 ) 1212 1213 @dist_init 1214 def test_mixed_requires_grad_sparse(self): 1215 self._mixed_requires_grad( 1216 build_sparse_tensor(requires_grad=True), 1217 build_sparse_tensor(requires_grad=False), 1218 True 1219 ) 1220 1221 @dist_init 1222 def test_multiple_backward_sparse(self): 1223 self._multiple_backward( 1224 build_sparse_tensor(requires_grad=True), 1225 build_sparse_tensor(requires_grad=True), 1226 True 1227 ) 1228 1229 @dist_init 1230 def test_embedding_bag_with_no_grad_tensors(self): 1231 dst = self._next_rank() 1232 remote_embedding = rpc.remote( 1233 worker_name(dst), 1234 torch.nn.EmbeddingBag, 1235 args=(16, 16), 1236 kwargs={"mode": "sum", "sparse": True}, 1237 ) 1238 local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True) 1239 1240 input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 1241 # requires_grad = True to record send/recv functions 1242 per_sample_weights = torch.rand((8), requires_grad=True) 1243 offsets = torch.LongTensor([0, 4]) 1244 1245 local_res = local_embedding(input, offsets, per_sample_weights) 1246 1247 # Run backward twice. 1248 torch.autograd.backward([local_res.sum()], retain_graph=True) 1249 torch.autograd.backward([local_res.sum()]) 1250 local_grad = local_embedding.weight.grad 1251 1252 with dist_autograd.context() as context_id: 1253 res = rpc.rpc_sync( 1254 worker_name(dst), 1255 DistAutogradTest._call_remote_embedding, 1256 args=(remote_embedding, input, offsets, per_sample_weights), 1257 ) 1258 1259 # Run backward twice to test accumulation of sparse gradients. 1260 dist_autograd.backward(context_id, [res.sum()], retain_graph=True) 1261 dist_autograd.backward(context_id, [res.sum()]) 1262 1263 remote_grad = rpc.rpc_sync( 1264 worker_name(dst), 1265 DistAutogradTest._get_grad, 1266 args=(remote_embedding, context_id), 1267 ) 1268 1269 self.assertEqual(local_grad, remote_grad) 1270 1271 1272class DistAutogradTest(CommonDistAutogradTest): 1273 @dist_init 1274 def test_autograd_context(self): 1275 # Verify max possible id. 1276 max_auto_increment = 281474976710655 1277 self.assertEqual( 1278 max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id() 1279 ) 1280 1281 context_ids = [] 1282 for i in range(200): 1283 with dist_autograd.context() as context_id: 1284 self.assertEqual( 1285 context_id, 1286 dist_autograd._retrieve_context(context_id)._context_id(), 1287 ) 1288 # First 16 bits should be worker_id. 1289 self.assertEqual(self.worker_id, context_id >> 48) 1290 context_ids.append(context_id) 1291 1292 for context_id in context_ids: 1293 with self.assertRaisesRegex( 1294 RuntimeError, 1295 f"Could not find autograd context with id: {context_id}", 1296 ): 1297 dist_autograd._retrieve_context(context_id) 1298 1299 @dist_init 1300 def test_nested_context(self): 1301 with dist_autograd.context() as context_id: 1302 # Nested contexts not supported. 1303 with self.assertRaisesRegex( 1304 RuntimeError, "Already have an autograd context id for this thread" 1305 ): 1306 with dist_autograd.context() as context_id: 1307 pass 1308 1309 @dist_init 1310 def test_graph_for_builtin_call(self): 1311 self._test_graph(torch.add, ExecMode.RPC_SYNC, False) 1312 1313 @dist_init 1314 def test_graph_for_python_call(self): 1315 self._test_graph(my_py_add, ExecMode.RPC_SYNC, False) 1316 1317 @dist_init 1318 def test_graph_for_builtin_remote_call(self): 1319 self._test_graph(torch.add, ExecMode.REMOTE, False) 1320 1321 @dist_init 1322 def test_graph_for_python_remote_call(self): 1323 self._test_graph(my_py_add, ExecMode.REMOTE, False) 1324 1325 @dist_init 1326 def test_graph_for_py_nested_call(self): 1327 self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False) 1328 1329 @dist_init 1330 def test_graph_for_py_nested_remote_call(self): 1331 self._test_graph_for_py_nested_call(ExecMode.REMOTE, False) 1332 1333 @dist_init 1334 def test_graph_for_py_nested_call_itself(self): 1335 self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False) 1336 1337 @dist_init 1338 def test_graph_for_py_nested_remote_call_itself(self): 1339 self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False) 1340 1341 @dist_init 1342 def test_no_graph_with_tensors_not_require_grad(self): 1343 self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False) 1344 1345 @dist_init 1346 def test_no_graph_with_tensors_not_require_grad_remote(self): 1347 self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False) 1348 1349 def _test_grad_only_on_return_value(self, exec_mode): 1350 initialize_pg(self.file_init_method, self.rank, self.world_size) 1351 dst_rank = (self.rank + 1) % self.world_size 1352 with dist_autograd.context() as context_id: 1353 if ExecMode.RPC_SYNC == exec_mode: 1354 ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad) 1355 elif ExecMode.REMOTE == exec_mode: 1356 ret = rpc.remote( 1357 worker_name(dst_rank), ret_requires_grad 1358 ).to_here() 1359 else: 1360 raise ValueError(f"Unrecognized ExecMode {exec_mode}") 1361 1362 dist_autograd.backward(context_id, [ret.sum()]) 1363 1364 rpc.rpc_sync( 1365 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 1366 ) 1367 1368 # Wait for the prev rank to be done with rpc. 1369 self._check_rpc_done(1) 1370 grads = dist_autograd.get_gradients(ctx_ids[1]) 1371 self.assertEqual(1, len(grads)) 1372 self.assertIn(requires_grad_tensor, grads) 1373 self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor]) 1374 # due to the above get_gradients call, ensure that dist autograd 1375 # contexts aren't cleaned up until all workers exit context managers 1376 dist.barrier() 1377 1378 @dist_init 1379 def test_grad_only_on_return_value(self): 1380 self._test_grad_only_on_return_value(ExecMode.RPC_SYNC) 1381 1382 @dist_init 1383 def test_grad_only_on_return_value_remote(self): 1384 self._test_grad_only_on_return_value(ExecMode.REMOTE) 1385 1386 @dist_init 1387 def test_rpc_complex_args(self): 1388 self._test_rpc_complex_args(ExecMode.RPC_SYNC, False) 1389 1390 @dist_init 1391 def test_remote_complex_args(self): 1392 self._test_rpc_complex_args(ExecMode.REMOTE, False) 1393 1394 @dist_init 1395 def test_context_cleanup_tensor_with_grad(self): 1396 t1 = torch.ones(3, 3, requires_grad=True) 1397 t2 = torch.zeros(3, 3, requires_grad=True) 1398 self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) 1399 1400 @dist_init 1401 def test_context_cleanup_tensor_no_grad(self): 1402 t1 = torch.ones(3, 3, requires_grad=False) 1403 self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) 1404 1405 @dist_init 1406 def test_context_cleanup_no_tensors(self): 1407 self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add) 1408 1409 @dist_init 1410 def test_context_cleanup_nested_rpc(self): 1411 t1 = torch.ones(3, 3, requires_grad=True) 1412 t2 = torch.zeros(3, 3, requires_grad=True) 1413 dst_rank = (self.rank + 1) % self.world_size 1414 args = (t1, t2, dst_rank, self.world_size, 0) 1415 self.context_cleanup_test_helper( 1416 rpc_args=args, func=my_py_nested_call, nested=True 1417 ) 1418 1419 @dist_init 1420 def test_worker_ids_recorded(self): 1421 dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} 1422 with dist_autograd.context() as context_id: 1423 # if no tensors require grad, we should still record worker_ids, as 1424 # the autograd context ID is still passed to other workers. 1425 t1 = torch.ones(3, 3, requires_grad=False) 1426 t2 = torch.zeros(3, 3, requires_grad=False) 1427 for dst_rank in dst_ranks: 1428 rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) 1429 rpc.rpc_sync( 1430 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 1431 ) 1432 # all worker_ids in dst_ranks should be recorded. 1433 ctx = dist_autograd._current_context() 1434 worker_ids = ctx._known_worker_ids() 1435 self.assertEqual(worker_ids, dst_ranks) 1436 1437 # worker_ids should be recorded when tensors do require grad 1438 t1.requires_grad = True 1439 t2.requires_grad = True 1440 for dst_rank in dst_ranks: 1441 ret = rpc.rpc_sync( 1442 worker_name(dst_rank), torch.add, args=(t1, t2) 1443 ) 1444 rpc.rpc_sync( 1445 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 1446 ) 1447 # all worker_ids in dst_ranks should be recorded. 1448 worker_ids = ctx._known_worker_ids() 1449 self.assertEqual(worker_ids, dst_ranks) 1450 1451 @dist_init 1452 def test_dist_autograd_profiling(self): 1453 with dist_autograd.context() as context_id: 1454 t1 = torch.rand(3, 3, requires_grad=True) 1455 t2 = torch.rand(3, 3, requires_grad=True) 1456 loss = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)).sum() 1457 with torch.autograd.profiler.profile() as p: 1458 dist_autograd.backward(context_id, [loss]) 1459 1460 function_events = p.function_events 1461 1462 def get_event(partial_key): 1463 return next(event for event in function_events if partial_key in event.name) 1464 1465 send_event = get_event("SendRpcBackward") 1466 recv_event = get_event("RecvRpcBackward") 1467 backward_event = get_event("torch::distributed::autograd::backward") 1468 # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed. 1469 self.assertEqual(send_event.count, 1) 1470 self.assertEqual(recv_event.count, 1) 1471 # The CPU total for backward event should be great than send and recv, since 1472 # applying those functions in the backwards pass is a subset of the entire backward pass. 1473 self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total) 1474 self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total) 1475 1476 @dist_init 1477 def test_error_in_context(self): 1478 with dist_autograd.context() as context_id: 1479 t1 = torch.rand(3, 3, requires_grad=True) 1480 t2 = torch.rand(6, 6, requires_grad=True) 1481 1482 with self.assertRaises(RuntimeError): 1483 # This should throw an error since matrix sizes don't match. 1484 rpc.rpc_sync( 1485 worker_name(self._next_rank()), torch.matmul, args=(t1, t2) 1486 ) 1487 1488 @dist_init 1489 def test_backward_no_grad_on_tensor(self): 1490 self._backward_no_grad_on_tensor( 1491 torch.rand((3, 3), requires_grad=True), 1492 torch.rand((3, 3), requires_grad=True), 1493 False 1494 ) 1495 1496 @dist_init 1497 def test_backward_simple(self): 1498 self._backward_simple( 1499 self._next_rank(), 1500 torch.rand((3, 3), requires_grad=True), 1501 torch.rand((3, 3), requires_grad=True), 1502 None, 1503 False 1504 ) 1505 1506 @dist_init 1507 def test_backward_simple_self(self): 1508 self._backward_simple( 1509 self.rank, 1510 torch.rand((3, 3), requires_grad=True), 1511 torch.rand((3, 3), requires_grad=True), 1512 None, 1513 False 1514 ) 1515 1516 @dist_init 1517 def test_backward_rref(self): 1518 callee = worker_name(self._next_rank()) 1519 rref_owner = callee 1520 self._backward_rref( 1521 callee, 1522 rref_owner, 1523 torch.rand((3, 3), requires_grad=True), 1524 torch.rand((3, 3), requires_grad=True), 1525 None, 1526 False 1527 ) 1528 1529 @dist_init 1530 def test_backward_rref_multi(self): 1531 if self.rank > 0: 1532 callee = "worker0" 1533 rref_owner = callee 1534 self._backward_rref( 1535 callee, 1536 rref_owner, 1537 torch.rand((3, 3), requires_grad=True), 1538 torch.rand((3, 3), requires_grad=True), 1539 None, 1540 False 1541 ) 1542 1543 @dist_init 1544 def test_backward_rref_nested(self): 1545 callee = worker_name((self.rank + 1) % self.world_size) 1546 rref_owner = worker_name((self.rank + 2) % self.world_size) 1547 self._backward_rref( 1548 callee, 1549 rref_owner, 1550 torch.rand((3, 3), requires_grad=True), 1551 torch.rand((3, 3), requires_grad=True), 1552 None, 1553 False 1554 ) 1555 1556 @dist_init 1557 def test_trainer_ps(self): 1558 self._test_trainer_ps( 1559 create_tensor, 1560 _run_trainer, 1561 False 1562 ) 1563 1564 @dist_init 1565 def test_trainer_ps_torchscript_functions(self): 1566 # TODO, need more investigation 1567 # there is rref leak when shutting down, suspect it is because 1568 # ref as arg is passed to pybind boundary, and the ref is not garbage 1569 # collected by python when calling shutdown() 1570 import torch.distributed.rpc.api as api 1571 api._ignore_rref_leak = True 1572 1573 self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False) 1574 1575 @dist_init 1576 def test_backward_multiple_round_trips(self): 1577 self._backward_multiple_round_trips( 1578 torch.rand((3, 3), requires_grad=True), 1579 torch.rand((3, 3)), 1580 torch.rand((3, 3), requires_grad=True), 1581 torch.rand((3, 3)), 1582 torch.rand((3, 3), requires_grad=True), 1583 None, 1584 False 1585 ) 1586 1587 @dist_init 1588 def test_backward_different_tensor_dims(self): 1589 local_grads = None 1590 t1 = torch.rand((4, 6), requires_grad=True) 1591 t2 = torch.rand((6, 5)) 1592 t3 = torch.rand((5, 7), requires_grad=True) 1593 t4 = torch.rand((7, 9)) 1594 1595 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 1596 with dist_autograd.context() as context_id: 1597 val = self._exec_func(exec_mode, torch.matmul, t1, t2) 1598 val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4)) 1599 loss = val.sum() 1600 1601 ret = self._verify_backwards( 1602 exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4 1603 ) 1604 local_grads = ret if ret else local_grads 1605 1606 @dist_init 1607 def test_backward_unused_tensors(self): 1608 local_grads = None 1609 t1 = torch.rand((3, 3), requires_grad=True) 1610 t2 = torch.rand((3, 3), requires_grad=True) 1611 t3 = torch.rand((3, 3), requires_grad=True) 1612 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 1613 with dist_autograd.context() as context_id: 1614 s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3)) 1615 val = self._exec_func( 1616 exec_mode, 1617 torch.matmul, 1618 torch.narrow(s, 0, 0, 1), 1619 torch.narrow(s, 0, 2, 1), 1620 ) 1621 1622 loss = val.sum() 1623 ret = self._verify_backwards( 1624 exec_mode, [loss], context_id, local_grads, t1, t2, t3 1625 ) 1626 local_grads = ret if ret else local_grads 1627 1628 @dist_init 1629 def test_backward_multiple_output_tensors(self): 1630 local_grads = None 1631 t = torch.rand((10, 2), requires_grad=True) 1632 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 1633 with dist_autograd.context() as context_id: 1634 tensor_list = self._exec_func(exec_mode, torch.split, t, 2) 1635 t1 = tensor_list[0] 1636 t2 = tensor_list[2] 1637 t3 = tensor_list[4] 1638 1639 val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3)) 1640 1641 loss = val.sum() 1642 ret = self._verify_backwards( 1643 exec_mode, [loss], context_id, local_grads, t 1644 ) 1645 local_grads = ret if ret else local_grads 1646 1647 def _run_test_backward_unused_send_function_in_thread(self): 1648 with dist_autograd.context() as context_id: 1649 t1 = torch.rand((3, 3), requires_grad=True) 1650 t2 = torch.rand((3, 3), requires_grad=True) 1651 1652 # We don't use the result of an RPC function, as a result the 1653 # backward pass would hang in the "FAST" mode. 1654 res = rpc.rpc_sync( 1655 worker_name(self._next_rank()), torch.add, args=(t1, t2) 1656 ) 1657 1658 val = torch.mul(t1, t2) 1659 1660 # Run backward, this would hang forever. 1661 dist_autograd.backward(context_id, [val.sum()]) 1662 1663 @dist_init 1664 def test_backward_unused_send_function(self): 1665 # Run the test in a thread which would never finish. 1666 t = threading.Thread( 1667 target=self._run_test_backward_unused_send_function_in_thread 1668 ) 1669 t.daemon = True 1670 t.start() 1671 t.join(10) # Wait for 10s. 1672 1673 # Verify thread is still alive (indicating backward hasn't completed yet). 1674 self.assertTrue(t.is_alive()) 1675 1676 @dist_init 1677 def test_backward_autograd_engine_error(self): 1678 with dist_autograd.context() as context_id: 1679 t1 = torch.rand((3, 3), requires_grad=True) 1680 t2 = torch.rand((3, 3), requires_grad=True) 1681 # Perform some ops before error simulation. 1682 tmp = (t1 + t2) * (t1 + t2) 1683 t3 = SimulateBackwardError.apply(tmp) 1684 1685 # Run multiple round trips across different nodes and verify the 1686 # original node receives an error thrown on a node deep in the chain. 1687 val = rpc.rpc_sync( 1688 worker_name(self._next_rank()), torch.add, args=(t2, t3) 1689 ) 1690 val = rpc.rpc_sync( 1691 worker_name(self._next_rank()), torch.mul, args=(val, t2) 1692 ) 1693 val = rpc.rpc_sync( 1694 worker_name(self._next_rank()), torch.matmul, args=(val, t2) 1695 ) 1696 val = rpc.rpc_sync( 1697 worker_name(self._next_rank()), torch.div, args=(val, t2) 1698 ) 1699 1700 with self.assertRaisesRegex( 1701 RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass" 1702 ): 1703 # Run backwards, and validate we receive an error. 1704 dist_autograd.backward(context_id, [val.sum()]) 1705 1706 @dist_init(clean_shutdown=False) 1707 @skip_but_pass_in_sandcastle_if( 1708 IS_MACOS, 1709 "Test is flaky on MacOS since libuv error handling is not as robust as TCP", 1710 ) 1711 def test_backward_node_failure(self): 1712 rpc._set_rpc_timeout(5) # 5 seconds 1713 initialize_pg(self.file_init_method, self.rank, self.world_size) 1714 1715 with dist_autograd.context() as context_id: 1716 t1 = torch.rand((3, 3), requires_grad=True) 1717 t2 = torch.rand((3, 3), requires_grad=True) 1718 res = rpc.rpc_sync( 1719 worker_name(self._next_rank()), torch.add, args=(t1, t2) 1720 ) 1721 1722 # Wait for all RPCs to be done. 1723 dist.barrier() 1724 1725 # Kill all odd rank nodes. 1726 if self.rank % 2 == 0: 1727 shutdown_error_regex = self.get_shutdown_error_regex() 1728 # Wait for all other nodes to die. 1729 for rank in range(self.world_size): 1730 if rank % 2 != 0: 1731 wait_until_node_failure(rank, shutdown_error_regex) 1732 1733 # Shutdown sequence is not very well defined and as a result 1734 # we might see any error given by get_shutdown_error_regex() 1735 with self.assertRaisesRegex(RuntimeError, shutdown_error_regex): 1736 # Run backwards, and validate we receive an error since all 1737 # other nodes are dead. 1738 dist_autograd.backward(context_id, [res.sum()]) 1739 else: 1740 # Exit all other nodes. 1741 pass 1742 1743 @dist_init 1744 def test_backward_without_context(self): 1745 t1 = torch.rand((3, 3), requires_grad=True) 1746 t2 = torch.rand((3, 3), requires_grad=True) 1747 1748 context_id = 100 # dummy context_id 1749 with self.assertRaisesRegex( 1750 RuntimeError, 1751 f"Could not find autograd context with id: {context_id}", 1752 ): 1753 res = rpc.rpc_sync( 1754 worker_name(self._next_rank()), torch.add, args=(t1, t2) 1755 ) 1756 dist_autograd.backward(context_id, [res.sum()]) 1757 1758 @dist_init 1759 def test_backward_without_rpc(self): 1760 dst_rank = self.rank 1761 with dist_autograd.context() as context_id: 1762 t1 = torch.rand((3, 3), requires_grad=True) 1763 t2 = torch.rand((3, 3), requires_grad=True) 1764 t3 = torch.add(t1, t2) 1765 1766 dist_autograd.backward(context_id, [t3.sum()]) 1767 grads = dist_autograd.get_gradients(context_id) 1768 self.assertEqual(2, len(grads)) 1769 self.assertIn(t1, grads) 1770 self.assertIn(t2, grads) 1771 self.assertEqual(torch.ones(3, 3), grads[t1]) 1772 self.assertEqual(torch.ones(3, 3), grads[t2]) 1773 1774 @dist_init 1775 def test_backward_invalid_args(self): 1776 with dist_autograd.context() as context_id: 1777 1778 with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 1779 dist_autograd.backward(context_id, None) 1780 1781 with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 1782 dist_autograd.backward(None, None) 1783 1784 with self.assertRaisesRegex( 1785 RuntimeError, "No tensors provided for gradient computation" 1786 ): 1787 dist_autograd.backward(context_id, []) 1788 1789 with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"): 1790 t = torch.rand(3, 3) 1791 dist_autograd.backward(context_id, [t]) 1792 1793 with self.assertRaisesRegex( 1794 RuntimeError, "is not a scalar, all roots need to be scalar" 1795 ): 1796 t = torch.rand(3, 3, requires_grad=True) 1797 dist_autograd.backward(context_id, [t]) 1798 1799 with self.assertRaisesRegex( 1800 RuntimeError, "does not have a valid gradient function" 1801 ): 1802 t = torch.rand(1, requires_grad=True) 1803 dist_autograd.backward(context_id, [t]) 1804 1805 @dist_init 1806 def test_backward_multiple_roots(self): 1807 local_grads = None 1808 t1 = torch.rand((3, 3), requires_grad=True) 1809 t2 = torch.rand((3, 3), requires_grad=True) 1810 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: 1811 with dist_autograd.context() as context_id: 1812 r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum() 1813 r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum() 1814 r3 = self._exec_func(exec_mode, torch.cos, t1).sum() 1815 r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum() 1816 1817 local_grads = self._verify_backwards( 1818 exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2 1819 ) 1820 1821 @dist_init 1822 def test_backward_different_dtypes(self): 1823 self._backward_different_dtypes( 1824 torch.rand((3, 3), requires_grad=True, dtype=torch.float32), 1825 torch.rand((3, 3), requires_grad=True, dtype=torch.float64), 1826 False 1827 ) 1828 1829 @dist_init 1830 def test_backward_simple_python_udf(self): 1831 self._backward_simple_python_udf( 1832 torch.rand(3, 3, requires_grad=True), 1833 torch.rand(3, 3, requires_grad=True), 1834 False 1835 ) 1836 1837 @dist_init 1838 def test_backward_simple_script_call(self): 1839 self._backward_simple_script_call( 1840 torch.rand(3, 3, requires_grad=True), 1841 torch.rand(3, 3, requires_grad=True), 1842 False 1843 ) 1844 1845 @staticmethod 1846 def _complex_python_udf(t1, t2): 1847 t3 = torch.nn.functional.linear(t1, t2) 1848 t4 = torch.nn.functional.linear(t2, t3) 1849 t5 = torch.nn.functional.linear(t3, t4) 1850 return torch.linalg.multi_dot([t1, t2, t3, t4, t5]) 1851 1852 @dist_init 1853 def test_backward_complex_python_udf(self): 1854 # Run the same code locally and with dist autograd and verify gradients 1855 # are same. 1856 local_grads = None 1857 t1 = torch.rand((3, 3), requires_grad=True) 1858 t2 = torch.rand((3, 3), requires_grad=True) 1859 for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: 1860 with dist_autograd.context() as context_id: 1861 ret = self._exec_func( 1862 exec_mode, DistAutogradTest._complex_python_udf, t1, t2 1863 ) 1864 loss = ret.sum() 1865 local_grads = self._verify_backwards( 1866 exec_mode, [loss], context_id, local_grads, t1, t2 1867 ) 1868 1869 @staticmethod 1870 def _python_udf_with_backward_error(t1, t2): 1871 t3 = t1 + t2 1872 t4 = SimulateBackwardError.apply(t3) 1873 return torch.linalg.multi_dot([t1, t2, t3, t4]) 1874 1875 @staticmethod 1876 def _nested_rpc_call_backward_error(t1, t2, dst): 1877 t1 = t1 * t2 1878 t2 = t1 + t2 1879 res = rpc.rpc_sync( 1880 worker_name(dst), 1881 DistAutogradTest._python_udf_with_backward_error, 1882 args=(t1, t2), 1883 ) 1884 return torch.linalg.multi_dot([t1, t2, res]) 1885 1886 @dist_init 1887 def test_backward_python_udf_error(self): 1888 t1 = torch.rand((3, 3), requires_grad=True) 1889 t2 = torch.rand((3, 3), requires_grad=True) 1890 with dist_autograd.context() as context_id: 1891 loss = rpc.rpc_sync( 1892 worker_name(self._next_rank()), 1893 DistAutogradTest._nested_rpc_call_backward_error, 1894 args=(t1, t2, self._next_rank()), 1895 ) 1896 with self.assertRaisesRegex( 1897 RuntimeError, "Simulate error on backward pass" 1898 ): 1899 dist_autograd.backward(context_id, [loss.sum()]) 1900 1901 _backward_done = False 1902 1903 @dist_init(clean_shutdown=False) 1904 @skip_but_pass_in_sandcastle_if( 1905 IS_MACOS, 1906 "Test is flaky on MacOS since libuv error handling is not as robust as TCP", 1907 ) 1908 def test_backward_node_failure_python_udf(self): 1909 # Set a short timeout to quickly time out failed RPCs. 1910 rpc._set_rpc_timeout(5) # 5 seconds 1911 initialize_pg(self.file_init_method, self.rank, self.world_size) 1912 1913 with dist_autograd.context() as context_id: 1914 t1 = torch.rand((3, 3), requires_grad=True) 1915 t2 = torch.rand((3, 3), requires_grad=True) 1916 1917 dst = self._next_rank() 1918 res = rpc.rpc_sync( 1919 worker_name(dst), 1920 my_py_nested_call, 1921 args=(t1, t2, dst, self.world_size, 1), 1922 ) 1923 1924 dist.barrier() 1925 1926 # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error. 1927 if self.rank == 2: 1928 return 1929 1930 store = dist.distributed_c10d._get_default_store() 1931 if self.rank == 0: 1932 # Wait for rank 2 to die. 1933 shutdown_error_regex = self.get_shutdown_error_regex() 1934 wait_until_node_failure(2, shutdown_error_regex) 1935 # Shutdown sequence is not very well defined and as a result 1936 # we might see any error given by get_shutdown_error_regex(). 1937 with self.assertRaisesRegex(RuntimeError, shutdown_error_regex): 1938 # Run backwards, and validate we receive an error since rank 2 is dead. 1939 dist_autograd.backward(context_id, [res.sum()]) 1940 1941 # Mark rank 0 is done in the store, since the RPC framework on 1942 # some nodes might be broken at this point. 1943 store.set('test_backward_node_failure_python_udf_rank0_done', "True") 1944 else: 1945 # Wait for backward to finish on rank 0. 1946 store.wait(['test_backward_node_failure_python_udf_rank0_done'], timedelta(seconds=10)) 1947 1948 @staticmethod 1949 def _nested_python_udf(t1, t2, dst): 1950 t3 = t1 * t2 1951 t4 = t1 + t2 1952 res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) 1953 return t1 * t2 * t3 * t4 * res 1954 1955 @dist_init 1956 def test_backwards_nested_python_udf(self): 1957 # Run equivalent of _nested_python_udf locally. 1958 self._backwards_nested_python_udf( 1959 torch.rand(3, 3, requires_grad=True), 1960 torch.rand(3, 3, requires_grad=True), 1961 False 1962 ) 1963 1964 _test_clean_context_backward_context_id = None 1965 1966 class MyBackwardFunc(Function): 1967 @staticmethod 1968 def forward(ctx, input): 1969 return input 1970 1971 @staticmethod 1972 @once_differentiable 1973 def backward(ctx, input): 1974 assert DistAutogradTest._test_clean_context_backward_context_id is not None 1975 1976 # Release the context to simulate error (use barrier before releasing 1977 # context to ensure all nodes execute the backward function). 1978 dist.barrier() 1979 dist_autograd._release_context( 1980 DistAutogradTest._test_clean_context_backward_context_id 1981 ) 1982 1983 # Verify all contexts are cleaned up. 1984 assert _all_contexts_cleaned_up() 1985 1986 return input 1987 1988 @dist_init 1989 def test_clean_context_during_backward(self): 1990 """ 1991 This test simulates the situation where the 'backward' call might throw 1992 an exception locally which would lead to the autograd context being 1993 cleaned up if we're using the context manager. As a result, the autograd 1994 context might be cleaned up while some threads are still using the 1995 autograd context. 1996 1997 It is fine for the 'backward' call to throw an exception in this test, 1998 but the process should not crash. 1999 """ 2000 initialize_pg(self.file_init_method, self.rank, self.world_size) 2001 2002 context = dist_autograd._new_context() 2003 context_id = context._context_id() 2004 DistAutogradTest._test_clean_context_backward_context_id = context_id 2005 2006 # Send the context id to all nodes. 2007 for i in range(0, self.world_size): 2008 if i != self.rank: 2009 rank_distance = (i - self.rank + self.world_size) % self.world_size 2010 rpc.rpc_sync( 2011 worker_name(i), 2012 _set_rpc_done, 2013 args=(context_id, rank_distance), 2014 ) 2015 2016 dist.barrier() 2017 2018 # Verify all context ids have been received. 2019 self.assertEqual(self.world_size - 1, len(known_context_ids)) 2020 2021 t1 = torch.rand((3, 3), requires_grad=True) 2022 for i in range(0, 100): 2023 dst = self._next_rank() 2024 t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1)) 2025 2026 # Call MyBackwardFunc as the first op of the backward pass to 2027 # ensure we release the context early in the backward pass. 2028 t1 = DistAutogradTest.MyBackwardFunc.apply(t1) 2029 self.assertEqual(100, len(context._send_functions())) 2030 2031 context_id = 100 # dummy context_id 2032 with self.assertRaisesRegex( 2033 RuntimeError, 2034 f"Could not find autograd context with id: {context_id}", 2035 ): 2036 dist_autograd.backward(context_id, [t1.sum()]) 2037 2038 # HACK: Killing workers since otherwise the autograd engine gets stuck on 2039 # other nodes. The proper fix would be addressing: 2040 # https://github.com/pytorch/pytorch/issues/27643, which would inform 2041 # other nodes about the failure. 2042 # The autograd engine gets stuck on other nodes since they're waiting to 2043 # receive gradients from the node that received an error (and as a 2044 # result it didn't execute the rest of the graph). 2045 dist.barrier() 2046 rpc.shutdown(graceful=False) 2047 sys.exit(0) 2048 2049 @classmethod 2050 def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights): 2051 embedding = embedding_rref.local_value() 2052 return embedding(input, offsets, per_sample_weights) 2053 2054 @classmethod 2055 def _get_grad(cls, embedding_rref, context_id): 2056 embedding = embedding_rref.local_value() 2057 grad_map = dist_autograd.get_gradients(context_id) 2058 return grad_map[embedding.weight] 2059 2060 @classmethod 2061 def _mixed_requires_grad_operaton(cls, t1, t2): 2062 if t2.requires_grad: 2063 return t1 - t2 2064 else: 2065 return t1 * t2 2066 2067 @dist_init 2068 def test_mixed_requires_grad(self): 2069 self._mixed_requires_grad( 2070 torch.rand(3, 3, requires_grad=True), 2071 torch.rand(3, 3, requires_grad=False), 2072 False 2073 ) 2074 2075 class TestDebugInfoFunc(Function): 2076 @staticmethod 2077 def forward(ctx, input): 2078 return input 2079 2080 @staticmethod 2081 @once_differentiable 2082 def backward(ctx, input): 2083 debug_info = dist_autograd._get_debug_info() 2084 assert debug_info is not None 2085 backward_passes = int(debug_info["num_current_backward_passes"]) 2086 2087 # Hard to validate exact numbers because of the distributed nature. 2088 # We can't use a barrier() here since that would block the single 2089 # CPU thread available for autograd and can cause deadlocks. 2090 assert backward_passes >= 1 and backward_passes <= 4 2091 return input 2092 2093 @dist_init 2094 def test_debug_info(self): 2095 initialize_pg(self.file_init_method, self.rank, self.world_size) 2096 2097 t1 = torch.rand((3, 3), requires_grad=True) 2098 t2 = torch.rand((3, 3), requires_grad=True) 2099 with dist_autograd.context() as context_id: 2100 i = 0 2101 res = {} 2102 res[i] = t1 2103 for rank in range(self.world_size): 2104 if rank != self.rank: 2105 res[i + 1] = rpc.rpc_sync( 2106 worker_name(rank), torch.add, args=(res[i], t2) 2107 ) 2108 i += 1 2109 2110 # Call custom function in middle of backward pass to ensure all 2111 # nodes are still waiting on a backward(). 2112 res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i]) 2113 i += 1 2114 2115 for rank in range(self.world_size): 2116 if rank != self.rank: 2117 res[i + 1] = rpc.rpc_sync( 2118 worker_name(rank), torch.add, args=(res[i], t2) 2119 ) 2120 i += 1 2121 2122 dist_autograd.backward(context_id, [res[i].sum()]) 2123 2124 debug_info = dist_autograd._get_debug_info() 2125 num_autograd_context = int(debug_info["num_autograd_contexts"]) 2126 # Need atleast one context and not more than 4. 2127 self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4) 2128 2129 for rd in range(self.world_size - 1): 2130 rpc.rpc_sync( 2131 worker_name((self.rank + rd + 1) % self.world_size), 2132 _set_rpc_done, 2133 args=(context_id, rd + 1), 2134 ) 2135 2136 dist.barrier() 2137 2138 # Validate information 2139 debug_info = dist_autograd._get_debug_info() 2140 assert debug_info is not None 2141 self.assertEqual(0, int(debug_info["num_current_backward_passes"])) 2142 # only have `num_current_backward_passes` and `num_autograd contexts` 2143 self.assertTrue(len(debug_info) == 2) 2144 2145 self.assertTrue(_all_contexts_cleaned_up()) 2146 2147 # All contexts should be cleaned up. 2148 debug_info = dist_autograd._get_debug_info() 2149 self.assertEqual(0, int(debug_info["num_autograd_contexts"])) 2150 2151 @staticmethod 2152 def _workload_thread(): 2153 t1 = torch.rand((3, 3), requires_grad=True) 2154 t2 = torch.rand((3, 3), requires_grad=True) 2155 with dist_autograd.context() as context_id: 2156 t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2)) 2157 t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3)) 2158 t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4)) 2159 t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5)) 2160 2161 dist_autograd.backward(context_id, [t6.sum()]) 2162 2163 @dist_init 2164 def test_async_dist_autograd(self): 2165 """ 2166 This test ensures async processing for distributed autograd works 2167 appropriately. This is achieved by spawning multiple threads and 2168 hammering a single node with a lot of backward() calls. 2169 """ 2170 2171 initialize_pg(self.file_init_method, self.rank, self.world_size) 2172 if self.rank != 0: 2173 # All other ranks schedule work on rank 0. 2174 threads = [] 2175 for i in range(20): 2176 t = threading.Thread(target=DistAutogradTest._workload_thread) 2177 t.start() 2178 threads.append(t) 2179 2180 for thread in threads: 2181 thread.join() 2182 2183 dist.barrier() 2184 2185 @dist_init 2186 def test_backward_accumulate_grads(self): 2187 t1 = torch.rand((3, 3), requires_grad=True) 2188 t2 = torch.rand((3, 3), requires_grad=True) 2189 with dist_autograd.context() as context_id: 2190 t3 = torch.matmul(t1, t2) 2191 # Run backward twice. 2192 torch.autograd.backward([t3.sum()], retain_graph=True) 2193 torch.autograd.backward([t3.sum()]) 2194 2195 t3 = rpc.rpc_sync( 2196 worker_name(self._next_rank()), torch.matmul, args=(t1, t2) 2197 ) 2198 # Run backward twice. 2199 dist_autograd.backward(context_id, [t3.sum()], retain_graph=True) 2200 dist_autograd.backward(context_id, [t3.sum()]) 2201 2202 # Verify the gradients are same for local and remote execution. 2203 grads = dist_autograd.get_gradients(context_id) 2204 self.assertEqual(2, len(grads)) 2205 self.assertIn(t1, grads) 2206 self.assertIn(t2, grads) 2207 self.assertEqual(t1.grad, grads[t1]) 2208 self.assertEqual(t2.grad, grads[t2]) 2209 2210 @staticmethod 2211 def _test_nested_backward_accumulate_grads(t1, t2, dst_rank): 2212 return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) 2213 2214 @dist_init 2215 def test_nested_backward_accumulate_grads(self): 2216 self._nested_backward_accumulate_grads( 2217 torch.rand(3, 3, requires_grad=True), 2218 torch.rand(3, 3, requires_grad=True), 2219 False 2220 ) 2221 2222 @dist_init 2223 def test_multiple_backward(self): 2224 self._multiple_backward( 2225 torch.rand(3, 3, requires_grad=True), 2226 torch.rand(3, 3, requires_grad=True), 2227 False 2228 ) 2229 2230 @dist_init(clean_shutdown=False) 2231 def test_multiple_backward_with_errors(self): 2232 initialize_pg(self.file_init_method, self.rank, self.world_size) 2233 t1 = torch.rand((3, 3), requires_grad=True) 2234 t2 = torch.rand((3, 3), requires_grad=True) 2235 with dist_autograd.context() as context_id: 2236 loss = rpc.rpc_sync( 2237 f'worker{self._next_rank()}', 2238 DistAutogradTest._python_udf_with_backward_error, 2239 args=(t1, t2)).sum() 2240 2241 try: 2242 # Run backward in a loop multiple times. 2243 for i in range(100): 2244 if i < 50: 2245 with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"): 2246 dist_autograd.backward(context_id, [loss], retain_graph=True) 2247 elif i > 50: 2248 # Recovered from error. 2249 dist_autograd.backward(context_id, [loss], retain_graph=True) 2250 else: 2251 dist.barrier() 2252 SimulateBackwardError._simulate_error = False 2253 dist.barrier() 2254 finally: 2255 # Sync before resetting flag. 2256 dist.barrier() 2257 2258 # Reset the flag. 2259 SimulateBackwardError._simulate_error = True 2260 2261 @dist_init 2262 def test_backward_verify_hooks(self): 2263 t1 = torch.ones((3, 3), requires_grad=True) 2264 # Double the gradient. 2265 t1.register_hook(lambda grad: grad * 2) 2266 t2 = torch.ones((3, 3), requires_grad=True) 2267 local_grads = None 2268 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: 2269 with dist_autograd.context() as context_id: 2270 ret = self._exec_func(exec_mode, torch.matmul, t1, t2) 2271 loss = ret.sum() 2272 ret = self._verify_backwards( 2273 exec_mode, [loss], context_id, local_grads, t1, t2 2274 ) 2275 local_grads = ret if ret else local_grads 2276 2277 @dist_init 2278 def test_no_grad_copy(self): 2279 ''' 2280 Similar to test in test_autograd.py. 2281 ''' 2282 # create autograd function that saves grad pointer as class static 2283 class MyFunc(Function): 2284 static_grad_ptr = None 2285 2286 @staticmethod 2287 def forward(ctx, inp1, inp2): 2288 return inp1 + inp2 2289 2290 @staticmethod 2291 def backward(ctx, grad): 2292 MyFunc.static_grad_ptr = grad.data_ptr() 2293 return grad, grad 2294 2295 class MyFuncSingleGrad(Function): 2296 static_grad_ptr = None 2297 2298 @staticmethod 2299 def forward(ctx, inp): 2300 return inp 2301 2302 @staticmethod 2303 def backward(ctx, grad): 2304 MyFuncSingleGrad.static_grad_ptr = grad.data_ptr() 2305 return grad 2306 2307 class NonContGradFunc(Function): 2308 @staticmethod 2309 def forward(ctx, inp1): 2310 ctx.size = inp1.size() 2311 return torch.tensor([1.]) 2312 2313 @staticmethod 2314 def backward(ctx, grad): 2315 return torch.ones(1).expand(ctx.size) 2316 2317 a = torch.randn(5, 6, requires_grad=True) 2318 b = torch.randn(5, 6, requires_grad=True) 2319 # non-contiguous grad should be copied 2320 with dist_autograd.context() as context_id: 2321 dist_autograd.backward(context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]) 2322 grads = dist_autograd.get_gradients(context_id) 2323 self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr) 2324 self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr) 2325 2326 # test case that should trigger no copy for a 2327 with dist_autograd.context() as context_id: 2328 dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]]) 2329 grads = dist_autograd.get_gradients(context_id) 2330 p_g = MyFuncSingleGrad.static_grad_ptr 2331 p_a = grads[a].data_ptr() 2332 # Verify there was no clone. 2333 self.assertTrue(p_a == p_g) 2334 2335 # Test case that should trigger copy for both of a,b. This is 2336 # different in the distributed autograd case since we hold 2337 # a reference to all grads in a vector until all accumulation is done. 2338 with dist_autograd.context() as context_id: 2339 dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]]) 2340 grads = dist_autograd.get_gradients(context_id) 2341 p_g = MyFunc.static_grad_ptr 2342 p_a = grads[a].data_ptr() 2343 p_b = grads[b].data_ptr() 2344 # check a,b uses different grad buffer 2345 self.assertFalse(p_a == p_b) 2346 # both should be copied. 2347 self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr) 2348 self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr) 2349 2350 @dist_init 2351 def test_no_grad_copy_sparse(self): 2352 # create autograd function that saves grad pointer as class static 2353 class MyFunc(Function): 2354 static_grad_ptr = None 2355 2356 @staticmethod 2357 def forward(ctx, inp): 2358 return inp 2359 2360 @staticmethod 2361 def backward(ctx, grad): 2362 MyFunc.static_grad_ptr = grad._values().data_ptr() 2363 return grad 2364 2365 class NonContGradFunc(Function): 2366 static_grad_ptr = None 2367 2368 @staticmethod 2369 def forward(ctx, inp1, inp2): 2370 return inp1 + inp2 2371 2372 @staticmethod 2373 def backward(ctx, grad): 2374 # Create a sparse tensor with non-contiguous indices and values 2375 # and return as grad. 2376 v = torch.rand(1, 3) 2377 i = torch.ones(1, 1, dtype=torch.long) 2378 nv = v.expand(8, 3) 2379 ni = i.expand(1, 8) 2380 ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32) 2381 NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr() 2382 return ngrad, ngrad 2383 2384 a = torch.randn(10, 3, requires_grad=True) 2385 b = torch.randn(10, 3, requires_grad=True) 2386 input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 2387 offsets = torch.tensor([0, 4]) 2388 import torch.nn.functional as F 2389 2390 # test case that should trigger no copy for a. 2391 with dist_autograd.context() as context_id: 2392 emb_matrix = MyFunc.apply(a) 2393 loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 2394 dist_autograd.backward(context_id, [loss], retain_graph=True) 2395 grads = dist_autograd.get_gradients(context_id) 2396 p_g = MyFunc.static_grad_ptr 2397 p_a = grads[a]._values().data_ptr() 2398 # check a uses the same buffer 2399 self.assertTrue(p_a == p_g) 2400 2401 # Run backwards multiple times. 2402 for i in range(10): 2403 dist_autograd.backward(context_id, [loss], retain_graph=True) 2404 2405 # non-contiguous indices and value, we should trigger a copy. 2406 with dist_autograd.context() as context_id: 2407 emb_matrix = NonContGradFunc.apply(a, b) 2408 loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 2409 dist_autograd.backward(context_id, [loss], retain_graph=True) 2410 grads = dist_autograd.get_gradients(context_id) 2411 p_g = NonContGradFunc.static_grad_ptr 2412 p_a = grads[a]._values().data_ptr() 2413 p_b = grads[b]._values().data_ptr() 2414 # check a,b uses different grad buffer 2415 self.assertFalse(p_a == p_b) 2416 # Verify we cloned both grads. 2417 self.assertFalse(p_a == p_g) 2418 self.assertFalse(p_b == p_g) 2419 2420 # Run backwards multiple times to verify accumulation. 2421 for i in range(10): 2422 dist_autograd.backward(context_id, [loss], retain_graph=True) 2423 2424 @dist_init 2425 def test_grad_copy_sparse_indices_extra_ref(self): 2426 # create autograd function that saves grad pointer as class static 2427 class MyFunc(Function): 2428 static_grad_ptr = None 2429 static_grad_indices_ref = None 2430 static_grad_values_ref = None 2431 2432 @staticmethod 2433 def forward(ctx, inp): 2434 return inp 2435 2436 @staticmethod 2437 def backward(ctx, grad): 2438 MyFunc.static_grad_ptr = grad._values().data_ptr() 2439 # indices() and values() return views, so holding onto 2440 # references of them would not increment refcount of indices 2441 # and values inside the sparse tensor. 2442 MyFunc.static_grad_indices_ref = grad._indices() 2443 MyFunc.static_grad_values_ref = grad._values() 2444 return grad 2445 2446 a = torch.randn(10, 3, requires_grad=True) 2447 input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 2448 offsets = torch.tensor([0, 4]) 2449 import torch.nn.functional as F 2450 2451 with dist_autograd.context() as context_id: 2452 emb_matrix = MyFunc.apply(a) 2453 loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 2454 dist_autograd.backward(context_id, [loss], retain_graph=True) 2455 grads = dist_autograd.get_gradients(context_id) 2456 p_g = MyFunc.static_grad_ptr 2457 p_a = grads[a]._values().data_ptr() 2458 self.assertIsNotNone(MyFunc.static_grad_indices_ref) 2459 self.assertIsNotNone(MyFunc.static_grad_values_ref) 2460 # grad would be stolen, since static_grad_indices_ref and 2461 # static_grad_values_ref are holding onto views and don't bump the 2462 # refcount. 2463 self.assertTrue(p_g == p_a) 2464 2465 @dist_init 2466 def test_post_hooks(self): 2467 self.hook_called_times = 0 2468 2469 def post_hook_add_one(output_grads, input_grads): 2470 self.hook_called_times += 1 2471 return output_grads 2472 2473 def post_hook_add_two(output_grads, input_grads): 2474 self.hook_called_times += 2 2475 return output_grads 2476 2477 t = torch.rand(10, 10, requires_grad=True) 2478 a = t + t 2479 2480 # Register post hooks 2481 accumulate_grad_0 = a.grad_fn.next_functions[0][0] 2482 accumulate_grad_0.register_hook(post_hook_add_one) 2483 accumulate_grad_0.register_hook(post_hook_add_two) 2484 2485 accumulate_grad_1 = a.grad_fn.next_functions[1][0] 2486 accumulate_grad_1.register_hook(post_hook_add_two) 2487 2488 with dist_autograd.context() as context_id: 2489 loss = a.sum() 2490 dist_autograd.backward(context_id, [loss]) 2491 self.assertEqual(5, self.hook_called_times) 2492 grads = dist_autograd.get_gradients(context_id) 2493 self.assertEqual(1, len(grads)) 2494 self.assertTrue(t in grads) 2495 2496 @staticmethod 2497 def _slow_add(t1, t2): 2498 time.sleep(1) 2499 t3 = t1 + t2 2500 t3.requires_grad = True 2501 return t3 2502 2503 @dist_init 2504 def test_thread_local_context_id(self): 2505 t1 = torch.rand((3, 3)) 2506 t2 = torch.rand((3, 3)) 2507 2508 t3 = t1 + t2 2509 t3.requires_grad = True 2510 t3.sum().backward() 2511 2512 dst = worker_name((self.rank + 1) % self.world_size) 2513 rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2)) 2514 2515 with dist_autograd.context() as context_id: 2516 loss = rref.to_here().sum() 2517 # due to slow add, the continuation of this backward pass will be 2518 # invoked by the previous rpc.remote thread which does not have a 2519 # valid context_id. So, this can test whether we propagate 2520 # thread_local states properly when jumping across threads on the 2521 # server side. 2522 dist_autograd.backward(context_id, [loss]) 2523 self.assertTrue( 2524 rpc.rpc_sync( 2525 dst, 2526 _compare_owner_value, 2527 args=(context_id, rref, t3.grad) 2528 ) 2529 ) 2530 2531 2532class CudaDistAutogradTest(CommonDistAutogradTest): 2533 @skip_if_lt_x_gpu(1) 2534 @dist_init 2535 def test_gpu_simple(self): 2536 t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") 2537 t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0") 2538 (t1 + t2).sum().backward() 2539 with dist_autograd.context() as context_id: 2540 t3 = t1 + t2 2541 dist_autograd.backward(context_id, [t3.sum()]) 2542 grads = dist_autograd.get_gradients(context_id) 2543 self.assertEqual(2, len(grads)) 2544 self.assertEqual(t1.grad, grads[t1]) 2545 self.assertEqual(t2.grad, grads[t2]) 2546 2547 @skip_if_lt_x_gpu(1) 2548 @dist_init 2549 def test_gpu_to_cpu_continuation(self): 2550 t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") 2551 t2 = torch.rand(3, 3, requires_grad=True) 2552 # Run a few iterations. 2553 for i in range(3): 2554 t1.grad = None 2555 t2.grad = None 2556 # Root is CPU 2557 local_grads = None 2558 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: 2559 with dist_autograd.context() as context_id: 2560 t3 = self._exec_func(exec_mode, torch.add, t2, t2) 2561 t4 = t3.cuda(0) + t1 2562 t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) 2563 t6 = t5.cuda(0) + t4 2564 t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5) 2565 # Autograd graph consists of CPU -> GPU -> CPU execution. 2566 ret = self._verify_backwards( 2567 exec_mode, [t7.sum()], context_id, local_grads, t1, t2 2568 ) 2569 local_grads = ret if ret else local_grads 2570 2571 @skip_if_lt_x_gpu(1) 2572 @dist_init 2573 def test_gpu_to_cpu_continuation_gpu_root(self): 2574 t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") 2575 t2 = torch.rand(3, 3, requires_grad=True) 2576 # Run a few iterations. 2577 for i in range(3): 2578 t1.grad = None 2579 t2.grad = None 2580 # Root is CPU 2581 local_grads = None 2582 for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]: 2583 with dist_autograd.context() as context_id: 2584 t3 = self._exec_func(exec_mode, torch.add, t2, t2) 2585 t4 = t3.cuda(0) + t1 2586 t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) 2587 t6 = t5.cuda(0) + t4 2588 # Autograd graph consists of CPU -> GPU -> CPU execution. 2589 ret = self._verify_backwards( 2590 exec_mode, [t6.sum()], context_id, local_grads, t1, t2 2591 ) 2592 local_grads = ret if ret else local_grads 2593 2594 2595class FaultyAgentDistAutogradTest(RpcAgentTestFixture): 2596 # Reusing a simplified helper function from DistAutogradTest to ensure 2597 # autograd context is successfully cleaned up even when RPCs are failing. 2598 def context_cleanup_test_helper(self, rpc_args, func): 2599 initialize_pg(self.file_init_method, self.rank, self.world_size) 2600 2601 # test that in dist autograd, in the case that tensors communicated over RPC do 2602 # NOT require grad, we still cleanup the dist autograd contexts created 2603 # on other nodes. This is because the autograd context is still 2604 # communicated over RPC even if tensor arguments do not require grad, as 2605 # it is possible that the response could. 2606 dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} 2607 2608 with dist_autograd.context() as context_id: 2609 for dst_rank in dst_ranks: 2610 rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args) 2611 rpc.rpc_sync( 2612 worker_name(dst_rank), _set_rpc_done, args=(context_id, 1) 2613 ) 2614 # the thread's context id should be cleaned up 2615 with self.assertRaises(RuntimeError): 2616 dist_autograd._retrieve_context(context_id) 2617 # Ensure all peers have finished mutating the 2618 # `known_context_ids` set. 2619 dist.barrier() 2620 # check that all contexts have been cleaned up. 2621 success = _all_contexts_cleaned_up() 2622 self.assertTrue(success) 2623 2624 # no faulty_messages defined so this fails all retryable messages - see 2625 # faulty_rpc_agent_test_fixture.py for the list of retryable messages. 2626 @dist_init 2627 def test_context_cleanup_tensor_with_grad(self): 2628 t1 = torch.ones(3, 3, requires_grad=True) 2629 t2 = torch.zeros(3, 3, requires_grad=True) 2630 self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) 2631 2632 @dist_init 2633 def test_verify_backend_options(self): 2634 self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE) 2635 self.assertEqual(self.rpc_backend_options.num_worker_threads, 8) 2636 self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) 2637 self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) 2638 2639 2640class WrapperModule(nn.Module): 2641 def __init__(self, model, device): 2642 super().__init__() 2643 self.model = model.to(device) 2644 2645 def forward(self, *args): 2646 return self.model(*args) 2647 2648 def gradients(self, ctx_id): 2649 grads = dist_autograd.get_gradients(ctx_id) 2650 return [grads[p] for p in self.model.parameters()] 2651 2652 2653class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): 2654 2655 @skip_if_lt_x_gpu(4) 2656 def test_device_maps_backward_pass(self): 2657 options = self.rpc_backend_options 2658 dst = worker_name((self.rank + 1) % self.world_size) 2659 2660 # The reverse of this device mapping should be used for the backward pass. 2661 options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size}) 2662 2663 rpc.init_rpc( 2664 name=worker_name(self.rank), 2665 backend=self.rpc_backend, 2666 rank=self.rank, 2667 world_size=self.world_size, 2668 rpc_backend_options=options, 2669 ) 2670 2671 t1 = torch.rand(10, device=self.rank, requires_grad=True) 2672 t2 = torch.rand(10, device=self.rank, requires_grad=True) 2673 with dist_autograd.context() as context_id: 2674 res = rpc.rpc_sync(dst, torch.add, args=(t1, t2)) 2675 dist_autograd.backward(context_id, [res.sum()]) 2676 grads = dist_autograd.get_gradients(context_id) 2677 self.assertEqual(torch.ones(10), grads[t1]) 2678 self.assertEqual(torch.ones(10), grads[t2]) 2679 self.assertEqual(t1.device, grads[t1].device) 2680 self.assertEqual(t2.device, grads[t2].device) 2681 2682 rpc.shutdown() 2683 2684 class MyRemoteCompute(torch.nn.Module): 2685 def forward(self, input): 2686 input = input * 2.0 2687 return input 2688 2689 class MyLocalCompute(torch.nn.Module): 2690 def __init__(self, next_stage): 2691 super().__init__() 2692 self.next_stage = next_stage 2693 2694 def forward(self, input): 2695 return self.next_stage.rpc_sync().forward(input) 2696 2697 @skip_if_lt_x_gpu(4) 2698 def test_dist_autograd_sync_streams(self): 2699 2700 options = self.rpc_backend_options 2701 dst = worker_name((self.rank + 1) % self.world_size) 2702 2703 # The reverse of this device mapping should be used for the backward pass. 2704 options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size}) 2705 2706 rpc.init_rpc( 2707 name=worker_name(self.rank), 2708 backend=self.rpc_backend, 2709 rank=self.rank, 2710 world_size=self.world_size, 2711 rpc_backend_options=options, 2712 ) 2713 2714 remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute) 2715 local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute) 2716 for _ in range(10): 2717 input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) 2718 # Run local autograd 2719 result = input * 2.0 2720 r = random.random() 2721 loss = result.sum() * r 2722 loss.backward() 2723 2724 # Run distributed autograd 2725 with dist_autograd.context() as context_id: 2726 result = local_compute(input) 2727 loss = result.sum() * r 2728 dist_autograd.backward(context_id, [loss]) 2729 2730 # Compare grads. 2731 grads = dist_autograd.get_gradients(context_id) 2732 self.assertEqual(input.grad, grads[input]) 2733 2734 rpc.shutdown() 2735 2736 @skip_if_lt_x_gpu(4) 2737 def test_gradients_synchronizations(self): 2738 options = self.rpc_backend_options 2739 for peer_rank in range(self.world_size): 2740 options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank}) 2741 2742 rpc.init_rpc( 2743 name=worker_name(self.rank), 2744 backend=self.rpc_backend, 2745 rank=self.rank, 2746 world_size=self.world_size, 2747 rpc_backend_options=options, 2748 ) 2749 2750 if self.rank == 0: 2751 # this is master 2752 layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)] 2753 local_layers = [l.to(0) for l in layers] 2754 remote_layers = [] 2755 for rank in range(1, self.world_size): 2756 remote_layers.append(rpc.remote( 2757 worker_name(rank), 2758 WrapperModule, 2759 args=(layers[rank - 1], rank) 2760 )) 2761 2762 x = torch.randn(5000, 2000).to(0) 2763 # local iteration 2764 local_model = nn.Sequential(*local_layers) 2765 local_model(x).sum().backward() 2766 2767 # remote iteration 2768 with dist_autograd.context() as context_id: 2769 for remote_layer in remote_layers: 2770 x = remote_layer.rpc_sync().forward(x) 2771 2772 dist_autograd.backward(context_id, [x.sum()]) 2773 2774 futs = [] 2775 for remote_layer in remote_layers: 2776 futs.append(remote_layer.rpc_async().gradients(context_id)) 2777 2778 for i in range(len(futs)): 2779 local_gradients = [p.grad for p in local_layers[i].parameters()] 2780 for g1, g2 in zip(futs[i].wait(), local_gradients): 2781 self.assertEqual(g1, g2) 2782 2783 rpc.shutdown() 2784