1# mypy: allow-untyped-defs 2 3import concurrent.futures 4import contextlib 5import json 6import os 7import sys 8import threading 9import time 10 11from collections import namedtuple 12from functools import partial 13from threading import Event 14from threading import Lock 15from unittest import mock 16 17import torch 18import torch.nn as nn 19import torch.distributed as dist 20import torch.distributed.rpc as rpc 21import torch.distributed.autograd as dist_autograd 22from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info, WorkerInfo 23from torch.distributed.rpc.api import _use_rpc_pickler, _thread_local_var, _wait_all 24from torch.distributed.rpc.internal import ( 25 PythonUDF, 26 RPCExecMode, 27 _internal_rpc_pickler, 28 _build_rpc_profiling_key, 29) 30from torch.futures import Future 31from torch.testing._internal.common_distributed import ( 32 skip_if_lt_x_gpu, 33 captured_output, 34 tp_transports, 35) 36from torch.testing._internal.common_utils import ( 37 IS_MACOS, 38 load_tests, 39 skip_but_pass_in_sandcastle_if, 40 get_cycles_per_ms, 41) 42 43from torch.testing._internal.dist_utils import ( 44 dist_init, 45 get_function_event, 46 initialize_pg, 47 wait_until_node_failure, 48 wait_until_pending_futures_and_users_flushed, 49 wait_until_owners_and_forks_on_rank, 50 worker_name, 51) 52from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 53 RpcAgentTestFixture, 54) 55from torch.testing._internal.common_utils import TemporaryFileName 56 57from torch.autograd.profiler_legacy import profile as _profile 58import operator 59 60 61def foo_add(): 62 return torch.add(torch.ones(1), torch.ones(1)) 63 64def udf_with_torch_ops(device=-1, use_record_function=False): 65 device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device) 66 record_function_ctx = ( 67 torch.autograd.profiler.record_function("##forward##") 68 if use_record_function 69 else contextlib.nullcontext() 70 ) 71 with device_ctx, record_function_ctx: 72 t1, t2 = torch.ones(1), torch.ones(1) 73 t = torch.add(t1, t2) 74 t = torch.mul(t, t) 75 t = t.relu() 76 t = t.sigmoid() 77 78# Events (operator invocations) that are expected to be ran as part of the above 79# function. 80EXPECTED_REMOTE_EVENTS = [ 81 "aten::ones", 82 "aten::ones", 83 "aten::add", 84 "aten::mul", 85 "aten::relu", 86 "aten::clamp_min", 87 "aten::sigmoid", 88] 89 90# Remote operations are prefixed with the following string for RPC profiling. 91REMOTE_OP_STR = "#remote_op: " 92 93 94VALUE_FUTURE = concurrent.futures.Future() 95DONE_FUTURE = concurrent.futures.Future() 96 97FIFTY_MIL_CYCLES = 50000000 98 99_rpc_barrier_count = 0 100 101def _increment_count(): 102 global _rpc_barrier_count 103 _rpc_barrier_count += 1 104 105def _reset_count(): 106 global _rpc_barrier_count 107 _rpc_barrier_count = 0 108 109class StubRpcAgent: 110 def __init__(self, world_size): 111 self.world_size = world_size 112 113 def get_worker_infos(self): 114 return { 115 WorkerInfo(name=worker_name(rank), id=rank) 116 for rank in range(self.world_size) 117 } 118 119 120def _stub_construct_rpc_backend_options_handler(**kwargs): 121 return mock.Mock() # RpcBackendOptions. 122 123 124def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options): 125 return StubRpcAgent(world_size=world_size) 126 127 128def set_value(value): 129 VALUE_FUTURE.set_result(value) 130 131 132def wait_for_value_future(): 133 return VALUE_FUTURE.result() 134 135 136def set_and_check_done(value): 137 VALUE_FUTURE.set_result(value) 138 return DONE_FUTURE.result() 139 140 141# it is used to test python user defined function over rpc 142# classes and functions are used to test python user defined class and 143# methods over rpc 144TensorClass = namedtuple("TensorClass", ["tensors"]) 145 146class MyPickleClass: 147 def __init__(self) -> None: 148 self.t = None 149 150 def __getstate__(self): 151 (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( 152 PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None) 153 ) 154 return (pickled_python_udf, tensors) 155 156 def __setstate__(self, obj): 157 python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1]) 158 result = python_udf.func(python_udf.args[0], python_udf.args[1]) 159 self.t = result 160 161 def set(self, val): 162 self.t = val 163 164 165class SlowPickleClass: 166 def __init__(self, t): 167 self.t = t 168 169 def __getstate__(self): 170 time.sleep(self.t) 171 return (self.t, ) 172 173 def __setstate__(self, obj): 174 self.t = obj[0] 175 time.sleep(self.t) 176 177 178class MyClass: 179 def __init__(self, a, delay=False): 180 self.a = a 181 # delay initialization to simulate errors if specified 182 if delay: 183 time.sleep(2) 184 185 def my_instance_method(self, b): 186 return self.a + b 187 188 @classmethod 189 def my_class_method(cls, d, e): 190 return d + e 191 192 @staticmethod 193 def my_static_method(f): 194 return f > 10 195 196 def increment_value(self, increment): 197 self.a += increment 198 199 def get_value(self): 200 return self.a 201 202 def my_slow_method(self, my_tensor_arg): 203 time.sleep(5) 204 return torch.add(self.a, my_tensor_arg) 205 206 207def _call_method_on_rref(method, rref, *args, **kwargs): 208 return method(rref.local_value(), *args, **kwargs) 209 210 211def get_rref_list(values): 212 return [RRef(MyClass(a)) for a in values] 213 214 215def add_rref_to_value(rref, value): 216 return rref.to_here() + value 217 218 219def run_nested_pickle(pickle_cls_instance, tensor): 220 return pickle_cls_instance.t + tensor 221 222def build_sparse_tensor(coalesce=False): 223 i = [[0, 1, 1], [2, 0, 2]] 224 v = [3, 4, 5] 225 tensor = torch.sparse_coo_tensor(i, v, (2, 3)) 226 if coalesce: 227 tensor = tensor.coalesce() 228 return tensor 229 230def build_complex_tensors(): 231 a = torch.ones(3, 3) 232 b = [a, a] 233 c = [b, b] 234 d = [a, b] 235 e = {a: d} 236 return [a, b, c, d, e] 237 238def non_cont_test(t_view, t_cont): 239 if t_view.is_contiguous(): 240 raise Exception('t_view is contiguous!') # noqa: TRY002 241 if not t_cont.is_contiguous(): 242 raise Exception('t_cont is not contiguous!') # noqa: TRY002 243 if not torch.equal(t_view, t_cont): 244 raise Exception('t_view is not equal to t_cont!') # noqa: TRY002 245 return t_view 246 247def my_function(a, b, c): 248 return a + b + c 249 250 251def my_tensor_function(a, b): 252 return a + b 253 254def my_container_sum(a): 255 result = a[0] 256 for tensor in a[1:]: 257 result += tensor 258 return result 259 260 261def my_sleep_func(seconds=1): 262 time.sleep(seconds) 263 return torch.mul(torch.tensor(1), torch.tensor(1)) 264 265 266def my_complex_tensor_function(list_input, tensor_class_input, dict_input): 267 res = list_input[0] 268 for t in list_input: 269 res += t 270 for v in dict_input.values(): 271 res += v 272 complex_tensors = tensor_class_input.tensors 273 return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2]) 274 275 276def my_rref_function(rref_a, rref_b): 277 return rref_a.to_here() + rref_b.to_here() 278 279 280def delayed_add(a, b, seconds=0.05): 281 time.sleep(seconds) 282 return a + b 283 284 285def identity(a): 286 return a 287 288def no_result(): 289 print("do nothing") 290 291def raise_or_inc(value): 292 if value.numel() == 2: 293 raise ValueError("Expected error") 294 return value + 1 295 296def nested_rpc(dst): 297 return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) 298 299 300def nested_rpc_sparse(dst): 301 return rpc.rpc_sync( 302 dst, 303 torch.add, 304 args=(build_sparse_tensor(), build_sparse_tensor()) 305 ) 306 307 308def multi_layer_nested_async_rpc(dst, world_size, ttl): 309 # this method returns immediately without blocking the callee, but will 310 # generate additional requests. 311 if ttl > 0: 312 current_dst = worker_name(dst) 313 next_dst = (dst + 1) % world_size 314 rpc.rpc_async( 315 current_dst, 316 multi_layer_nested_async_rpc, 317 args=(next_dst, world_size, ttl - 1), 318 ) 319 return 0 320 321 322def nested_rref(dst): 323 return ( 324 rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), 325 rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), 326 ) 327 328 329def nested_rref_sparse(dst): 330 return ( 331 rpc.remote( 332 dst, 333 torch.add, 334 args=(build_sparse_tensor(), build_sparse_tensor()) 335 ), 336 rpc.remote( 337 dst, 338 torch.add, 339 args=(build_sparse_tensor(), build_sparse_tensor()) 340 ), 341 ) 342 343 344def nested_remote(dst): 345 rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) 346 return rref.to_here() 347 348def nested_remote_sparse(dst): 349 rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())) 350 return rref.to_here() 351 352 353def rref_forward_chain(dst, world_size, rref, ttl): 354 if ttl > 0: 355 current_dst = worker_name(dst) 356 next_dst = (dst + 1) % world_size 357 ret_rref = rpc.remote( 358 current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1) 359 ) 360 return [ret_rref] 361 else: 362 return rref.to_here() 363 364 365def rpc_return_rref(dst): 366 return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) 367 368 369def light_rpc(): 370 return 0 371 372 373def heavy_rpc(tensor): 374 for i in range(1, 100): 375 tensor *= i 376 tensor /= i + 1 377 return 0 378 379 380def heavy_rpc_sparse(tensor): 381 for i in range(1, 100): 382 tensor *= i 383 tensor = tensor / (i + 1) 384 return 0 385 386@torch.jit.script 387def heavy_rpc_torchscript(tensor): 388 for i in range(1, 100): 389 tensor *= i 390 tensor /= i + 1 391 return 0 392 393 394@torch.jit.script 395def my_script_func(tensor): 396 return torch.add(tensor, tensor) 397 398 399expected_err = "Expected error" 400 401# Note that it needs to inherit from Exception, not BaseException. See comment 402# in rpc/internal.py 403class CustomException(Exception): 404 def __init__(self, bool, msg): 405 self.bool = bool 406 super().__init__(msg) 407 408def raise_func(): 409 raise ValueError(expected_err) 410 411def custom_raise_func(): 412 raise CustomException(True, "foo") 413 414@torch.jit.script 415def raise_func_script(expected_err: str) -> torch.Tensor: 416 raise ValueError(expected_err) 417 418expected_err_escape = "\nFirst line of error \n next line of error \n last line of error" 419def raise_func_escape(): 420 raise ValueError(expected_err_escape) 421 422 423global_rref = None 424 425 426def set_global_rref(rref): 427 global global_rref 428 global_rref = rref 429 430 431def clear_global_rref(): 432 global global_rref 433 global_rref = None 434 435 436def check_rref_confirmed(rref): 437 return rref.confirmed_by_owner() 438 439 440def get_rref_debug_info(): 441 return _rref_context_get_debug_info() 442 443 444def add_use_future_cb(to, x, y, z): 445 out = concurrent.futures.Future() 446 447 def callback(fut): 448 out.set_result(fut.wait() + z) 449 450 fut = rpc.rpc_async(to, torch.add, args=(x, y)) 451 fut.then(callback) 452 return out.result() 453 454 455def get_events_from_profile(profile_rref): 456 return profile_rref.local_value().process_global_function_events 457 458 459def add_use_future_set_result(to, x, y, z): 460 out = torch.futures.Future() 461 fut = rpc.rpc_async(to, torch.add, args=(x, y)) 462 fut.then(lambda fut : out.set_result(fut.wait() + z)) 463 return out.wait() 464 465 466def add_use_future_nested_cb(to, x, y, z): 467 out = torch.futures.Future() 468 469 def callback(fut1): 470 fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z)) 471 fut2.then(lambda fut2 : out.set_result(fut2.wait())) 472 473 fut1 = rpc.rpc_async(to, torch.add, args=(x, y)) 474 fut1.then(callback) 475 return out.wait() 476 477 478def fail_on_fut(fut): 479 pass 480 481 482@rpc.functions.async_execution 483def async_raise_func(): 484 raise RuntimeError("Expected error") 485 486 487@rpc.functions.async_execution 488def async_wrong_type(): 489 return torch.zeros(2, 2) 490 491 492@rpc.functions.async_execution 493def async_add(to, x, y): 494 return rpc.rpc_async(to, torch.add, args=(x, y)) 495 496 497def slow_add(x, y, device="cpu"): 498 time.sleep(1) 499 x = x.to(device) 500 y = y.to(device) 501 return torch.add(x, y).cpu() 502 503 504@rpc.functions.async_execution 505def slow_async_add(to, x, y, device="cpu"): 506 return rpc.rpc_async(to, slow_add, args=(x, y, device)) 507 508 509@rpc.functions.async_execution 510def async_add_with_future_ctor(to, x, y, z): 511 fut = torch.futures.Future() 512 rpc.rpc_async(to, torch.add, args=(x, y)).then( 513 lambda fut1: fut.set_result(fut1.wait() + z) 514 ) 515 return fut 516 517 518@rpc.functions.async_execution 519def async_add_chained(to, x, y, z): 520 return rpc.rpc_async(to, torch.add, args=(x, y)).then( 521 lambda fut: fut.wait() + z 522 ) 523 524 525@rpc.functions.async_execution 526def async_add_chained_multi(to, x, num, step): 527 fut = rpc.rpc_async(to, torch.add, args=(x, 0)) 528 for _ in range(num): 529 fut = fut.then(lambda fut: fut.wait() + step) 530 return fut 531 532 533@rpc.functions.async_execution 534def async_add_nested(to, x, y, z): 535 return rpc.rpc_async(to, async_add, args=(to, x, y)).then( 536 lambda fut: fut.wait() + z 537 ) 538 539 540@rpc.functions.async_execution 541def async_add_multi_fanout(to, x, num, step): 542 futs = [] 543 for i in range(num): 544 if i == 0: 545 futs.append(rpc.rpc_async(to, torch.add, args=(x, step))) 546 else: 547 futs.append(rpc.rpc_async(to, torch.add, args=(0, step))) 548 549 # TODO: use torch.futures.collect_all 550 lock = Lock() 551 state = {"cnt": 0, "ret": torch.zeros_like(x)} 552 ret_future = torch.futures.Future() 553 554 def inc_and_set(fut): 555 with lock: 556 state["cnt"] += 1 557 state["ret"] += fut.wait() 558 if state["cnt"] >= len(futs): 559 ret_future.set_result(state["ret"]) 560 561 for fut in futs: 562 fut.then(inc_and_set) 563 564 return ret_future 565 566 567@rpc.functions.async_execution 568def async_cuda_sleep_and_set_to_one(t): 569 device = t.device 570 original_stream = torch.cuda.current_stream(device) 571 new_stream = torch.cuda.Stream(device) 572 new_stream.wait_stream(original_stream) 573 with torch.cuda.stream(new_stream): 574 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 575 t.fill_(1) 576 fut = Future(devices=[device]) 577 fut.set_result(t) 578 return fut 579 580 581@rpc.functions.async_execution 582def async_cuda_nested_add(to, x, y, z): 583 def cb(fut): 584 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 585 return fut.value() + z 586 587 return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb) 588 589 590# A custom Python class that contains a tensor, needed to see if we correctly 591# use the Python pickler to extract tensors from non-IValue-convertible types. 592class TensorWrapper: 593 __slots__ = ("tensor", "lock", "event", "thread") 594 595 def __init__(self, t): 596 self.tensor = t 597 # Add one non-picklable field, to ensure it's ignored/skipped. 598 self.lock = Lock() 599 self.event = torch.cuda.Event(enable_timing=True) 600 self.thread = threading.Thread() 601 self.thread.start() 602 603 def increase(self, v): 604 with self.lock: 605 self.tensor += v 606 607 def sum(self): 608 with self.lock: 609 self.event.record() 610 return self.tensor.sum() 611 612 613class AsyncExecutionClass: 614 615 @staticmethod 616 @rpc.functions.async_execution 617 def static_async_add(to, x, y, z): 618 return rpc.rpc_async(to, torch.add, args=(x, y)).then( 619 lambda fut: fut.wait() + z 620 ) 621 622 @classmethod 623 @rpc.functions.async_execution 624 def class_async_add(cls, to, x, y, z): 625 ret_fut = torch.futures.Future() 626 rpc.rpc_async(to, torch.add, args=(x, y)).then( 627 lambda fut: ret_fut.set_result(fut.wait() + z) 628 ) 629 return ret_fut 630 631 @rpc.functions.async_execution 632 def bound_async_add(self, to, x, y, z): 633 return rpc.rpc_async(to, torch.add, args=(x, y)).then( 634 lambda fut: fut.wait() + z 635 ) 636 637 638def return_future(): 639 return torch.futures.Future() 640 641 642class FooBackendOptions(rpc.RpcBackendOptions): 643 def __init__(self, init_method): 644 # Must call the __init__ of the superclass (and do so directly, 645 # without using super()) because... pybind. 646 rpc.RpcBackendOptions.__init__(self) 647 self.init_method = init_method 648 649 650# load_tests from common_utils is used to automatically filter tests for 651# sharding on sandcastle. This line silences flake warnings 652load_tests = load_tests 653 654 655class MyEmbeddingBagModel(torch.nn.Module): 656 def __init__(self, sparse): 657 super().__init__() 658 self.eb = torch.nn.EmbeddingBag( 659 10, 660 10, 661 sparse=sparse 662 ) 663 664 def forward(self, x): 665 return self.eb(x) 666 667 668class MyParameterServer: 669 def __init__(self, trainers): 670 self.lock = Lock() 671 self.trainers = trainers 672 self.iteration = 0 673 self.updates = 0 674 self.futures = [] 675 self.total = None 676 self.gradient = None 677 678 @staticmethod 679 def get_gradient(rref): 680 return rref.local_value().gradient 681 682 @staticmethod 683 @rpc.functions.async_execution 684 def average(rref, riteration, tensor): 685 self = rref.local_value() 686 fut = torch.futures.Future() 687 with self.lock: 688 if riteration > self.iteration: 689 self.iteration = riteration 690 self.updates = 0 691 self.futures.clear() 692 self.futures.append(fut) 693 if self.total is None: 694 self.total = tensor 695 else: 696 self.total += tensor 697 self.updates += 1 698 if self.trainers == self.updates: 699 self.gradient = self.total / float(self.trainers) 700 for fut in self.futures: 701 result = self.total / float(self.trainers) 702 fut.set_result(result) 703 return fut 704 705 706class MyConvNetForMNIST(nn.Module): 707 def __init__(self, device): 708 super().__init__() 709 self.net = nn.Sequential( 710 nn.Conv2d(1, 16, 3, 1), 711 nn.ReLU(), 712 nn.Conv2d(16, 32, 3, 1), 713 nn.ReLU(), 714 nn.MaxPool2d(2), 715 nn.Flatten(1), 716 nn.Linear(4608, 128), 717 nn.ReLU(), 718 nn.Linear(128, 10), 719 ).to(device) 720 self.device = device 721 722 def forward(self, x, is_rref=False): 723 x = x.to_here() if is_rref else x 724 with torch.cuda.stream(torch.cuda.current_stream(self.device)): 725 # intentionally adding delay to current CUDA stream 726 torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) 727 return self.net(x) 728 729 def __getstate__(self): 730 # return an empty dict to avoid inspecting the model contents on the 731 # owner 732 return {} 733 734 735class RpcTestCommon: 736 def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): 737 if mode == RPCExecMode.SYNC: 738 return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) 739 elif mode == RPCExecMode.ASYNC: 740 return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() 741 elif mode == RPCExecMode.REMOTE: 742 return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() 743 744 def _self_py_udf_remote(self, worker_info, x, y, z): 745 rref = rpc.remote(worker_info, my_function, args=(x, y, z)) 746 self.assertEqual(rref.to_here(), x + y + z) 747 748 def _self_remote_rref_as_rpc_arg(self, dst, x, y, z): 749 self_worker_info = rpc.get_worker_info() 750 rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) 751 fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x)) 752 ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y)) 753 self.assertEqual(ret, x + y + z + x + y) 754 self.assertEqual(fut.wait(), x + y + z + x) 755 756 def _self_remote_rref_as_remote_arg(self, dst, x, y, z): 757 self_worker_info = rpc.get_worker_info() 758 rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) 759 ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x)) 760 self.assertEqual( 761 ret_rref.to_here(), x + y + z + x 762 ) 763 764 def _world_size_one(self, a, b): 765 if self.rank == 0: 766 rpc.init_rpc( 767 name="me", 768 backend=self.rpc_backend, 769 rank=0, 770 world_size=1, 771 rpc_backend_options=self.rpc_backend_options, 772 ) 773 774 def _rpc_sync(x, y): 775 expect = x * 2 776 result = rpc.rpc_sync( 777 "me", 778 my_tensor_function, 779 args=(x, y) 780 ) 781 self.assertEqual(expect, result) 782 783 def _rpc_async(x, y): 784 expect = x * 2 785 result = rpc.rpc_async( 786 "me", 787 my_tensor_function, 788 args=(x, y) 789 ).wait() 790 self.assertEqual(expect, result) 791 792 def _remote(x, y): 793 expect = x * 2 794 result = rpc.remote( 795 "me", 796 my_tensor_function, 797 args=(x, y) 798 ).to_here() 799 self.assertEqual(expect, result) 800 801 _rpc_sync(a, b) 802 _rpc_async(a, b) 803 _remote(a, b) 804 805 rpc.shutdown() 806 807 def _multi_rpc(self, sparse): 808 dst_rank = (self.rank + 1) % self.world_size 809 for i in range(20): 810 n = i + self.rank + 1 811 if sparse: 812 x = build_sparse_tensor() * n 813 y = build_sparse_tensor() * n 814 else: 815 x = torch.ones(2, 2) 816 y = torch.ones(2, 2) 817 ret = rpc.rpc_sync( 818 worker_name(dst_rank), 819 torch.add, 820 args=(x, y), 821 ) 822 self.assertEqual(ret, x * 2) 823 824 def _run_uneven_workload(self, f, x, num_repeat=30): 825 # worker0 drives and waits for worker1 and worker2 826 # throughout the test. 827 if self.rank == 0: 828 self.assertTrue(self.world_size >= 3) 829 830 # Phase 1: Only worker1 has workload. 831 dst = "worker1" 832 futs = [] 833 for _ in range(num_repeat): 834 fut = rpc.rpc_async(dst, f, args=(x,)) 835 futs.append(fut) 836 837 for fut in torch.futures.collect_all(futs).wait(): 838 self.assertEqual(fut.wait(), 0) 839 840 # Phase 2: Only worker2 has workload. 841 # If join is not correctly implemented, 842 # worker2 should be closed by now. 843 dst = "worker2" 844 futs = [] 845 for _ in range(num_repeat): 846 fut = rpc.rpc_async(dst, f, args=(x,)) 847 futs.append(fut) 848 849 for val in torch.futures.wait_all(futs): 850 self.assertEqual(val, 0) 851 852 def _wait_all_workers(self, f, x): 853 initialize_pg(self.file_init_method, self.rank, self.world_size) 854 rpc.init_rpc( 855 name="worker%d" % self.rank, 856 backend=self.rpc_backend, 857 rank=self.rank, 858 world_size=self.world_size, 859 rpc_backend_options=self.rpc_backend_options, 860 ) 861 862 self._run_uneven_workload(f, x) 863 864 # worker0 calls this at the end after waiting for RPC responses. 865 # worker1/2 calls this immediately and has some works after it. 866 # worker3 calls this immediately and has no more work. 867 rpc.api._wait_all_workers() 868 869 # Wait before proceeding to shutdown to ensure worker0 RPCs make 870 # it through to other workers. 871 dist.barrier() 872 rpc.shutdown(graceful=False) 873 874 def _wait_all_workers_twice(self, f, x): 875 initialize_pg(self.file_init_method, self.rank, self.world_size) 876 rpc.init_rpc( 877 name="worker%d" % self.rank, 878 backend=self.rpc_backend, 879 rank=self.rank, 880 world_size=self.world_size, 881 rpc_backend_options=self.rpc_backend_options, 882 ) 883 884 self._run_uneven_workload(f, x) 885 886 # worker0 calls this at the end after waiting for RPC responses. 887 # worker1/2 calls this immediately and has some works after it. 888 # worker3 calls this immediately and has no more work. 889 rpc.api._wait_all_workers() 890 rpc.api._wait_all_workers() 891 892 # Wait before proceeding to shutdown to ensure worker0 RPCs make 893 # it through to other workers. 894 dist.barrier() 895 rpc.shutdown(graceful=False) 896 897 def _nested_rpc(self, f, expected): 898 n = self.rank + 1 899 dst_rank = n % self.world_size 900 ret = rpc.rpc_sync( 901 worker_name(dst_rank), 902 f, 903 args=(worker_name(self.rank),), 904 ) 905 self.assertEqual(ret, expected) 906 907 def _stress_test_rpc(self, f, repeat=1000, args=()): 908 n = self.rank + 1 909 dst_rank = n % self.world_size 910 futs = [] 911 tik = time.time() 912 for _ in range(repeat): 913 fut = rpc.rpc_async(worker_name(dst_rank), f, args=args) 914 futs.append(fut) 915 916 for val in torch.futures.wait_all(futs): 917 self.assertEqual(val, 0) 918 tok = time.time() 919 print( 920 f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds." 921 ) 922 923 def _builtin_remote_ret(self, x, y, expected): 924 n = self.rank + 1 925 dst_rank = n % self.world_size 926 rref = rpc.remote( 927 worker_name(dst_rank), 928 torch.add, 929 args=(x, y), 930 ) 931 self.assertEqual(rref.to_here(), expected) 932 933 def _builtin_remote_self(self, x, y, expected): 934 rref = rpc.remote( 935 worker_name(self.rank), 936 torch.add, 937 args=(x, y), 938 ) 939 self.assertEqual(rref.local_value(), expected) 940 941 def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}): 942 m = 10 943 n = self.rank + 1 944 dst_rank = n % self.world_size 945 rrefs = [] 946 expected = [] 947 for i in range(m): 948 n = n + i 949 rrefs.append( 950 rpc.remote( 951 worker_name(dst_rank), 952 fn, 953 args=args_fn(n, sparse), 954 kwargs=kwargs_fn(n, sparse), 955 ) 956 ) 957 expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse))) 958 959 for i in range(m): 960 self.assertEqual(rrefs[i].to_here(), expected[i]) 961 962 def _py_rref_args(self, a, b, x, y, expected): 963 n = self.rank + 1 964 dst_rank = n % self.world_size 965 rref_a = rpc.remote( 966 worker_name(dst_rank), torch.add, args=(a, b) 967 ) 968 rref_b = rpc.remote( 969 worker_name(dst_rank), torch.add, args=(x, y) 970 ) 971 rref_c = rpc.remote( 972 worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) 973 ) 974 self.assertEqual(rref_c.to_here(), expected) 975 976 def _py_rref_args_user_share(self, a, b, c, x, y, z, expected): 977 n = self.rank + 1 978 owner_rank = n % self.world_size 979 user_rank = (n + 1) % self.world_size 980 rref_a = rpc.remote( 981 worker_name(owner_rank), my_function, args=(a, b, c) 982 ) 983 rref_b = rpc.remote( 984 worker_name(owner_rank), my_function, args=(x, y, z) 985 ) 986 rref_c = rpc.remote( 987 worker_name(user_rank), my_rref_function, args=(rref_a, rref_b) 988 ) 989 self.assertEqual(rref_c.to_here(), expected) 990 991 def _py_rpc_rref_args(self, a, b, c, x, y, z, expected): 992 n = self.rank + 1 993 dst_rank = n % self.world_size 994 rref_a = rpc.remote( 995 worker_name(dst_rank), my_function, args=(a, b, c) 996 ) 997 rref_b = rpc.remote( 998 worker_name(dst_rank), my_function, args=(x, y, z) 999 ) 1000 1001 c = rpc.rpc_sync( 1002 worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) 1003 ) 1004 self.assertEqual(c, expected) 1005 1006 def _nested_remote(self, f, expected): 1007 n = self.rank + 1 1008 dst_rank1 = n % self.world_size 1009 dst_rank2 = (n + 1) % self.world_size 1010 1011 rref = rpc.remote( 1012 worker_name(dst_rank1), 1013 f, 1014 args=(worker_name(dst_rank2),), 1015 ) 1016 self.assertEqual(rref.to_here(), expected) 1017 1018 def _nested_rref(self, f, expected1, expected2): 1019 n = self.rank + 1 1020 dst_rank1 = n % self.world_size 1021 dst_rank2 = (n + 1) % self.world_size 1022 rref_of_rrefs = rpc.remote( 1023 worker_name(dst_rank1), 1024 f, 1025 args=(worker_name(dst_rank2),), 1026 ) 1027 1028 # Say C has 2 OwnerRRefs. 1029 # B has 2 UserRRefs to those 2 OwnerRRefs, respectively. 1030 # This call is effectively A asking B to share its 2 UserRRefs. 1031 rrefs = rref_of_rrefs.to_here() 1032 1033 self.assertEqual(len(rrefs), 2) 1034 self.assertEqual(rrefs[0].to_here(), expected1) 1035 self.assertEqual(rrefs[1].to_here(), expected2) 1036 1037 def _nested_rref_stress(self, f, expected1, expected2): 1038 n = self.rank + 1 1039 dst_rank1 = n % self.world_size 1040 dst_rank2 = (n + 1) % self.world_size 1041 all_rrefs = [] 1042 for _ in range(20): 1043 all_rrefs.append( 1044 rpc.remote( 1045 worker_name(dst_rank1), 1046 f, 1047 args=(worker_name(dst_rank2),), 1048 ) 1049 ) 1050 1051 for i in range(20): 1052 rref_of_rrefs = all_rrefs[i] 1053 rrefs = rref_of_rrefs.to_here() 1054 self.assertEqual(len(rrefs), 2) 1055 self.assertEqual(rrefs[0].to_here(), expected1) 1056 self.assertEqual(rrefs[1].to_here(), expected2) 1057 1058 def _trainer_func(self, rref, sparse): 1059 m = MyEmbeddingBagModel(sparse=sparse) 1060 loss_fn = nn.MSELoss() 1061 for i in range(10): 1062 outputs = m(torch.rand(10, 10).long()) 1063 loss_fn(outputs, torch.rand(10, 10)).backward() 1064 gradient = next(iter(m.parameters())).grad 1065 fut = rref.rpc_async().average(rref, i, gradient) 1066 gradient = fut.wait() 1067 if gradient.is_sparse: 1068 gradient = gradient.to_dense().double() 1069 ps_gradient = rref.rpc_sync().get_gradient(rref) 1070 if ps_gradient.is_sparse: 1071 ps_gradient = ps_gradient.to_dense().double() 1072 self.assertTrue(torch.equal(gradient, ps_gradient)) 1073 1074 def _my_parameter_server(self, sparse): 1075 ps_rref = RRef(MyParameterServer(self.world_size - 1)) 1076 futures = [] 1077 for index in range(1, self.world_size): 1078 futures.append( 1079 rpc.rpc_async( 1080 worker_name((self.rank + index) % self.world_size), 1081 self._trainer_func, 1082 args=( 1083 ps_rref, 1084 sparse 1085 ), 1086 ) 1087 ) 1088 torch.futures.wait_all(futures) 1089 1090 def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor): 1091 # We check proper CUDA stream synchronization by adding to the tensor 1092 # in one stream to get the expected value, and reading it from another stream. 1093 future = Future(devices=["cuda:0"]) 1094 with torch.cuda.device("cuda:0"): 1095 stream = torch.cuda.Stream() 1096 another_stream = torch.cuda.Stream() 1097 with torch.cuda.stream(stream): 1098 if sparse_tensor: 1099 tensor = build_sparse_tensor().to("cuda:0") 1100 add_tensor = build_sparse_tensor().to("cuda:0") 1101 expected_tensor = (tensor + add_tensor).coalesce() 1102 else: 1103 tensor = torch.zeros((100,), device="cuda:0") 1104 add_tensor = torch.ones((100,), device="cuda:0") 1105 expected_tensor = tensor + add_tensor 1106 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 1107 tensor += add_tensor 1108 if sparse_tensor: 1109 tensor = tensor.coalesce() 1110 future.set_result(wrapper(tensor)) 1111 with torch.cuda.stream(another_stream): 1112 tensor = unwrapper(future.wait()) 1113 if sparse_tensor: 1114 self.assertTrue(torch.eq(tensor.indices(), expected_tensor.indices()).all().item()) 1115 self.assertTrue(torch.eq(tensor.values(), expected_tensor.values()).all().item()) 1116 self.assertEqual(tensor.size(), expected_tensor.size()) 1117 else: 1118 self.assertTrue(torch.eq(tensor, expected_tensor).all().item()) 1119 1120 1121class RpcTest(RpcAgentTestFixture, RpcTestCommon): 1122 @dist_init 1123 def test_worker_id(self): 1124 n = self.rank + 1 1125 peer_rank = n % self.world_size 1126 self_worker_info = rpc.get_worker_info() 1127 peer_worker_info = rpc.get_worker_info(worker_name(peer_rank)) 1128 1129 self.assertEqual(self_worker_info.name, worker_name(self.rank)) 1130 self.assertEqual(peer_worker_info.name, worker_name(peer_rank)) 1131 1132 with self.assertRaisesRegex(RuntimeError, "could not find destination"): 1133 unknown_worker_id = rpc.get_worker_info("WorkerUnknown") 1134 1135 @dist_init 1136 def test_get_worker_infos(self): 1137 worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos() 1138 1139 worker_names = {worker_info.name for worker_info in worker_infos} 1140 expected_worker_names = { 1141 worker_name(rank) for rank in range(self.world_size) 1142 } 1143 self.assertEqual(worker_names, expected_worker_names) 1144 1145 worker_ids = {worker_info.id for worker_info in worker_infos} 1146 expected_worker_ids = set(range(self.world_size)) 1147 self.assertEqual(worker_ids, expected_worker_ids) 1148 1149 @dist_init 1150 def test_self_add(self): 1151 self_worker_info = rpc.get_worker_info() 1152 self_worker_name = worker_name(self.rank) 1153 fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) 1154 ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) 1155 self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) 1156 self.assertEqual(ret, torch.ones(2, 2) + 1) 1157 1158 @dist_init 1159 def test_send_to_rank(self): 1160 dst_rank = (self.rank + 1) % self.world_size 1161 1162 # Test dense tensor 1163 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1164 ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) 1165 self.assertEqual(ret, torch.ones(2, 2) + 1) 1166 1167 # Test invalid ranks 1168 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1169 with self.assertRaises(RuntimeError): 1170 self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) 1171 1172 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1173 with self.assertRaises(RuntimeError): 1174 self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) 1175 1176 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1177 with self.assertRaises(ValueError): 1178 self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) 1179 1180 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1181 with self.assertRaises(ValueError): 1182 self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) 1183 1184 @dist_init 1185 def test_self_py_udf_remote(self): 1186 self._self_py_udf_remote( 1187 rpc.get_worker_info(), 1188 torch.ones(2, 2), 1189 1, 1190 3 1191 ) 1192 1193 @dist_init 1194 def test_self_remote_rref_as_rpc_arg(self): 1195 dst = worker_name((self.rank + 1) % self.world_size) 1196 self._self_remote_rref_as_rpc_arg( 1197 dst, 1198 torch.ones(2, 2), 1199 1, 1200 3 1201 ) 1202 1203 @dist_init 1204 def test_self_remote_rref_as_self_rpc_arg(self): 1205 self._self_remote_rref_as_rpc_arg( 1206 rpc.get_worker_info(), 1207 torch.ones(2, 2), 1208 1, 1209 3 1210 ) 1211 1212 @dist_init 1213 def test_self_remote_rref_as_remote_arg(self): 1214 dst = worker_name((self.rank + 1) % self.world_size) 1215 self._self_remote_rref_as_remote_arg( 1216 dst, 1217 torch.ones(2, 2), 1218 1, 1219 3 1220 ) 1221 1222 @dist_init 1223 def test_self_remote_rref_as_self_remote_arg(self): 1224 self._self_remote_rref_as_remote_arg( 1225 rpc.get_worker_info(), 1226 torch.ones(2, 2), 1227 1, 1228 3 1229 ) 1230 1231 @dist_init 1232 def test_rref_proxy_non_exist(self): 1233 dst = worker_name((self.rank + 1) % self.world_size) 1234 rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) 1235 msg = "has no attribute 'non_exist'" 1236 with self.assertRaisesRegex(AttributeError, msg): 1237 rref.rpc_sync().non_exist() 1238 1239 with self.assertRaisesRegex(AttributeError, msg): 1240 rref.rpc_async().non_exist().wait() 1241 1242 with self.assertRaisesRegex(AttributeError, msg): 1243 rref.remote().non_exist() 1244 1245 def _test_rref_proxy_tensor(self, dst): 1246 rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) 1247 1248 expected = torch.ones(2, 2) + 1 + 3 1249 self.assertEqual(expected.size(), rref.rpc_sync().size()) 1250 self.assertEqual(expected + 1, rref.rpc_async().add(1).wait()) 1251 self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here()) 1252 1253 @dist_init 1254 def test_rref_proxy_tensor(self): 1255 self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size)) 1256 1257 @dist_init 1258 def test_rref_proxy_tensor_self(self): 1259 self._test_rref_proxy_tensor(rpc.get_worker_info()) 1260 1261 @dist_init 1262 def test_rref_proxy_reuse(self): 1263 rref = rpc.remote( 1264 worker_name((self.rank + 1) % self.world_size), 1265 my_function, 1266 args=(torch.ones(2, 2), 1, 3) 1267 ) 1268 expected = torch.ones(2, 2) + 1 + 3 1269 1270 proxy_rpc_sync = rref.rpc_sync() 1271 proxy_rpc_async = rref.rpc_async() 1272 proxy_remote = rref.remote() 1273 1274 self.assertEqual(expected.size(), proxy_rpc_sync.size()) 1275 self.assertEqual(expected + 1, proxy_rpc_sync.add(1)) 1276 self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4)) 1277 1278 self.assertEqual(expected.size(), proxy_rpc_async.size().wait()) 1279 self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait()) 1280 self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait()) 1281 1282 self.assertEqual(expected.size(), proxy_remote.size().to_here()) 1283 self.assertEqual(expected + 5, proxy_remote.add(5).to_here()) 1284 self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here()) 1285 1286 def _test_rref_proxy_class(self, dst): 1287 rref = rpc.remote(dst, MyClass, args=(7,)) 1288 expected = MyClass(7) 1289 self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) 1290 self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) 1291 self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) 1292 1293 expected.increment_value(3) 1294 self.assertEqual(None, rref.rpc_sync().increment_value(1)) 1295 self.assertEqual(None, rref.rpc_async().increment_value(1).wait()) 1296 self.assertEqual(None, rref.remote().increment_value(1).to_here()) 1297 1298 self.assertEqual(expected.get_value(), rref.rpc_sync().get_value()) 1299 self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait()) 1300 self.assertEqual(expected.get_value(), rref.remote().get_value().to_here()) 1301 1302 self.assertEqual( 1303 expected.my_instance_method(2), 1304 rref.rpc_sync().my_instance_method(2) 1305 ) 1306 self.assertEqual( 1307 expected.my_instance_method(3), 1308 rref.rpc_async().my_instance_method(3).wait() 1309 ) 1310 self.assertEqual( 1311 expected.my_instance_method(4), 1312 rref.remote().my_instance_method(4).to_here() 1313 ) 1314 1315 self.assertEqual( 1316 expected.my_static_method(9), 1317 rref.rpc_sync().my_static_method(9) 1318 ) 1319 self.assertEqual( 1320 expected.my_static_method(10), 1321 rref.rpc_async().my_static_method(10).wait() 1322 ) 1323 self.assertEqual( 1324 expected.my_static_method(11), 1325 rref.remote().my_static_method(11).to_here() 1326 ) 1327 1328 self.assertEqual( 1329 expected.my_class_method(2, torch.zeros(2, 2)), 1330 rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)) 1331 ) 1332 self.assertEqual( 1333 expected.my_class_method(2, torch.ones(3, 3)), 1334 rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait() 1335 ) 1336 self.assertEqual( 1337 expected.my_class_method(2, torch.ones(4, 4)), 1338 rref.remote().my_class_method(2, torch.ones(4, 4)).to_here() 1339 ) 1340 1341 @dist_init 1342 def test_rref_proxy_class(self): 1343 self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size)) 1344 1345 @dist_init 1346 def test_rref_proxy_class_self(self): 1347 self._test_rref_proxy_class(rpc.get_worker_info()) 1348 1349 @mock.patch.object(torch.distributed.autograd, "_init") 1350 @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent") 1351 @dist_init(setup_rpc=False) 1352 def test_register_rpc_backend_and_set_and_start_rpc_backend( 1353 self, mock_rpc_agent, mock_dist_autograd_init 1354 ): 1355 backend_name = "stub_backend" 1356 1357 backend = rpc.backend_registry.register_backend( 1358 backend_name, 1359 _stub_construct_rpc_backend_options_handler, 1360 _stub_init_rpc_backend_handler, 1361 ) 1362 1363 with self.assertRaisesRegex( 1364 RuntimeError, "^RPC backend .+: already registered$" 1365 ): 1366 backend = rpc.backend_registry.register_backend( 1367 backend_name, 1368 _stub_construct_rpc_backend_options_handler, 1369 _stub_init_rpc_backend_handler, 1370 ) 1371 1372 rpc.init_rpc( 1373 name="worker1", 1374 backend=backend, 1375 rank=self.rank, 1376 world_size=self.world_size, 1377 rpc_backend_options=self.rpc_backend_options, 1378 ) 1379 1380 @dist_init(setup_rpc=False) 1381 def test_duplicate_name(self): 1382 with self.assertRaisesRegex(RuntimeError, "is not unique"): 1383 store, _, _ = next( 1384 torch.distributed.rendezvous( 1385 self.init_method, rank=self.rank, world_size=self.world_size 1386 ) 1387 ) 1388 rpc._init_rpc_backend( 1389 backend=self.rpc_backend, 1390 store=store, 1391 name="duplicate_name", 1392 rank=self.rank, 1393 world_size=self.world_size, 1394 rpc_backend_options=self.rpc_backend_options, 1395 ) 1396 1397 @dist_init(setup_rpc=False) 1398 def test_duplicate_name_2(self): 1399 with self.assertRaisesRegex(RuntimeError, "is not unique"): 1400 rpc.init_rpc( 1401 name=worker_name(self.rank % (self.world_size - 1)), 1402 backend=self.rpc_backend, 1403 rank=self.rank, 1404 world_size=self.world_size, 1405 rpc_backend_options=self.rpc_backend_options, 1406 ) 1407 1408 @dist_init(setup_rpc=False) 1409 def test_reinit(self): 1410 rpc.init_rpc( 1411 name=worker_name(self.rank), 1412 backend=self.rpc_backend, 1413 rank=self.rank, 1414 world_size=self.world_size, 1415 rpc_backend_options=self.rpc_backend_options, 1416 ) 1417 1418 initialize_pg(self.file_init_method, self.rank, self.world_size) 1419 # Wait for all init to complete. 1420 dist.barrier() 1421 1422 # TODO: with TCP init, rank 0 raises Address already in use because 1423 # rank 0 is the start daemon and the store is created before checking if 1424 # RPC is already initialized in init_rpc. 1425 if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0: 1426 expected_reinit_err = "Address already in use" 1427 else: 1428 expected_reinit_err = "is already initialized" 1429 1430 with self.assertRaisesRegex(RuntimeError, expected_reinit_err): 1431 rpc.init_rpc( 1432 name=worker_name(self.rank), 1433 backend=self.rpc_backend, 1434 rank=self.rank, 1435 world_size=self.world_size, 1436 rpc_backend_options=self.rpc_backend_options, 1437 ) 1438 rpc.shutdown() 1439 1440 @dist_init(setup_rpc=False) 1441 def test_pg_init_no_rpc_init(self): 1442 dist.init_process_group( 1443 backend='gloo', 1444 init_method=self.file_init_method, 1445 rank=self.rank, 1446 world_size=self.world_size) 1447 1448 class MyModel(torch.nn.Module): 1449 def __init__(self) -> None: 1450 super().__init__() 1451 self.lin = torch.nn.Linear(3, 4) 1452 1453 def forward(self, x): 1454 return self.lin(x) 1455 1456 model = MyModel() 1457 model.train() 1458 model = torch.nn.parallel.DistributedDataParallel(model) 1459 1460 with self.assertRaisesRegex(RuntimeError, 'Current RPC agent is not set! Did you initialize the RPC framework'): 1461 params = [] 1462 for param in model.parameters(): 1463 params.append(RRef(param)) 1464 1465 def test_world_size_one(self): 1466 self._world_size_one( 1467 torch.ones(2, 2), 1468 torch.ones(2, 2) 1469 ) 1470 1471 @dist_init(setup_rpc=False) 1472 def test_invalid_names(self): 1473 1474 worker_id = 0 1475 with self.assertRaisesRegex(RuntimeError, "Worker name must match"): 1476 info = WorkerInfo("abc*", worker_id) 1477 1478 with self.assertRaisesRegex(RuntimeError, "Worker name must match"): 1479 info = WorkerInfo(" ", worker_id) 1480 1481 with self.assertRaisesRegex(RuntimeError, "must be non-empty"): 1482 info = WorkerInfo("", worker_id) 1483 1484 # If the number in the message does not match, it is likely that the 1485 # value of MAX_NAME_LEN in RPC WorkerInfo has changed. 1486 with self.assertRaisesRegex(RuntimeError, "shorter than 128"): 1487 info = WorkerInfo("".join(["a" for i in range(500)]), worker_id) 1488 1489 # Test that WorkerInfo can be pickled and sent in RPC call 1490 @dist_init 1491 def test_worker_info_pickle(self): 1492 dst_rank = (self.rank + 1) % self.world_size 1493 worker_info = rpc.api.get_worker_info() 1494 ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,)) 1495 self.assertEqual(ret, worker_info) 1496 1497 @dist_init 1498 def test_add(self): 1499 n = self.rank + 1 1500 dst_rank = n % self.world_size 1501 ret = rpc.rpc_sync( 1502 worker_name(dst_rank), 1503 torch.add, 1504 args=(torch.ones(n, n), torch.ones(n, n)), 1505 ) 1506 self.assertEqual(ret, torch.ones(n, n) * 2) 1507 1508 @staticmethod 1509 def return_callee_id(): 1510 return rpc.get_worker_info().id 1511 1512 @dist_init 1513 def test_int_callee(self): 1514 dst_rank = (self.rank + 1) % self.world_size 1515 ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id) 1516 self.assertEqual(ret, dst_rank) 1517 1518 @dist_init 1519 def test_add_with_id(self): 1520 n = self.rank + 1 1521 dst_rank = n % self.world_size 1522 workder_info = rpc.get_worker_info(worker_name(dst_rank)) 1523 1524 ret = rpc.rpc_sync( 1525 workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) 1526 ) 1527 self.assertEqual(ret, torch.ones(n, n) * 2) 1528 1529 @dist_init 1530 def test_scalar_add(self): 1531 n = self.rank + 1 1532 dst_rank = n % self.world_size 1533 ret = rpc.rpc_sync( 1534 worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n) 1535 ) 1536 self.assertEqual(ret, (torch.ones(n, n) + n)) 1537 1538 @dist_init 1539 def test_async_add(self): 1540 n = self.rank + 1 1541 dst_rank = n % self.world_size 1542 fut = rpc.rpc_async( 1543 worker_name(dst_rank), 1544 torch.add, 1545 args=(torch.ones(n, n), torch.ones(n, n)), 1546 ) 1547 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 1548 1549 @dist_init 1550 def test_nonzero(self): 1551 n = self.rank + 1 1552 dst_rank = n % self.world_size 1553 x = torch.ones(self.world_size, self.world_size) 1554 x[self.rank][self.rank] = 0 1555 ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,)) 1556 self.assertEqual(ret, x.nonzero()) 1557 1558 @dist_init 1559 def test_multi_rpc(self): 1560 self._multi_rpc(False) 1561 1562 @dist_init 1563 def test_future_wait_twice(self): 1564 dst = worker_name((self.rank + 1) % self.world_size) 1565 futs = [] 1566 for i in range(20): 1567 futs.append(rpc.rpc_async(dst, raise_func)) 1568 1569 with self.assertRaisesRegex(ValueError, "Expected error"): 1570 torch.futures.wait_all(futs) 1571 1572 for fut in futs: 1573 with self.assertRaisesRegex(ValueError, "Expected error"): 1574 fut.wait() 1575 1576 @dist_init(setup_rpc=False) 1577 def test_wait_all_workers_timeout(self): 1578 initialize_pg(self.file_init_method, self.rank, self.world_size) 1579 1580 rpc.init_rpc( 1581 name=worker_name(self.rank), 1582 backend=self.rpc_backend, 1583 rank=self.rank, 1584 world_size=self.world_size, 1585 rpc_backend_options=self.rpc_backend_options, 1586 ) 1587 1588 og_func = rpc.api._wait_all_workers 1589 1590 def wait_all_workers_sleep(timeout): 1591 rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout) 1592 1593 rpc.api._wait_all_workers = wait_all_workers_sleep 1594 1595 try: 1596 with self.assertRaisesRegex(RuntimeError, ''): 1597 rpc.shutdown(graceful=True, timeout=0.01) 1598 finally: 1599 rpc.api._wait_all_workers = og_func 1600 dist.barrier() 1601 1602 def test_wait_all_workers_dense(self): 1603 self._wait_all_workers(heavy_rpc, torch.ones(100, 100)) 1604 1605 def test_wait_all_workers_twice_dense(self): 1606 self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100)) 1607 1608 @dist_init 1609 def test_all_gather(self): 1610 info = rpc.get_worker_info() 1611 results = rpc.api._all_gather(info.id) 1612 expected = {} 1613 for info in rpc._get_current_rpc_agent().get_worker_infos(): 1614 expected[info.name] = info.id 1615 1616 self.assertEqual(expected, results) 1617 1618 @dist_init 1619 def test_all_gather_timeout(self): 1620 rpc._set_rpc_timeout(0.1) 1621 1622 if self.rank == 0: 1623 with self.assertRaisesRegex( 1624 RuntimeError, 1625 "timed out in _all_gather after 0\\.10 seconds" 1626 ): 1627 rpc.api._all_gather(SlowPickleClass(0.5)) 1628 else: 1629 expected_error = self.get_timeout_error_regex() 1630 with self.assertRaisesRegex(RuntimeError, expected_error): 1631 rpc.api._all_gather(SlowPickleClass(0.5)) 1632 1633 def _test_barrier_helper(self, info, names, multi_threaded=False): 1634 names = sorted(names) 1635 leader = names[0] 1636 rpc.rpc_sync(leader, _reset_count) 1637 if not multi_threaded and info.name == leader: 1638 self.assertEqual(_rpc_barrier_count, 0) 1639 rpc.api._barrier(names) 1640 rpc.rpc_sync(leader, _increment_count) 1641 rpc.api._barrier(names) 1642 if not multi_threaded and info.name == leader: 1643 self.assertEqual(_rpc_barrier_count, len(names)) 1644 1645 @dist_init 1646 def test_rpc_barrier_all(self): 1647 # Test rpc barrier when called with full list of workers 1648 info = rpc.get_worker_info() 1649 all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() 1650 names = [worker.name for worker in all_worker_info] 1651 self._test_barrier_helper(info, names) 1652 1653 @dist_init 1654 def test_rpc_barrier_subset(self): 1655 # Test rpc barrier when processes are called with different subsets of the full list 1656 info = rpc.get_worker_info() 1657 all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() 1658 if info.id % 2: 1659 names = [worker.name for worker in all_worker_info if worker.id % 2] 1660 else: 1661 names = [worker.name for worker in all_worker_info if not worker.id % 2] 1662 self._test_barrier_helper(info, names) 1663 1664 @dist_init 1665 def test_rpc_barrier_partial_subset(self): 1666 # Test rpc barrier when some processes are not involved in the barrier 1667 info = rpc.get_worker_info() 1668 all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() 1669 if info.id % 2: 1670 names = [worker.name for worker in all_worker_info if worker.id % 2] 1671 else: 1672 names = [f"worker{info.id}"] 1673 self._test_barrier_helper(info, names) 1674 1675 @dist_init 1676 def test_rpc_barrier_multithreaded(self): 1677 # This tests validates the implementation of barrier when multiple threads call into it 1678 # We only need to check that it does not hang in this case 1679 info = rpc.get_worker_info() 1680 all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() 1681 names = [worker.name for worker in all_worker_info] 1682 threads = [] 1683 for _ in range(3): 1684 th = threading.Thread(target=self._test_barrier_helper, args=(info, names, True)) 1685 threads.append(th) 1686 th.start() 1687 for th in threads: 1688 th.join() 1689 1690 @dist_init 1691 def test_graceful_shutdown_with_uneven_workload(self): 1692 """Test graceful termination.""" 1693 self._run_uneven_workload(heavy_rpc, torch.ones(100, 100)) 1694 1695 @dist_init(setup_rpc=False) 1696 def test_shutdown_followed_by_rpc(self): 1697 # Initialize RPC. 1698 rpc.init_rpc( 1699 name="worker%d" % self.rank, 1700 backend=self.rpc_backend, 1701 rank=self.rank, 1702 world_size=self.world_size, 1703 rpc_backend_options=self.rpc_backend_options, 1704 ) 1705 1706 n = self.rank + 1 1707 dst_rank = n % self.world_size 1708 ret = rpc.rpc_sync( 1709 worker_name(dst_rank), 1710 torch.add, 1711 args=(torch.ones(n, n), torch.ones(n, n)), 1712 ) 1713 self.assertEqual(ret, torch.ones(n, n) * 2) 1714 rpc.shutdown() 1715 1716 with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): 1717 rpc.rpc_sync( 1718 worker_name(dst_rank), 1719 torch.add, 1720 args=(torch.ones(n, n), torch.ones(n, n)), 1721 ) 1722 1723 @dist_init 1724 def test_expected_src(self): 1725 dst_rank = (self.rank + 1) % self.world_size 1726 expected_src_rank = (self.rank - 1) % self.world_size 1727 ret = rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,)) 1728 value = VALUE_FUTURE.result() 1729 self.assertEqual(value, expected_src_rank) 1730 1731 @dist_init 1732 def test_py_built_in(self): 1733 n = self.rank + 1 1734 dst_rank = n % self.world_size 1735 ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2)) 1736 self.assertEqual(ret, min(n, n + 1, n + 2)) 1737 1738 @dist_init 1739 def test_py_user_defined(self): 1740 n = self.rank + 1 1741 dst_rank = n % self.world_size 1742 ret = rpc.rpc_sync( 1743 worker_name(dst_rank), 1744 my_function, 1745 kwargs={"a": n, "b": n + 1, "c": n + 2}, 1746 ) 1747 self.assertEqual(ret, my_function(n, n + 1, n + 2)) 1748 1749 def test_build_rpc_profiling_key(self): 1750 # Tests that the name that shows up as an Event in profiling RPCs has all 1751 # the necessary information. 1752 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 1753 rpc_profiling_key = _build_rpc_profiling_key( 1754 exec_mode, "foo", "worker0", "worker1" 1755 ) 1756 self.assertIn(exec_mode.value, rpc_profiling_key) 1757 self.assertIn("foo", rpc_profiling_key) 1758 self.assertIn("worker0", rpc_profiling_key) 1759 self.assertIn("worker1", rpc_profiling_key) 1760 1761 def check_profiling_info(self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode): 1762 self.assertTrue(self_worker_name in rpc_event.name) 1763 self.assertTrue(dst_worker_name in rpc_event.name) 1764 if isinstance(func, torch.jit.ScriptFunction): 1765 self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name) 1766 else: 1767 self.assertTrue(func.__name__ in rpc_event.name) 1768 self.assertTrue(rpc_exec_mode.value in rpc_event.name) 1769 self.assertEqual(rpc_event.count, 1) 1770 1771 @dist_init 1772 def test_profiler_rpc_record_shapes(self): 1773 if self.rank != 1: 1774 return 1775 dst = (self.rank + 1) % self.world_size 1776 dst_worker = worker_name(dst) 1777 t1, t2 = torch.ones(100), torch.ones(100) 1778 with _profile(record_shapes=True) as prof: 1779 rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2)) 1780 1781 function_events = prof.function_events 1782 remote_events = [event for event in function_events if event.is_remote] 1783 remote_add_event = next( 1784 event for event in remote_events if "aten::add" in event.name 1785 ) 1786 remote_add_input_shapes = remote_add_event.input_shapes 1787 # Run profiler on equivalent local op and validate shapes are the same. 1788 with _profile(record_shapes=True) as prof: 1789 torch.add(t1, t2) 1790 1791 local_function_events = prof.function_events 1792 local_add_event = next( 1793 event for event in local_function_events if "aten::add" in event.name 1794 ) 1795 local_add_input_shapes = local_add_event.input_shapes 1796 self.assertEqual(remote_add_input_shapes, local_add_input_shapes) 1797 1798 @dist_init 1799 def test_profiler_rpc_memory(self): 1800 if self.rank != 1: 1801 return 1802 dst = (self.rank + 1) % self.world_size 1803 dst_worker = worker_name(dst) 1804 with _profile(profile_memory=True) as p: 1805 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) 1806 res = fut.wait() 1807 1808 function_events = p.function_events 1809 event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} 1810 # if cpu_memory_usage was not propagated over the wire, this set would 1811 # only contain 0 (indicates no memory being profiled) 1812 self.assertNotEqual({0}, event_cpu_mem_usages) 1813 # No memory profiled if profile_memory=False 1814 with _profile(profile_memory=False) as p: 1815 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) 1816 res = fut.wait() 1817 1818 function_events = p.function_events 1819 event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} 1820 self.assertEqual({0}, event_cpu_mem_usages) 1821 1822 @dist_init 1823 def test_profiler_export_trace(self): 1824 if self.rank != 1: 1825 return 1826 dst = (self.rank + 1) % self.world_size 1827 dst_worker = worker_name(dst) 1828 with _profile() as p: 1829 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) 1830 res = fut.wait() 1831 1832 events = p.function_events 1833 with TemporaryFileName() as fname: 1834 path = fname 1835 p.export_chrome_trace(path) 1836 with open(path) as f: 1837 trace = json.load(f) 1838 event_names = [event['name'] for event in trace] 1839 for expected_event_name in EXPECTED_REMOTE_EVENTS + [RPCExecMode.ASYNC.value]: 1840 event_exists = any(expected_event_name in event_name for event_name in event_names) 1841 self.assertTrue(event_exists) 1842 1843 @dist_init 1844 def test_profiler_rpc_key_names(self): 1845 # tests that remote events are properly prefixed with the RPC profiling key. 1846 if self.rank != 1: 1847 return 1848 1849 # Spawn multiple threads that send RPCs to ensure keys are correctly 1850 # prefixed when there are multiple RPCs being created/in flight at the 1851 # same time. 1852 dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] 1853 1854 def rpc_with_profiling(dst_worker): 1855 with _profile() as prof: 1856 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) 1857 fut.wait() 1858 1859 events = prof.function_events 1860 remote_event_names = { 1861 event.name: event for event in events if event.is_remote 1862 } 1863 rpc_profiling_key = _build_rpc_profiling_key( 1864 RPCExecMode.ASYNC, 1865 udf_with_torch_ops.__qualname__, 1866 worker_name(self.rank), 1867 dst_worker, 1868 ) 1869 1870 remote_event_name_set = set(EXPECTED_REMOTE_EVENTS) 1871 for name, event in remote_event_names.items(): 1872 # Ensure that we have the expected key as part of the remote 1873 # event. 1874 self.assertTrue(name.startswith(rpc_profiling_key)) 1875 self.assertTrue(event.is_remote) 1876 self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id) 1877 # Ensure that the remote event name also contains the operator. 1878 operator_name_substr = name[len(rpc_profiling_key) :] 1879 # Note: we don't assert that every remote event needs to be 1880 # in the above set, the set is just a representative set of 1881 # what we expect to see. The profiler can change and add more 1882 # events, but we should always expect to see this representative 1883 # set. 1884 matching_event = { 1885 remote_event_name 1886 for remote_event_name in remote_event_name_set 1887 if remote_event_name in operator_name_substr 1888 } 1889 remote_event_name_set -= matching_event 1890 1891 # The set should be empty, otherwise its contained elements did 1892 # not show up in the remote profiler output. 1893 self.assertTrue( 1894 remote_event_name_set == set(), 1895 f"Expected {remote_event_name_set} to be included in remote profiler output.", 1896 ) 1897 1898 for dst in dst_ranks: 1899 dst_worker = worker_name(dst) 1900 num_parallel_rpcs = 2 1901 with concurrent.futures.ThreadPoolExecutor( 1902 max_workers=num_parallel_rpcs 1903 ) as executor: 1904 futs = [ 1905 executor.submit(rpc_with_profiling, dst_worker) 1906 for _ in range(num_parallel_rpcs) 1907 ] 1908 # Wait for workers to finish test 1909 for fut in futs: 1910 fut.result() 1911 1912 def _run_test_profiler_remote_events_profiled(self): 1913 # Tests that we can successfully invoke the profiler on a remote node, 1914 # and collect the remote events back in the local profiler. 1915 if self.rank != 1: 1916 return 1917 1918 dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank] 1919 for dst in dst_ranks: 1920 dst_worker = worker_name(dst) 1921 with _profile() as prof: 1922 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) 1923 ret = fut.wait() 1924 1925 events = prof.function_events 1926 1927 rpc_event = get_function_event(events, RPCExecMode.ASYNC.value) 1928 self.check_profiling_info( 1929 worker_name(self.rank), 1930 dst_worker, 1931 udf_with_torch_ops, 1932 rpc_event, 1933 RPCExecMode.ASYNC, 1934 ) 1935 1936 remote_events = {event.name: event for event in events if event.is_remote} 1937 rpc_profiling_key = _build_rpc_profiling_key( 1938 RPCExecMode.ASYNC, 1939 udf_with_torch_ops.__qualname__, 1940 worker_name(self.rank), 1941 worker_name(dst), 1942 ) 1943 1944 for expected_remote_event_name in EXPECTED_REMOTE_EVENTS: 1945 expected_key = rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name 1946 self.assertTrue(expected_key in remote_events) 1947 remote_event = remote_events[expected_key] 1948 # Remote event should have a node ID corresponding to the worker 1949 # it ran on. 1950 self.assertEqual(remote_event.node_id, dst) 1951 1952 # Validate order remote events show up in profiling output. 1953 def convert_remote_to_local(event_name): 1954 remote_op_key = rpc_profiling_key + REMOTE_OP_STR 1955 return event_name[ 1956 event_name.find(remote_op_key) 1957 + len(remote_op_key) : 1958 ] 1959 1960 remote_events_list = [ 1961 convert_remote_to_local(event.name) 1962 for event in events 1963 if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS 1964 ] 1965 self.assertEqual( 1966 set(remote_events_list), 1967 set(EXPECTED_REMOTE_EVENTS), 1968 f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}", 1969 ) 1970 1971 @dist_init 1972 def test_profiler_remote_events_profiled(self): 1973 self._run_test_profiler_remote_events_profiled() 1974 1975 @dist_init 1976 def test_profiler_remote_events_profiled_single_threaded(self): 1977 self._run_test_profiler_remote_events_profiled() 1978 1979 def run_profiling_workload(self, dst): 1980 fut = rpc.rpc_async( 1981 worker_name(dst), 1982 torch.mul, 1983 args=( 1984 torch.tensor(1.0, requires_grad=True), 1985 torch.tensor(1.0, requires_grad=True), 1986 ), 1987 ) 1988 fut.wait() 1989 1990 def _run_rpc_profiling_async_function(self, device="cpu"): 1991 if self.rank != 1: 1992 return 1993 1994 dst1 = worker_name((self.rank + 1) % self.world_size) 1995 dst2 = worker_name((self.rank + 2) % self.world_size) 1996 x = torch.ones(2) 1997 y = torch.ones(2) 1998 with _profile() as prof: 1999 ret = rpc.rpc_async( 2000 dst1, slow_async_add, args=(dst2, x, y, device), timeout=20 2001 ) 2002 out = ret.wait() 2003 2004 function_events = prof.function_events 2005 # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be 2006 # recorded. 2007 key_prefix = _build_rpc_profiling_key( 2008 RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1 2009 ) 2010 2011 nested_rpc_key_prefix = _build_rpc_profiling_key( 2012 RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2 2013 ) 2014 expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix 2015 remote_events = [event for event in function_events if event.is_remote] 2016 rpc_remote_event = [ 2017 event for event in remote_events if event.name == expected_key 2018 ] 2019 self.assertEqual(1, len(rpc_remote_event)) 2020 rpc_remote_event = rpc_remote_event[0] 2021 self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size) 2022 # slow_async_add's RPC does an add on dst2, which should be reflected as well. 2023 remote_add_key = ( 2024 expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add) 2025 ) 2026 remote_add_event = [ 2027 event for event in remote_events if event.name == remote_add_key 2028 ] 2029 self.assertEqual(1, len(remote_add_event)) 2030 remote_add_event = remote_add_event[0] 2031 # Validate that node_id is dst2. 2032 self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size) 2033 2034 @dist_init 2035 def test_rpc_profiling_async_function(self): 2036 initialize_pg(self.file_init_method, self.rank, self.world_size) 2037 self._run_rpc_profiling_async_function() 2038 if torch.cuda.is_available(): 2039 dist.barrier() 2040 self._run_rpc_profiling_async_function(device="cuda:0") 2041 2042 @dist_init 2043 def test_rpc_profiling_async_function_single_threaded(self): 2044 initialize_pg(self.file_init_method, self.rank, self.world_size) 2045 self._run_rpc_profiling_async_function() 2046 if torch.cuda.is_available(): 2047 dist.barrier() 2048 self._run_rpc_profiling_async_function(device="cuda:0") 2049 2050 @dist_init 2051 def test_rpc_profiling_remote_record_function(self): 2052 # test that functions run over RPC with record_function show the expected 2053 # profiled block. 2054 if self.rank != 1: 2055 return 2056 dst_ranks = [i for i in range(self.world_size) if i != self.rank] 2057 for dst_rank in dst_ranks: 2058 dst_worker = worker_name(dst_rank) 2059 with _profile() as prof: 2060 fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True)) 2061 fut.wait() 2062 2063 function_events = prof.function_events 2064 record_function_remote_event = [ 2065 evt for evt in function_events if "##forward##" in evt.name 2066 ] 2067 self.assertEqual(1, len(record_function_remote_event)) 2068 record_function_remote_event = record_function_remote_event[0] 2069 self.assertEqual(record_function_remote_event.node_id, dst_rank) 2070 # cpu_children only returns direct children, so here we get all 2071 # children recursively. 2072 2073 def get_cpu_children(event): 2074 if not event.cpu_children: 2075 return [] 2076 cpu_children = event.cpu_children 2077 for e in event.cpu_children: 2078 cpu_children.extend(get_cpu_children(e)) 2079 return cpu_children 2080 2081 remote_children = get_cpu_children(record_function_remote_event) 2082 # Get local children and verify parity. 2083 with _profile() as prof: 2084 udf_with_torch_ops(-1, True) 2085 2086 local_function_events = prof.function_events 2087 local_record_function_event = next( 2088 evt for evt in local_function_events if "##forward##" in evt.name 2089 ) 2090 local_children = get_cpu_children(local_record_function_event) 2091 local_children_names = [ 2092 evt.name for evt in local_children 2093 ] 2094 2095 REMOTE_OP_STR = "#remote_op: " 2096 2097 def convert_remote_to_local(event_name): 2098 remote_op_key = REMOTE_OP_STR 2099 return event_name[ 2100 event_name.find(remote_op_key) + len(remote_op_key) : 2101 ] 2102 2103 for evt in remote_children: 2104 local_name = convert_remote_to_local(evt.name) 2105 self.assertTrue(local_name in local_children_names) 2106 2107 def validate_profiling_workload(self, dst, prof): 2108 2109 def convert_remote_to_local(event_name): 2110 return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :] 2111 2112 events = prof.function_events 2113 remote_events = { 2114 convert_remote_to_local(event.name): event 2115 for event in events 2116 if event.is_remote 2117 } 2118 self.assertTrue("aten::mul" in remote_events) 2119 remote_mul_event = remote_events["aten::mul"] 2120 self.assertEqual(remote_mul_event.node_id, dst) 2121 self.check_profiling_info( 2122 worker_name(self.rank), 2123 worker_name(dst), 2124 torch.mul, 2125 remote_mul_event, 2126 RPCExecMode.ASYNC, 2127 ) 2128 2129 def _run_test_profiler_with_autograd_context(self): 2130 dst = (self.rank + 1) % self.world_size 2131 if self.rank == 1: 2132 # Cases where we can double wrap messages with profiling information and autograd info. 2133 with dist_autograd.context() as context_id: 2134 with _profile() as prof: 2135 self.run_profiling_workload(dst) 2136 2137 self.validate_profiling_workload(dst, prof) 2138 2139 # Ensure that flipped order of ctx managers results in events being 2140 # recorded as expected. 2141 with _profile() as prof: 2142 with dist_autograd.context() as context_id: 2143 self.run_profiling_workload(dst) 2144 2145 self.validate_profiling_workload(dst, prof) 2146 2147 @dist_init 2148 def test_profiler_with_autograd_context_single_threaded(self): 2149 self._run_test_profiler_with_autograd_context() 2150 2151 @dist_init 2152 def test_profiler_with_autograd_context(self): 2153 self._run_test_profiler_with_autograd_context() 2154 2155 def _profiler_test_with_rpc( 2156 self, rpc_exec_mode, func, args, use_record_function=False, dst=None, kineto_profile=False 2157 ): 2158 dst = dst if dst is not None else (self.rank + 1) % self.world_size 2159 2160 # only run profiler on rank 1. 2161 p = _profile if not kineto_profile else torch.profiler.profile # kineto 2162 if self.rank == 1: 2163 with p() as prof: 2164 record_function_ctx_mgr = ( 2165 contextlib.nullcontext() 2166 if not use_record_function 2167 else torch.autograd.profiler.record_function( 2168 "foo" 2169 ) 2170 ) 2171 with record_function_ctx_mgr as rf: 2172 if rpc_exec_mode == RPCExecMode.SYNC: 2173 rpc.rpc_sync(worker_name(dst), func, args=args) 2174 elif rpc_exec_mode == RPCExecMode.ASYNC: 2175 fut = rpc.rpc_async(worker_name(dst), func, args=args) 2176 if kineto_profile: 2177 # Ensure multiple async RPCs don't cause issues. 2178 # Would have raised 2179 # "RuntimeError: Cannot call 2180 # RemoteProfilerManager::setCurrentKey when current 2181 # key is already set." error if RPC profiling was 2182 # not disabled properly for kineto. 2183 fut2 = rpc.rpc_async(worker_name(dst), func, args=args) 2184 fut2.wait() 2185 fut.wait() 2186 else: 2187 self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE) 2188 rref = rpc.remote(worker_name(dst), func, args=args) 2189 rref.to_here() 2190 # To avoid flakiness, wait for the RRef to be profiled. This 2191 # means that we received the acknowledgement of successful 2192 # creation on the owner and ran the callbacks responsible 2193 # for recording the profiling event. 2194 rref._get_profiling_future().wait() 2195 2196 events = prof.function_events if not kineto_profile else prof.events() 2197 if kineto_profile: 2198 # RPC profiling is disabled so there should be no rpc related 2199 # events. 2200 with self.assertRaises(IndexError): 2201 get_function_event(events, rpc_exec_mode.value) 2202 2203 return 2204 2205 rpc_event = get_function_event(events, rpc_exec_mode.value) 2206 # verify Node ID for this rpc event. 2207 self.assertEqual(rpc_event.node_id, self.rank) 2208 # Ensure recording of remote events. 2209 remote_events = {event for event in events if event.node_id == dst} - {rpc_event} 2210 self.assertGreaterEqual(len(remote_events), 1) 2211 for remote_event in remote_events: 2212 self.assertEqual(remote_event.node_id, dst) 2213 2214 if use_record_function: 2215 scope_event = get_function_event(events, "foo") 2216 # Since RPC call is within the scope, its CPU interval should be 2217 # contained within foo's interval. 2218 self.assertLessEqual(scope_event.time_range.start, rpc_event.time_range.start) 2219 self.assertGreaterEqual(scope_event.time_range.end, rpc_event.time_range.end) 2220 # the sender, dest worker, function run, and type of RPC should all 2221 # be recorded. 2222 self_worker_name = worker_name(self.rank) 2223 dst_worker_name = worker_name(dst) 2224 self.check_profiling_info(self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode) 2225 if use_record_function: 2226 # verify order by ensuring that the outer context comes 2227 # before the rpc event. 2228 foo_event_ix = next(i for i, event in enumerate(events) if "foo" in event.name) 2229 rpc_event_idx = next(i for i, event in enumerate(events) if rpc_exec_mode.value in event.name) 2230 self.assertLess(foo_event_ix, rpc_event_idx) 2231 2232 def _run_test_profiler_with_sync_rpc_udf(self): 2233 self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,)) 2234 self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,), 2235 use_record_function=True) 2236 2237 @dist_init 2238 def test_profiler_with_sync_rpc_udf(self): 2239 self._run_test_profiler_with_sync_rpc_udf() 2240 2241 @dist_init 2242 def test_profiler_with_sync_rpc_udf_single_threaded(self): 2243 self._run_test_profiler_with_sync_rpc_udf() 2244 2245 def _run_test_profiler_with_sync_rpc_builtin(self): 2246 self._profiler_test_with_rpc( 2247 RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) 2248 ) 2249 self._profiler_test_with_rpc( 2250 RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)), 2251 use_record_function=True 2252 ) 2253 2254 @dist_init 2255 def test_profiler_with_sync_rpc_builtin(self): 2256 self._run_test_profiler_with_sync_rpc_builtin() 2257 2258 @dist_init 2259 def test_profiler_with_sync_rpc_builtin_single_threaded(self): 2260 self._run_test_profiler_with_sync_rpc_builtin() 2261 2262 def _run_test_profiler_with_async_rpc_udf(self): 2263 self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,)) 2264 self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,), 2265 use_record_function=True) 2266 # Test to ensure that kineto profiler enabled in RPC does not enable 2267 # RPC profiling (it is unsupported) and does not result in issues. 2268 self._profiler_test_with_rpc( 2269 RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True 2270 ) 2271 2272 @dist_init 2273 def test_profiler_with_async_rpc_udf(self): 2274 self._run_test_profiler_with_async_rpc_udf() 2275 2276 @dist_init 2277 def test_profiler_with_async_rpc_udf_single_threaded(self): 2278 self._run_test_profiler_with_async_rpc_udf() 2279 2280 def _run_test_profiler_with_async_rpc_builtin(self): 2281 self._profiler_test_with_rpc( 2282 RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) 2283 ) 2284 self._profiler_test_with_rpc( 2285 RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)), 2286 use_record_function=True 2287 ) 2288 2289 @dist_init 2290 def test_profiler_with_async_rpc_builtin(self): 2291 self._run_test_profiler_with_async_rpc_builtin() 2292 2293 @dist_init 2294 def test_profiler_with_async_rpc_builtin_single_threaded(self): 2295 self._run_test_profiler_with_async_rpc_builtin() 2296 2297 def _run_test_profiler_with_remote_udf(self): 2298 self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,)) 2299 self._profiler_test_with_rpc( 2300 RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True 2301 ) 2302 # test remote to self 2303 self._profiler_test_with_rpc( 2304 RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank 2305 ) 2306 2307 @dist_init 2308 def test_profiler_with_remote_udf(self): 2309 self._run_test_profiler_with_remote_udf() 2310 2311 @dist_init 2312 def test_profiler_with_remote_udf_single_threaded(self): 2313 self._run_test_profiler_with_remote_udf() 2314 2315 def _run_test_profiler_with_remote_builtin(self): 2316 self._profiler_test_with_rpc( 2317 RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)) 2318 ) 2319 self._profiler_test_with_rpc( 2320 RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)), 2321 use_record_function=True 2322 ) 2323 # test remote to self 2324 self._profiler_test_with_rpc( 2325 RPCExecMode.REMOTE, 2326 torch.mul, 2327 args=(torch.ones(1), torch.ones(1)), 2328 dst=self.rank, 2329 ) 2330 2331 @dist_init 2332 def test_profiler_with_remote_builtin(self): 2333 self._run_test_profiler_with_remote_builtin() 2334 2335 @dist_init 2336 def test_profiler_with_remote_builtin_single_threaded(self): 2337 self._run_test_profiler_with_remote_builtin() 2338 2339 def _run_test_profiler_with_script_async_rpc(self): 2340 self._profiler_test_with_rpc( 2341 RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),) 2342 ) 2343 self._profiler_test_with_rpc( 2344 RPCExecMode.ASYNC, 2345 my_script_func, 2346 args=(torch.tensor(1),), 2347 use_record_function=True, 2348 ) 2349 2350 @dist_init 2351 def test_profiler_with_script_async_rpc(self): 2352 self._run_test_profiler_with_script_async_rpc() 2353 2354 @dist_init 2355 def test_profiler_with_script_async_rpc_single_threaded(self): 2356 self._run_test_profiler_with_script_async_rpc() 2357 2358 def _run_test_profiler_with_script_sync_rpc(self): 2359 self._profiler_test_with_rpc( 2360 RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),) 2361 ) 2362 self._profiler_test_with_rpc( 2363 RPCExecMode.SYNC, 2364 my_script_func, 2365 args=(torch.tensor(1),), 2366 use_record_function=True, 2367 ) 2368 2369 @dist_init 2370 def test_profiler_with_script_sync_rpc(self): 2371 self._run_test_profiler_with_script_sync_rpc() 2372 2373 @dist_init 2374 def test_profiler_with_script_sync_rpc_single_threaded(self): 2375 self._run_test_profiler_with_script_sync_rpc() 2376 2377 def _run_test_profiler_with_script_remote_rpc(self): 2378 self._profiler_test_with_rpc( 2379 RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),) 2380 ) 2381 self._profiler_test_with_rpc( 2382 RPCExecMode.REMOTE, 2383 my_script_func, 2384 args=(torch.tensor(1),), 2385 use_record_function=True, 2386 ) 2387 # test remote to self 2388 self._profiler_test_with_rpc( 2389 RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank 2390 ) 2391 2392 @dist_init 2393 def test_profiler_with_script_remote_rpc(self): 2394 self._run_test_profiler_with_script_remote_rpc() 2395 2396 @dist_init 2397 def test_profiler_with_script_remote_rpc_single_threaded(self): 2398 self._run_test_profiler_with_script_remote_rpc() 2399 2400 def _assert_top_level_events(self, process_global_events, expected_top_level_event_names): 2401 top_level_event_names = [] 2402 for thread_local_events in process_global_events: 2403 # Get top-level events from all events happened on a thread. 2404 last_end_time = 0 2405 for event in thread_local_events: 2406 event_name = event.name 2407 time_range = event.time_range 2408 if time_range.start > last_end_time: 2409 top_level_event_names.append(event_name) 2410 last_end_time = time_range.end 2411 top_level_event_names = sorted(top_level_event_names) 2412 expected_top_level_event_names = sorted(expected_top_level_event_names) 2413 self.assertEqual( 2414 top_level_event_names, 2415 expected_top_level_event_names, 2416 f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}", 2417 ) 2418 2419 @dist_init 2420 def test_server_process_global_profiler(self): 2421 if self.rank != 0: 2422 return 2423 2424 dst_rank = (self.rank + 1) % self.world_size 2425 dst_worker_name = worker_name(dst_rank) 2426 2427 x = torch.tensor(1) 2428 y = torch.tensor(2) 2429 2430 outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) 2431 outer_profile_rref.rpc_sync().__enter__() 2432 rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) 2433 inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) 2434 inner_profile_rref.rpc_sync().__enter__() 2435 rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) 2436 inner_profile_rref.rpc_sync().__exit__(None, None, None) 2437 outer_profile_rref.rpc_sync().__exit__(None, None, None) 2438 2439 inner_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (inner_profile_rref,)) 2440 expected_inner_events = ['aten::sub'] 2441 expected_outer_events = expected_inner_events + ['aten::add'] 2442 2443 self._assert_top_level_events(inner_events, expected_inner_events) 2444 outer_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (outer_profile_rref,)) 2445 self._assert_top_level_events(outer_events, expected_outer_events) 2446 2447 inner_profile_rref.rpc_sync().key_averages() 2448 outer_profile_rref.rpc_sync().key_averages() 2449 2450 @dist_init 2451 def test_async_record_function_double_end_callbacks(self): 2452 num_sleep_seconds = 1 2453 if self.rank == 1: 2454 # Validate that calling the function twice results in an error. 2455 with _profile() as pf: 2456 with torch.autograd.profiler.record_function("foo") as rf: 2457 fut = rpc.rpc_async( 2458 worker_name(0), my_sleep_func, args=(num_sleep_seconds,) 2459 ) 2460 rf._call_end_callbacks_on_future(fut) 2461 with self.assertRaisesRegex( 2462 RuntimeError, "can only be called once." 2463 ): 2464 rf._call_end_callbacks_on_future(fut) 2465 fut.wait() 2466 2467 @dist_init 2468 def test_async_record_function_legacy(self): 2469 # Test the legacy _record_function ops work 2470 # Note: These exist for backward compatibility with TorchScript 2471 num_sleep_seconds = 1 2472 if self.rank == 1: 2473 with _profile() as pf: 2474 try: 2475 handle = torch.ops.profiler._record_function_enter("foo", None) 2476 fut = rpc.rpc_async( 2477 worker_name(0), my_sleep_func, args=(num_sleep_seconds,) 2478 ) 2479 torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut) 2480 finally: 2481 torch.ops.profiler._record_function_exit(handle) 2482 2483 fut.wait() 2484 2485 @dist_init 2486 def test_async_record_function_cbs_jit_call(self): 2487 if self.rank == 1: 2488 with _profile() as pf: 2489 key = _build_rpc_profiling_key( 2490 RPCExecMode.ASYNC, 2491 torch._jit_internal._qualified_name(my_script_func), 2492 "worker1", 2493 "worker0", 2494 ) 2495 with torch.autograd.profiler.record_function(key) as rf: 2496 fut = rpc.rpc_async( 2497 worker_name(0), my_script_func, args=(torch.tensor(1),) 2498 ) 2499 # Intentionally calling record_function internals 2500 fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.record, fut) 2501 result = fut.wait() 2502 # Validate that the profiling future returns the same value as the RPC 2503 # future. 2504 expected = torch.add(torch.tensor(1), torch.tensor(1)) 2505 self.assertEqual(result, expected) 2506 events = pf.function_events 2507 rpc_event = get_function_event( 2508 events, torch._jit_internal._qualified_name(my_script_func) 2509 ) 2510 self.assertTrue(torch._jit_internal._qualified_name(my_script_func) in rpc_event.name) 2511 2512 @dist_init 2513 def test_py_class_constructor(self): 2514 n = self.rank + 1 2515 dst_rank = n % self.world_size 2516 ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,)) 2517 self.assertEqual(ret.a, n) 2518 2519 @dist_init 2520 def test_py_class_instance_method(self): 2521 n = self.rank + 1 2522 dst_rank = n % self.world_size 2523 ret = rpc.rpc_sync( 2524 worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,) 2525 ) 2526 self.assertEqual(ret, MyClass(2).my_instance_method(n)) 2527 2528 @dist_init 2529 def test_py_class_method(self): 2530 n = self.rank + 1 2531 dst_rank = n % self.world_size 2532 ret = rpc.rpc_sync( 2533 worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1) 2534 ) 2535 self.assertEqual(ret, MyClass.my_class_method(n, n + 1)) 2536 2537 @dist_init 2538 def test_py_class_static_method(self): 2539 n = self.rank + 1 2540 dst_rank = n % self.world_size 2541 ret = rpc.rpc_sync( 2542 worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,) 2543 ) 2544 self.assertEqual(ret, MyClass.my_static_method(n + 10)) 2545 2546 @dist_init 2547 def test_py_multi_async_call(self): 2548 n = self.rank + 1 2549 dst_rank = n % self.world_size 2550 dst_worker_info = rpc.get_worker_info(worker_name(dst_rank)) 2551 fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,)) 2552 fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2)) 2553 self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10)) 2554 self.assertEqual(fut2.wait(), min(n, n + 1, n + 2)) 2555 2556 @dist_init 2557 def test_py_no_return_result(self): 2558 n = self.rank + 1 2559 dst_rank = n % self.world_size 2560 ret = rpc.rpc_sync(worker_name(dst_rank), no_result) 2561 self.assertEqual(ret, no_result()) 2562 2563 @dist_init 2564 def test_py_tensors(self): 2565 n = self.rank + 1 2566 dst_rank = n % self.world_size 2567 ret = rpc.rpc_sync( 2568 worker_name(dst_rank), 2569 my_tensor_function, 2570 args=(torch.ones(n, n), torch.ones(n, n)), 2571 ) 2572 self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n))) 2573 2574 @dist_init 2575 def test_py_tensors_multi_async_call(self): 2576 futs = [] 2577 n = self.rank + 1 2578 dst_rank = n % self.world_size 2579 for i in range(100): 2580 fut = rpc.rpc_async( 2581 worker_name(dst_rank), 2582 my_tensor_function, 2583 args=(torch.ones(i, i), torch.ones(i, i)), 2584 ) 2585 futs.append(fut) 2586 2587 j = 0 2588 for val in torch.futures.wait_all(futs): 2589 self.assertEqual( 2590 val, my_tensor_function(torch.ones(j, j), torch.ones(j, j)) 2591 ) 2592 j += 1 2593 2594 @dist_init 2595 def test_py_tensors_in_container(self): 2596 n = self.rank + 1 2597 dst_rank = n % self.world_size 2598 a = [torch.ones(n, n), torch.ones(n, n)] 2599 b = TensorClass(build_complex_tensors()) 2600 c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} 2601 ret = rpc.rpc_sync( 2602 worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c) 2603 ) 2604 self.assertEqual(ret, my_complex_tensor_function(a, b, c)) 2605 2606 @dist_init 2607 def test_py_nested_pickle(self): 2608 n = self.rank + 1 2609 dst_rank = n % self.world_size 2610 2611 ret = rpc.rpc_sync( 2612 worker_name(dst_rank), 2613 run_nested_pickle, 2614 args=(MyPickleClass(), torch.ones(2, 2)), 2615 ) 2616 2617 m = MyPickleClass() 2618 m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) 2619 self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2))) 2620 2621 @dist_init 2622 def test_py_function_exception(self): 2623 n = self.rank + 1 2624 dst_rank = n % self.world_size 2625 with self.assertRaises(TypeError): 2626 ret = rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,)) 2627 2628 @dist_init 2629 def test_py_raise_in_user_func(self): 2630 with captured_output() as (_, err): 2631 # This barrier prevents a race condition where the main thread has 2632 # not entered the context manager when the remote function runs. 2633 initialize_pg(self.file_init_method, self.rank, self.world_size) 2634 dist.barrier() 2635 n = self.rank + 1 2636 dst_rank = n % self.world_size 2637 fut = rpc.rpc_async(worker_name(dst_rank), raise_func) 2638 with self.assertRaisesRegex(ValueError, expected_err): 2639 fut.wait() 2640 # This barrier prevents a race condition where the main thread exits 2641 # context manager before the remote function has ran. 2642 dist.barrier() 2643 2644 # Validate that trainers log errors when running functions. 2645 stderr_lines = err.getvalue() 2646 self.assertTrue(expected_err in stderr_lines) 2647 2648 @dist_init 2649 def test_py_raise_in_user_func_escaped_str(self): 2650 n = self.rank + 1 2651 dst_rank = n % self.world_size 2652 fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape) 2653 try: 2654 fut.wait() 2655 except ValueError as e: 2656 msg = str(e) 2657 # Ensure newlines are unescaped to provide a better repr of error. 2658 self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape")) 2659 else: 2660 self.assertTrue(False, "expected raise_func_escape to raise ValueError.") 2661 2662 @dist_init 2663 def test_nested_rpc(self): 2664 self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1) 2665 2666 @dist_init 2667 def test_stress_light_rpc(self): 2668 self._stress_test_rpc(light_rpc) 2669 2670 @dist_init 2671 def test_stress_heavy_rpc(self): 2672 self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) 2673 2674 @dist_init 2675 def test_stress_heavy_rpc_torchscript(self): 2676 self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)) 2677 2678 @dist_init 2679 def test_builtin_remote_ret(self): 2680 self._builtin_remote_ret( 2681 torch.ones(2, 2), 2682 torch.ones(2, 2), 2683 torch.ones(2, 2) * 2 2684 ) 2685 2686 @dist_init 2687 def test_builtin_remote_self(self): 2688 self._builtin_remote_self( 2689 torch.ones(2, 2), 2690 torch.ones(2, 2), 2691 torch.ones(2, 2) * 2 2692 ) 2693 2694 @staticmethod 2695 def _multi_args_fn(n, sparse=False): 2696 if sparse: 2697 return (build_sparse_tensor(), build_sparse_tensor()) 2698 else: 2699 return (torch.ones(n, n), torch.ones(n, n)) 2700 2701 @dist_init 2702 def test_multi_builtin_remote_ret(self): 2703 self._test_multi_remote_call( 2704 torch.add, False, 2705 args_fn=RpcTest._multi_args_fn 2706 ) 2707 2708 @dist_init 2709 def test_py_udf_remote(self): 2710 n = self.rank + 1 2711 dst_rank = n % self.world_size 2712 rref = rpc.remote( 2713 worker_name(dst_rank), 2714 my_function, 2715 kwargs={"a": n, "b": n + 1, "c": n + 2}, 2716 ) 2717 self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) 2718 2719 @staticmethod 2720 def _multi_kwargs_fn(n, sparse=False): 2721 if sparse: 2722 return { 2723 "a": build_sparse_tensor(), 2724 "b": build_sparse_tensor(), 2725 "c": build_sparse_tensor() 2726 } 2727 else: 2728 return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} 2729 2730 @dist_init 2731 def test_multi_py_udf_remote(self): 2732 self._test_multi_remote_call( 2733 my_function, 2734 False, 2735 kwargs_fn=RpcTest._multi_kwargs_fn 2736 ) 2737 2738 @dist_init 2739 def test_py_rref_args(self): 2740 self._py_rref_args( 2741 torch.ones(2, 2), 2742 1, 2743 torch.ones(2, 2), 2744 2, 2745 torch.ones(2, 2) * 2 + 3) 2746 2747 @dist_init 2748 def test_py_rref_args_user_share(self): 2749 self._py_rref_args_user_share( 2750 torch.ones(2, 2), 2751 1, 2752 2, 2753 torch.ones(2, 2), 2754 3, 2755 4, 2756 torch.ones(2, 2) * 2 + 10 2757 ) 2758 2759 @dist_init 2760 def test_py_rpc_rref_args(self): 2761 self._py_rpc_rref_args( 2762 torch.ones(2, 2), 2763 1, 2764 2, 2765 torch.ones(2, 2), 2766 3, 2767 4, 2768 torch.ones(2, 2) * 2 + 10 2769 ) 2770 2771 @dist_init 2772 def test_nested_remote(self): 2773 self._nested_remote( 2774 nested_remote, 2775 torch.ones(2, 2) + 3 2776 ) 2777 2778 @dist_init 2779 def test_nested_rref(self): 2780 self._nested_rref( 2781 nested_rref, 2782 torch.ones(2, 2) + 1, 2783 torch.ones(2, 2) + 2 2784 ) 2785 2786 @dist_init 2787 def test_nested_rref_stress(self): 2788 self._nested_rref_stress( 2789 nested_rref, 2790 torch.ones(2, 2) + 1, 2791 torch.ones(2, 2) + 2 2792 ) 2793 2794 @dist_init 2795 def test_multi_layer_nested_async_rpc(self): 2796 # This test will exit right away, but there will be a chain of async 2797 # RPCs. The termination algorithm should detect those messages properly. 2798 # Otherwise, some peer could exit early, leaving others to timeout 2799 # errors or connection closed errors. 2800 ttl = 20 2801 n = self.rank + 1 2802 dst_rank = n % self.world_size 2803 2804 multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl) 2805 2806 @dist_init 2807 def test_remote_with_exception(self): 2808 n = self.rank + 1 2809 dst_rank = n % self.world_size 2810 # check ref to other workers 2811 rref = rpc.remote(worker_name(dst_rank), raise_func) 2812 with self.assertRaises(ValueError): 2813 rref.to_here() 2814 # check ref to itself 2815 rref = rpc.remote(worker_name(self.rank), no_result, args=(10,)) 2816 with self.assertRaises(TypeError): 2817 rref.to_here() 2818 2819 @dist_init 2820 def test_rpc_return_rref(self): 2821 n = self.rank + 1 2822 dst_rank1 = n % self.world_size 2823 dst_rank2 = (n + 1) % self.world_size 2824 rref = rpc.rpc_sync( 2825 worker_name(dst_rank1), 2826 rpc_return_rref, 2827 args=(worker_name(dst_rank2),), 2828 ) 2829 self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) 2830 2831 @dist_init 2832 def test_rref_forward_chain(self): 2833 ttl = 8 2834 n = self.rank + 1 2835 dst_rank = n % self.world_size 2836 2837 rref = rpc.remote( 2838 worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) 2839 ) 2840 2841 ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) 2842 2843 for i in range(ttl): 2844 self.assertEqual(len(ret_rref), 1) 2845 ret_rref = ret_rref[0].to_here() 2846 2847 ret = ret_rref 2848 self.assertEqual(ret, torch.add(torch.ones(n, n), 1)) 2849 2850 @dist_init 2851 def test_local_rref_no_fork(self): 2852 local_rref = RRef(35) 2853 self.assertEqual(local_rref.local_value(), 35) 2854 2855 @dist_init 2856 def test_local_value_not_on_owner(self): 2857 # ensure that an error message is thrown if a user tries to call 2858 # local_value() on a non-owning node. 2859 next_rank = (self.rank + 1) % self.world_size 2860 rref = rpc.remote( 2861 worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1)) 2862 ) 2863 with self.assertRaisesRegex( 2864 RuntimeError, ( 2865 fr"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), " 2866 fr"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), " 2867 r"can't call localValue\(\) on user " 2868 fr"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). " 2869 fr"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)" 2870 ) 2871 ): 2872 rref.local_value() 2873 2874 @dist_init 2875 def test_return_local_rrefs(self): 2876 n = self.rank + 1 2877 dst_rank = n % self.world_size 2878 2879 rref_list = rpc.rpc_sync( 2880 worker_name(dst_rank), get_rref_list, args=([1, 2, 3],) 2881 ) 2882 2883 for rref in rref_list: 2884 rpc.rpc_sync( 2885 rref.owner(), 2886 _call_method_on_rref, 2887 args=(MyClass.increment_value, rref, 10), 2888 ) 2889 2890 rets = [ 2891 rpc.rpc_sync( 2892 rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref) 2893 ) 2894 for rref in rref_list 2895 ] 2896 2897 self.assertEqual(rets, [11, 12, 13]) 2898 2899 @dist_init 2900 def _test_rref_type(self, blocking): 2901 2902 def launched_rpc(events): 2903 expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner" 2904 return any(e.name.startswith(expected_name) for e in events) 2905 2906 dst = worker_name((self.rank + 1) % self.world_size) 2907 rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1)) 2908 2909 with _profile() as p: 2910 t = rref._get_type(blocking=blocking) 2911 if not blocking: 2912 t = t.wait() 2913 2914 self.assertTrue(launched_rpc(p.function_events)) 2915 expected_type = type(torch.ones(2)) 2916 self.assertEqual(t, expected_type) 2917 2918 futs = [] 2919 2920 def verify(fut): 2921 self.assertEqual(fut.value(), expected_type) 2922 2923 with _profile() as p: 2924 for _ in range(10): 2925 t = rref._get_type(blocking=blocking) 2926 if not blocking: 2927 futs.append(t) 2928 t.add_done_callback(verify) 2929 t = t.wait() 2930 self.assertEqual(t, expected_type) 2931 2932 if not blocking: 2933 # Note that cached calls with blocking=False all return the same 2934 # cached original future. 2935 first_fut = futs[0] 2936 for f in futs[1:]: 2937 self.assertTrue(f is first_fut) 2938 # Ensure we never launch another RPC, other than for the very 2939 # first call. 2940 self.assertFalse(launched_rpc(p.function_events)) 2941 self.assertEqual(t, type(torch.ones(2))) 2942 2943 rref = rpc.remote(dst, MyClass, args=(0,)) 2944 rref_type = rref._get_type(blocking=blocking) 2945 if not blocking: 2946 rref_type = rref_type.wait() 2947 self.assertEqual(rref_type, MyClass) 2948 2949 def test_rref_type_blocking(self): 2950 self._test_rref_type(blocking=True) 2951 2952 def test_rref_type_non_blocking(self): 2953 self._test_rref_type(blocking=False) 2954 2955 @dist_init 2956 def _test_rref_type_with_error(self, blocking): 2957 dst = worker_name((self.rank + 1) % self.world_size) 2958 # 10 ms timeout 2959 rref = rpc.remote(dst, raise_func) 2960 # Blocking: error raised inline 2961 if blocking: 2962 with self.assertRaisesRegex(ValueError, "Expected error"): 2963 rref._get_type(blocking=blocking) 2964 else: 2965 # Non-blocking: Immediately return future, block on wait 2966 fut = rref._get_type(blocking=blocking) 2967 with self.assertRaisesRegex(ValueError, "Expected error"): 2968 fut.wait() 2969 2970 2971 def test_rref_type_with_error_blocking(self): 2972 self._test_rref_type_with_error(blocking=True) 2973 2974 def test_rref_type_with_error_non_blocking(self): 2975 self._test_rref_type_with_error(blocking=False) 2976 2977 @dist_init 2978 def _test_rref_type_owner(self, blocking): 2979 rref = RRef(torch.ones(2) + 1) 2980 rref_type = rref._get_type(blocking=blocking) 2981 if not blocking: 2982 rref_type = rref_type.wait() 2983 self.assertEqual(rref_type, type(torch.ones(2))) 2984 2985 rref = RRef(MyClass(0)) 2986 rref_type = rref._get_type(blocking=blocking) 2987 if not blocking: 2988 rref_type = rref_type.wait() 2989 self.assertEqual(rref_type, MyClass) 2990 2991 def test_rref_type_owner_blocking(self): 2992 self._test_rref_type_owner(blocking=True) 2993 2994 def test_rref_type_owner_non_blocking(self): 2995 self._test_rref_type_owner(blocking=False) 2996 2997 @staticmethod 2998 def _slow_add(x, y): 2999 time.sleep(1) 3000 return x + y 3001 3002 @dist_init 3003 def test_rref_type_slow_init(self): 3004 dst = worker_name((self.rank + 1) % self.world_size) 3005 rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1)) 3006 self.assertEqual(rref._get_type(), type(torch.ones(2))) 3007 3008 @dist_init 3009 def test_owner_equality(self): 3010 a = RRef(40) 3011 b = RRef(50) 3012 3013 other_rank = (self.rank + 1) % self.world_size 3014 other_a = rpc.remote( 3015 worker_name(other_rank), torch.add, args=(torch.ones(1), 1) 3016 ) 3017 other_b = rpc.remote( 3018 worker_name(other_rank), torch.add, args=(torch.ones(1), 1) 3019 ) 3020 other_a.to_here() # to ensure clean termination 3021 other_b.to_here() 3022 3023 self.assertNotEqual(a.owner(), 23) 3024 self.assertEqual(other_a.owner(), other_b.owner()) 3025 self.assertNotEqual(a.owner(), other_a.owner()) 3026 self.assertEqual(other_a.owner(), other_a.owner()) 3027 self.assertEqual(other_a.owner(), other_b.owner()) 3028 self.assertEqual(a.owner(), a.owner()) 3029 self.assertEqual(a.owner(), b.owner()) 3030 self.assertEqual(a.owner(), rpc.get_worker_info()) 3031 x = {} 3032 x[a.owner()] = a 3033 x[other_a.owner()] = other_a 3034 self.assertEqual(x[a.owner()], a) 3035 self.assertEqual(x[b.owner()], a) 3036 self.assertEqual(x[other_a.owner()], other_a) 3037 self.assertEqual(x[other_b.owner()], other_a) 3038 self.assertEqual(len(x), 2) 3039 3040 @dist_init 3041 def test_pass_local_rrefs(self): 3042 n = self.rank + 1 3043 dst_rank = n % self.world_size 3044 dst_worker = worker_name(dst_rank) 3045 3046 rref = RRef(40) 3047 self.assertEqual( 3048 rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90 3049 ) 3050 self.assertEqual( 3051 rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90 3052 ) 3053 self.assertEqual( 3054 rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90 3055 ) 3056 3057 @dist_init 3058 def test_remote_same_worker(self): 3059 n = self.rank + 1 3060 dst_rank = n % self.world_size 3061 rref_a = rpc.remote( 3062 worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) 3063 ) 3064 rref_b = rpc.remote( 3065 worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) 3066 ) 3067 rref_c = rpc.remote( 3068 worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) 3069 ) 3070 self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) 3071 3072 @dist_init(setup_rpc=True) 3073 def test_call_method_on_rref(self): 3074 """ 3075 Tests that it is possible to call an instance method on a remote object 3076 by using rref.owner() as destination of the call. 3077 """ 3078 vals = [10, 2, 5, 7] 3079 dst_rank = (self.rank + 1) % self.world_size 3080 dst_worker = worker_name(dst_rank) 3081 3082 # creates a remote object 3083 rref = rpc.remote(dst_worker, MyClass, args=(vals[0],)) 3084 3085 # modifies state of the remote object 3086 rpc.rpc_sync( 3087 rref.owner(), 3088 _call_method_on_rref, 3089 args=(MyClass.increment_value, rref, vals[1]), 3090 ) 3091 rpc.rpc_async( 3092 rref.owner(), 3093 _call_method_on_rref, 3094 args=(MyClass.increment_value, rref, vals[2]), 3095 ).wait() 3096 rpc.remote( 3097 rref.owner(), 3098 _call_method_on_rref, 3099 args=(MyClass.increment_value, rref, vals[3]), 3100 ).to_here() 3101 3102 # queries state of the remote object 3103 result = rpc.rpc_sync( 3104 dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref) 3105 ) 3106 3107 self.assertEqual(result, sum(vals)) 3108 3109 # Notice `rpc.api.shutdown()` accesses 3110 # `_delete_all_user_and_unforked_owner_rrefs` through 3111 # `torch.distributed.rpc.api`, so patching 3112 # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will 3113 # not help. 3114 @mock.patch.object(torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs") 3115 def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak): 3116 rpc.init_rpc( 3117 name=worker_name(self.rank), 3118 backend=self.rpc_backend, 3119 rank=self.rank, 3120 world_size=self.world_size, 3121 rpc_backend_options=self.rpc_backend_options, 3122 ) 3123 3124 initialize_pg(self.file_init_method, self.rank, self.world_size) 3125 # Wait for all init to complete. 3126 dist.barrier() 3127 3128 rref = rpc.remote( 3129 worker_name((self.rank + 1) % self.world_size), 3130 torch.add, 3131 args=(torch.ones(2, 2), 1), 3132 ) 3133 3134 import torch.distributed.rpc.api as api 3135 3136 if ignore_leak: 3137 api._ignore_rref_leak = True 3138 rpc.shutdown(graceful=True) 3139 else: 3140 api._ignore_rref_leak = False 3141 with self.assertRaisesRegex(RuntimeError, "Leaking RRef"): 3142 rpc.shutdown(graceful=True) 3143 3144 @dist_init(setup_rpc=False) 3145 def test_rref_leak(self): 3146 self._test_rref_leak(ignore_leak=False) 3147 3148 @dist_init(setup_rpc=False) 3149 def test_ignore_rref_leak(self): 3150 self._test_rref_leak(ignore_leak=True) 3151 3152 @dist_init 3153 def test_rref_str(self): 3154 rref1 = RRef(self.rank) 3155 id_class = "GloballyUniqueId" 3156 self.assertEqual( 3157 f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))", rref1.__str__() 3158 ) 3159 3160 dst_rank = (self.rank + 1) % self.world_size 3161 rref2 = rpc.remote( 3162 worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) 3163 ) 3164 self.assertEqual( 3165 rref2.__str__(), 3166 f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), " 3167 f"ForkId = {id_class}(created_on={self.rank}, local_id=2))", 3168 ) 3169 3170 @dist_init 3171 def test_rref_get_future(self): 3172 # Tests that we can obtain the future corresponding to the creation of 3173 # the RRef on remote end 3174 if self.rank == 0: 3175 # Builtin 3176 rref = rpc.remote(worker_name(1), torch.add, args=(1, 1)) 3177 rref.to_here() 3178 fut = rref._get_future() 3179 self.assertIsInstance(fut, torch._C.Future) 3180 3181 # UDF 3182 rref = rpc.remote(worker_name(1), foo_add, args=()) 3183 rref.to_here() 3184 fut = rref._get_future() 3185 self.assertIsInstance(fut, torch._C.Future) 3186 3187 # Script 3188 rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), )) 3189 rref.to_here() 3190 fut = rref._get_future() 3191 self.assertIsInstance(fut, torch._C.Future) 3192 3193 3194 @dist_init 3195 def test_rref_context_debug_info(self): 3196 # This test checks local states that are modified by remote workers. 3197 # This means that we would need barrier before and after every check. 3198 # The barrier before the check makes sure that all previous states are 3199 # cleared globally, the barrier after ensures that no following states 3200 # change gets into the current check. 3201 initialize_pg(self.file_init_method, self.rank, self.world_size) 3202 3203 # Check 1: local RRef does not update owners_ map or add a pending user. 3204 ################################################# 3205 3206 rref1 = RRef(self.rank) 3207 3208 # don't need a barrier here as local RRef is handled by this thread 3209 info = _rref_context_get_debug_info() 3210 self.assertIn("num_owner_rrefs", info) 3211 self.assertIn("num_pending_users", info) 3212 # RRef on local value is not added to context until shared across RPC 3213 self.assertEqual(0, int(info["num_owner_rrefs"])) 3214 self.assertEqual(0, int(info["num_pending_users"])) 3215 # barrier after the check 1 3216 dist.barrier() 3217 3218 # Check 2: Sharing RRef as an arg should update owners_ map 3219 ########################################################### 3220 3221 dst_rank = (self.rank + 1) % self.world_size 3222 rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,)) 3223 3224 # barrier before check 2 3225 wait_until_pending_futures_and_users_flushed() 3226 dist.barrier() 3227 3228 info = _rref_context_get_debug_info() 3229 self.assertIn("num_owner_rrefs", info) 3230 self.assertEqual(1, int(info["num_owner_rrefs"])) 3231 # no pending users since the fork is finished 3232 self.assertEqual(0, int(info["num_pending_users"])) 3233 # barrier after check 2 3234 dist.barrier() 3235 3236 # clear states for check 2 3237 rpc.rpc_sync(worker_name(dst_rank), clear_global_rref) 3238 3239 # Wait for owner rref to be cleared. 3240 while int(info["num_owner_rrefs"]) != 0: 3241 info = _rref_context_get_debug_info() 3242 time.sleep(0.1) 3243 dist.barrier() 3244 3245 # Check 3: rpc.remote call should update owners_ map 3246 #################################################### 3247 rref2 = rpc.remote( 3248 worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) 3249 ) 3250 rref3 = rpc.remote( 3251 worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) 3252 ) 3253 rref2.to_here() 3254 rref3.to_here() 3255 3256 # barrier before check 3 3257 wait_until_pending_futures_and_users_flushed() 3258 dist.barrier() 3259 3260 info = _rref_context_get_debug_info() 3261 self.assertIn("num_owner_rrefs", info) 3262 self.assertEqual(2, int(info["num_owner_rrefs"])) 3263 # no pending users since the fork is finished 3264 self.assertEqual(0, int(info["num_pending_users"])) 3265 3266 # barrier after check 3 3267 dist.barrier() 3268 3269 @dist_init 3270 def test_disable_gil_profiling(self): 3271 # test that rpc.enable_gil_profiling(false) will result in 3272 # GIL wait time not being recorded. 3273 3274 # GIL profiling should be disabled by default. 3275 dst_rank = (self.rank + 1) % self.world_size 3276 rpc.rpc_sync( 3277 worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) 3278 ) 3279 info = rpc.api._get_current_rpc_agent().get_debug_info() 3280 self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"]) 3281 rpc.enable_gil_profiling(True) 3282 rpc.rpc_sync( 3283 worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) 3284 ) 3285 info = rpc.api._get_current_rpc_agent().get_debug_info() 3286 self.assertIn("agent.gil_average_wait_time_us", info) 3287 3288 @dist_init(setup_rpc=False) 3289 def test_local_shutdown(self): 3290 # test that we can start RPC and then immediately locally shutdown 3291 # without sending any messages. 3292 rpc.init_rpc( 3293 name="worker%d" % self.rank, 3294 backend=self.rpc_backend, 3295 rank=self.rank, 3296 world_size=self.world_size, 3297 rpc_backend_options=self.rpc_backend_options, 3298 ) 3299 # pass in graceful=False to ensure that we don't wait for other workers. 3300 rpc.shutdown(graceful=False) 3301 3302 @dist_init 3303 def test_debug_info(self): 3304 # only test keys in this test case. Values should be covered by 3305 # individual module debug info tests 3306 import torch.distributed.autograd as dist_autograd 3307 3308 info = _get_debug_info() 3309 rref_info = _rref_context_get_debug_info() 3310 agent_info = rpc.api._get_current_rpc_agent().get_debug_info() 3311 autograd_info = dist_autograd._get_debug_info() 3312 common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys() 3313 self.assertEqual(0, len(common_keys)) 3314 expected = {} 3315 expected.update(rref_info) 3316 expected.update(agent_info) 3317 expected.update(autograd_info) 3318 # NB: Key ordering is only preserved in python 3.6+. So here, we 3319 # manually check keys are equal. 3320 for key in expected.keys(): 3321 self.assertIn(key, info.keys()) 3322 3323 for key in info.keys(): 3324 self.assertIn(key, expected.keys()) 3325 3326 @dist_init(setup_rpc=False) 3327 @skip_but_pass_in_sandcastle_if( 3328 IS_MACOS, 3329 "Test is flaky on MacOS since libuv error handling is not as robust as TCP", 3330 ) 3331 def test_handle_send_exceptions(self): 3332 # test that if a callee node has gone down, we raise an appropriate 3333 # exception instead of just crashing. 3334 rpc.init_rpc( 3335 name="worker%d" % self.rank, 3336 backend=self.rpc_backend, 3337 rank=self.rank, 3338 world_size=self.world_size, 3339 rpc_backend_options=self.rpc_backend_options, 3340 ) 3341 rpc._set_rpc_timeout(10) 3342 # This barrier is needed to ensure that some workers do not exit before 3343 # others have been brought up. 3344 initialize_pg(self.file_init_method, self.rank, self.world_size) 3345 dist.barrier() 3346 if self.rank == 1: 3347 dst_rank = (self.rank + 1) % self.world_size 3348 dst_worker = worker_name(dst_rank) 3349 # allow destination worker to exit without joining 3350 error_str = self.get_shutdown_error_regex() 3351 wait_until_node_failure(dst_rank, error_str) 3352 fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3)) 3353 # Shutdown sequence is not very well defined and as a result 3354 # we can see any of the error messages defined in get_shutdown_error_regex. 3355 with self.assertRaisesRegex(RuntimeError, error_str): 3356 fut.wait() 3357 # exit all workers non-gracefully. 3358 rpc.shutdown(graceful=False) 3359 3360 @dist_init 3361 def test_deadlock(self): 3362 # this test is copied from https://github.com/pytorch/pytorch/issues/45089 3363 if self.rank == 1: 3364 dst1 = worker_name((self.rank + 1) % self.world_size) 3365 x = torch.ones(2) 3366 y = torch.ones(2) 3367 rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait() 3368 3369 dist_initialized = dist.is_initialized() 3370 if not dist_initialized: 3371 dist.init_process_group( 3372 backend="gloo", 3373 init_method=self.file_init_method, 3374 rank=self.rank, 3375 world_size=self.world_size, 3376 ) 3377 3378 @dist_init(setup_rpc=False) 3379 def test_local_shutdown_with_rpc(self): 3380 # test that we can start RPC, send RPCs, and then run local shutdown. 3381 rpc.init_rpc( 3382 name="worker%d" % self.rank, 3383 backend=self.rpc_backend, 3384 rank=self.rank, 3385 world_size=self.world_size, 3386 rpc_backend_options=self.rpc_backend_options, 3387 ) 3388 n = self.rank + 1 3389 dst_rank = n % self.world_size 3390 rpc.rpc_sync( 3391 worker_name(dst_rank), 3392 torch.add, 3393 args=(torch.ones(n, n), torch.ones(n, n)), 3394 ) 3395 # A barrier is needed to ensure that all RPCs are processed. 3396 # Otherwise, some RPCs can timeout since the receiving end 3397 # has terminated. 3398 initialize_pg(self.file_init_method, self.rank, self.world_size) 3399 dist.barrier() 3400 # pass in graceful=False to ensure that we don't wait for other workers. 3401 rpc.shutdown(graceful=False) 3402 3403 @dist_init(setup_rpc=False) 3404 def test_set_and_get_default_rpc_timeout(self): 3405 timeout = 0.5 3406 3407 # A new `RpcBackendOptions` is constructed 3408 # when accessing `self.rpc_backend_options`. 3409 rpc_backend_options = self.rpc_backend_options 3410 rpc_backend_options.rpc_timeout = timeout 3411 3412 rpc.init_rpc( 3413 name=worker_name(self.rank), 3414 backend=self.rpc_backend, 3415 rank=self.rank, 3416 world_size=self.world_size, 3417 rpc_backend_options=rpc_backend_options, 3418 ) 3419 set_timeout = rpc.get_rpc_timeout() 3420 self.assertEqual(timeout, set_timeout) 3421 rpc.shutdown() 3422 3423 @dist_init 3424 def test_default_timeout_used(self): 3425 """ 3426 Tests that if no timeout is passed into rpc_async and rpc_sync, then the 3427 default timeout is used. 3428 """ 3429 dst_rank = (self.rank + 1) % self.world_size 3430 rpc._set_rpc_timeout(0.001) # 1 ms 3431 # futures should time out and be marked with an exception indicating it as such. 3432 futs = [ 3433 rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()) 3434 for _ in range(10) 3435 ] 3436 expected_error = self.get_timeout_error_regex() 3437 for fut in futs: 3438 with self.assertRaisesRegex(RuntimeError, expected_error): 3439 fut.wait() 3440 3441 # ensure that if a new timeout is set old futures don't time out but new ones do. 3442 rpc._set_rpc_timeout(200) # 200 seconds 3443 # create a longstanding RPC. 3444 fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) 3445 # now, set a short timeout. 3446 rpc._set_rpc_timeout(0.001) 3447 # fut2 should time out, fut1 should not. 3448 fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) 3449 with self.assertRaisesRegex(RuntimeError, expected_error): 3450 fut2.wait() 3451 fut1.wait() 3452 3453 # Zero timeout means infinity, so future should run to completion. 3454 rpc._set_rpc_timeout(0) 3455 rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait() 3456 3457 # reset to default timeout so shutdown messages can process cleanly. 3458 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 3459 3460 @dist_init 3461 def test_rpc_timeouts(self): 3462 # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803) 3463 dst_rank = (self.rank + 1) % self.world_size 3464 dst_worker = worker_name(dst_rank) 3465 timeout = 0.1 # 100 ms 3466 expected_error = self.get_timeout_error_regex() 3467 # Test async UDF 3468 fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout) 3469 with self.assertRaisesRegex(RuntimeError, expected_error): 3470 fut.wait() 3471 3472 # Ensure run to completion if there is no timeout and we use the default 3473 # RPC timeout. 3474 rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait() 3475 3476 # Test sync UDF 3477 with self.assertRaisesRegex(RuntimeError, expected_error): 3478 rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout) 3479 3480 # Ensure run to completion if there is no timeout and we use the default 3481 # RPC timeout. 3482 rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) 3483 3484 # If we set a default timeout for RPCs, it should be respected, though 3485 # still overridden if we pass in a different timeout to the APIs. 3486 rpc._set_rpc_timeout(0.001) 3487 fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)) 3488 with self.assertRaisesRegex(RuntimeError, expected_error): 3489 fut.wait() 3490 with self.assertRaisesRegex(RuntimeError, expected_error): 3491 rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) 3492 3493 # The RPCs should run to completion since we override the timeout. 3494 rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait() 3495 rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5) 3496 # Passing in a zero timeout should ensure that the RPC won't time out. 3497 rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait() 3498 rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0) 3499 # Reset for clean shutdown 3500 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 3501 3502 def test_dist_init_decorator(self): 3503 @dist_init(setup_rpc=False) 3504 def test_func(self): 3505 return "expected result" 3506 3507 self.assertEqual(test_func(self), "expected result") 3508 3509 @dist_init 3510 def test_func(self): 3511 return "expected result" 3512 3513 self.assertEqual(test_func(self), "expected result") 3514 3515 def test_use_rpc_pickler(self): 3516 class TestPickler: 3517 pass 3518 3519 test_pickler = TestPickler() 3520 with _use_rpc_pickler(test_pickler): 3521 self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler) 3522 self.assertTrue( 3523 torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler 3524 ) 3525 3526 @dist_init 3527 def test_wait_all(self): 3528 with _wait_all(): 3529 self.assertTrue(_thread_local_var.future_list == []) 3530 dst = worker_name((self.rank + 1) % self.world_size) 3531 fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) 3532 self.assertTrue(len(_thread_local_var.future_list) == 1) 3533 self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future)) 3534 self.assertTrue(fut.done()) 3535 self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) 3536 self.assertFalse(hasattr(_thread_local_var, "future_list")) 3537 3538 @dist_init 3539 def test_wait_all_multiple_call(self): 3540 with _wait_all(): 3541 self.assertTrue(_thread_local_var.future_list == []) 3542 dst = worker_name((self.rank + 1) % self.world_size) 3543 for i in range(20): 3544 fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1)) 3545 res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1)) 3546 self.assertEqual(res, torch.ones(i, i) + 1) 3547 self.assertEqual(fut.wait(), torch.ones(i, i) + 1) 3548 self.assertTrue(len(_thread_local_var.future_list) == 20) 3549 self.assertFalse(hasattr(_thread_local_var, "future_list")) 3550 3551 @dist_init 3552 def test_wait_all_timeout(self): 3553 expected_error = self.get_timeout_error_regex() 3554 with self.assertRaisesRegex(RuntimeError, expected_error): 3555 with _wait_all(): 3556 self.assertTrue(_thread_local_var.future_list == []) 3557 dst = worker_name((self.rank + 1) % self.world_size) 3558 timeout = 0.1 # 100 ms 3559 fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) 3560 self.assertFalse(hasattr(_thread_local_var, "future_list")) 3561 3562 @dist_init 3563 def test_wait_all_raise_in_user_func(self): 3564 with self.assertRaises(ValueError): 3565 with _wait_all(): 3566 self.assertTrue(_thread_local_var.future_list == []) 3567 dst = worker_name((self.rank + 1) % self.world_size) 3568 fut = rpc.rpc_async(dst, raise_func) 3569 self.assertFalse(hasattr(_thread_local_var, "future_list")) 3570 3571 @dist_init 3572 def test_wait_all_raise_in_body(self): 3573 with self.assertRaises(ValueError): 3574 with _wait_all(): 3575 raise_func() 3576 self.assertFalse(hasattr(_thread_local_var, "future_list")) 3577 3578 @dist_init 3579 def test_custom_exception_throw_during_reconstruction(self): 3580 """ 3581 Test that we still throw info about the remote side exception even when 3582 we cannot recreate it on client side. 3583 """ 3584 initialize_pg(self.file_init_method, self.rank, self.world_size) 3585 if self.rank != 0: 3586 exc_caught = False 3587 dst = worker_name(0) 3588 try: 3589 rpc.rpc_sync(dst, custom_raise_func, args=()) 3590 except RuntimeError as e: 3591 exc_caught = True 3592 msg = str(e) 3593 print(f"Got msg {msg}") 3594 self.assertTrue("Original exception on remote side was" in msg) 3595 self.assertTrue("CustomException" in msg) 3596 except BaseException as e: 3597 raise RuntimeError( 3598 f"Failure - expected RuntimeError, got {e}" 3599 ) from e 3600 finally: 3601 self.assertTrue(exc_caught) 3602 3603 dist.barrier() 3604 3605 3606 timed_out_rpc_event = None 3607 3608 @staticmethod 3609 def timed_out_rpc(): 3610 RpcTest.timed_out_rpc_event.wait() 3611 3612 @dist_init 3613 def test_wait_all_exit_early_python(self): 3614 # Initialize the event in the subprocess. 3615 RpcTest.timed_out_rpc_event = Event() 3616 3617 # Wait for all processes to initialize event. 3618 initialize_pg(self.file_init_method, self.rank, self.world_size) 3619 dist.barrier() 3620 3621 dst = worker_name((self.rank + 1) % self.world_size) 3622 fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) 3623 fut2 = rpc.rpc_async(dst, raise_func) 3624 fut3 = rpc.rpc_async(dst, raise_func) 3625 3626 # We should receive the error from fut2 3627 with self.assertRaisesRegex(ValueError, expected_err): 3628 torch.futures.wait_all([fut1, fut2, fut3]) 3629 3630 # Unblock RPC thread for fut1 3631 RpcTest.timed_out_rpc_event.set() 3632 3633 @dist_init 3634 def test_wait_all_exit_early_builtin(self): 3635 # Initialize the event in the subprocess. 3636 RpcTest.timed_out_rpc_event = Event() 3637 3638 # Wait for all processes to initialize event. 3639 initialize_pg(self.file_init_method, self.rank, self.world_size) 3640 dist.barrier() 3641 3642 dst = worker_name((self.rank + 1) % self.world_size) 3643 fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) 3644 fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) 3645 fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5))) 3646 3647 # We should receive the error from fut2 3648 with self.assertRaisesRegex(RuntimeError, "size of tensor"): 3649 torch.futures.wait_all([fut1, fut2, fut3]) 3650 3651 # Unblock RPC thread for fut1 3652 RpcTest.timed_out_rpc_event.set() 3653 3654 @dist_init 3655 def test_wait_all_exit_early_script_function(self): 3656 # Initialize the event in the subprocess. 3657 RpcTest.timed_out_rpc_event = Event() 3658 3659 # Wait for all processes to initialize event. 3660 initialize_pg(self.file_init_method, self.rank, self.world_size) 3661 dist.barrier() 3662 3663 dst = worker_name((self.rank + 1) % self.world_size) 3664 fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc) 3665 fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) 3666 fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,)) 3667 3668 # We should receive the error from fut2 3669 with self.assertRaisesRegex(RuntimeError, expected_err): 3670 torch.futures.wait_all([fut1, fut2, fut3]) 3671 3672 # Unblock RPC thread for fut1 3673 RpcTest.timed_out_rpc_event.set() 3674 3675 3676 @dist_init 3677 def test_function_not_on_callee(self): 3678 # test that if a function does not exist on a callee, we don't crash, 3679 # instead we get an AttributeError indicating that the func does not exist. 3680 this_module = sys.modules[__name__] 3681 caller_worker = "worker0" 3682 callee_worker = "worker1" 3683 3684 if self.rank == 1: 3685 # Use delattr to remove the binding of a func on this nodes 3686 delattr(this_module, "foo_add") 3687 # notify remote end that we have removed it. 3688 rpc.rpc_sync(caller_worker, set_value, args=(self.rank,)) 3689 3690 if self.rank == 0: 3691 # func exists on caller, but not callee. 3692 # wait for remote end to remove the binding of foo_add func. 3693 wait_for_value_future() 3694 # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error. 3695 self.assertTrue(hasattr(this_module, "foo_add")) 3696 with self.assertRaisesRegex( 3697 RuntimeError, "RPC pickler does not serialize" 3698 ): 3699 rpc.rpc_sync(callee_worker, foo_add, args=()) 3700 3701 @dist_init 3702 def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self): 3703 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 3704 3705 a = MyClass(1) 3706 b = MyClass(2) 3707 3708 # This is to make Python not garbage collect a and b. 3709 a.other = b 3710 b.other = a 3711 3712 n = self.rank 3713 a.rref = rpc.remote( 3714 dst_worker_name, 3715 torch.add, 3716 args=(torch.ones(n, n), 2) 3717 ) 3718 3719 @dist_init(setup_rpc=False) 3720 def test_use_rref_after_shutdown(self): 3721 rpc.init_rpc( 3722 name="worker%d" % self.rank, 3723 backend=self.rpc_backend, 3724 rank=self.rank, 3725 world_size=self.world_size, 3726 rpc_backend_options=self.rpc_backend_options, 3727 ) 3728 n = self.rank + 1 3729 dst_rank = n % self.world_size 3730 rref = rpc.remote( 3731 worker_name(dst_rank), 3732 torch.add, 3733 args=(torch.ones(n, n), torch.ones(n, n)), 3734 ) 3735 # pass in graceful=True to ensure that local UserRRefs are deleted. 3736 rpc.shutdown(graceful=True) 3737 3738 with self.assertRaisesRegex( 3739 RuntimeError, "Cannot call to_here\\(\\) on it after deletion." 3740 ): 3741 rref.to_here() 3742 3743 with self.assertRaisesRegex( 3744 RuntimeError, "Cannot call fork an UserRRef after deletion." 3745 ): 3746 import torch.distributed.rpc.internal as internal 3747 internal.serialize(rref) 3748 3749 @staticmethod 3750 def _return_gpu_tensor(): 3751 return torch.rand(3, 3).cuda(0) 3752 3753 @staticmethod 3754 def _return_gpu_tensor_list(): 3755 return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)] 3756 3757 @staticmethod 3758 def _gpu_tensor_list_arg(tensor_list): 3759 return torch.rand(3, 3) 3760 3761 def _create_rref(self): 3762 owner_rank = (self.rank + 2) % self.world_size 3763 return rpc.remote( 3764 worker_name(owner_rank), 3765 torch.add, 3766 args=(torch.zeros(2, 2), 1) 3767 ) 3768 3769 @dist_init 3770 def test_user_rrefs_confirmed(self): 3771 dst_rank = (self.rank + 1) % self.world_size 3772 rref = self._create_rref() 3773 ret = rpc.rpc_sync( 3774 worker_name(dst_rank), 3775 check_rref_confirmed, 3776 args=(rref,) 3777 ) 3778 self.assertEqual(ret, True) 3779 3780 @dist_init 3781 def test_user_rrefs_confirmed_remote(self): 3782 dst_rank = (self.rank + 1) % self.world_size 3783 rref = self._create_rref() 3784 ret_rref = rpc.remote( 3785 worker_name(dst_rank), 3786 check_rref_confirmed, 3787 args=(rref,) 3788 ) 3789 self.assertEqual(ret_rref.to_here(), True) 3790 3791 @dist_init 3792 def test_rref_py_pickle_not_supported(self): 3793 local_rref = RRef(35) 3794 with TemporaryFileName() as fname: 3795 with self.assertRaisesRegex(RuntimeError, "Can not pickle rref in python pickler"): 3796 torch.save(local_rref, fname) 3797 3798 @dist_init 3799 def test_remote_throw(self): 3800 rref = rpc.remote(worker_name((self.rank + 1) % self.world_size), 3801 raise_or_inc, 3802 args=(torch.ones(2),)) 3803 with self.assertRaisesRegex(Exception, ".*Expected error.*"): 3804 rref.to_here() 3805 3806 @dist_init 3807 def test_non_cont_tensors(self): 3808 if self.rank == 0: 3809 # Create a non-contiguous tensor. 3810 t = torch.rand(5, 5) 3811 t_view = t.narrow(1, 2, 2) 3812 self.assertFalse(t_view.is_contiguous()) 3813 t_cont = t_view.contiguous() 3814 self.assertTrue(t_cont.is_contiguous()) 3815 self.assertEqual(t_view, t_cont) 3816 3817 # Send non-cont tensor over RPC. 3818 next_rank = (self.rank + 1) % self.world_size 3819 t_ret = rpc.rpc_sync(worker_name(next_rank), non_cont_test, args=(t_view, t_cont)) 3820 3821 # Verify the returned tensor. 3822 self.assertEqual(t_view, t_ret) 3823 self.assertFalse(t_ret.is_contiguous()) 3824 3825 @dist_init 3826 def test_callback_simple(self): 3827 set_by_cb = concurrent.futures.Future() 3828 n = self.rank + 1 3829 3830 def callback(fut): 3831 ret = fut.wait() 3832 self.assertEqual(ret, torch.ones(n, n) * 2) 3833 set_by_cb.set_result(ret.clone() + 1) 3834 3835 fut = rpc.rpc_async( 3836 worker_name(n % self.world_size), 3837 torch.add, 3838 args=(torch.ones(n, n), torch.ones(n, n)) 3839 ) 3840 3841 fut.then(callback) 3842 3843 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 3844 self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1) 3845 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 3846 3847 @dist_init 3848 def test_callback_wrong_arg_num(self): 3849 set_by_cb = concurrent.futures.Future() 3850 n = self.rank + 1 3851 3852 fut = rpc.rpc_async( 3853 worker_name(n % self.world_size), 3854 torch.add, 3855 args=(torch.ones(n, n), torch.ones(n, n)) 3856 ) 3857 3858 cb_fut = fut.then(my_function) 3859 3860 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 3861 3862 with self.assertRaisesRegex( 3863 RuntimeError, 3864 "my\\_function\\(\\) missing 2 required positional arguments" 3865 ): 3866 cb_fut.wait() 3867 3868 @dist_init 3869 def test_callback_wrong_arg_type(self): 3870 dst = worker_name((self.rank + 1) % self.world_size) 3871 3872 fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)) 3873 fut1 = fut0.then(lambda x: x + 1) 3874 3875 with self.assertRaisesRegex( 3876 RuntimeError, 3877 "unsupported operand type\\(s\\) for \\+" 3878 ): 3879 fut1.wait() 3880 3881 @dist_init 3882 def test_callback_multi(self): 3883 num_cbs = 10 3884 n = self.rank + 1 3885 3886 def callback(idx, fut): 3887 ret = fut.wait() 3888 self.assertEqual(ret, torch.ones(n, n) * 2) 3889 return ret + idx 3890 3891 fut = rpc.rpc_async( 3892 worker_name(n % self.world_size), 3893 torch.add, 3894 args=(torch.ones(n, n), torch.ones(n, n)) 3895 ) 3896 3897 cb_futs = [] 3898 for idx in range(num_cbs): 3899 cb_futs.append(fut.then(partial(callback, idx))) 3900 3901 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 3902 3903 for idx in range(num_cbs): 3904 self.assertEqual( 3905 cb_futs[idx].wait(), 3906 torch.ones(n, n) * 2 + idx 3907 ) 3908 3909 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 3910 3911 @dist_init 3912 def test_callback_chain(self): 3913 n = self.rank + 1 3914 dst = worker_name(n % self.world_size) 3915 3916 def callback(fut): 3917 return fut.wait() + 1 3918 3919 fut = rpc.rpc_async( 3920 worker_name(n % self.world_size), 3921 torch.add, 3922 args=(torch.ones(n, n), 1) 3923 ) 3924 3925 num_cbs = 20 3926 for _ in range(num_cbs): 3927 fut = fut.then(callback) 3928 3929 self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) 3930 3931 @dist_init 3932 def test_callback_in_rpc(self): 3933 dst1 = worker_name((self.rank + 1) % self.world_size) 3934 dst2 = worker_name((self.rank + 2) % self.world_size) 3935 3936 ret = rpc.rpc_sync( 3937 dst1, 3938 add_use_future_cb, 3939 args=(dst2, torch.ones(2, 2), 1, 2) 3940 ) 3941 self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) 3942 3943 @dist_init 3944 def test_callback_with_ret(self): 3945 dst = worker_name((self.rank + 1) % self.world_size) 3946 3947 def callback(fut0): 3948 fut2 = rpc.rpc_async( 3949 dst, 3950 torch.add, 3951 args=(fut0.wait(), 1) 3952 ).then(lambda fut1: fut1.wait() + 1) 3953 3954 return fut2.wait() 3955 3956 fut3 = rpc.rpc_async( 3957 dst, 3958 torch.add, 3959 args=(torch.ones(2, 2), 1) 3960 ).then(callback) 3961 3962 self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3) 3963 3964 @dist_init 3965 def test_callback_with_error(self): 3966 dst = worker_name((self.rank + 1) % self.world_size) 3967 3968 def callback(fut0): 3969 with self.assertRaisesRegex(ValueError, "Expected error"): 3970 fut0.wait() 3971 raise RuntimeError("Another expected error") 3972 3973 fut1 = rpc.rpc_async(dst, raise_func).then(callback) 3974 with self.assertRaisesRegex(RuntimeError, "Another expected error"): 3975 fut1.wait() 3976 3977 @dist_init 3978 def test_callback_none(self): 3979 dst = worker_name((self.rank + 1) % self.world_size) 3980 with self.assertRaisesRegex( 3981 TypeError, 3982 "incompatible function arguments." 3983 ): 3984 rpc.rpc_async(dst, raise_func).then(None) 3985 3986 @dist_init 3987 def test_add_done_callback(self): 3988 set_by_cb = False 3989 n = self.rank + 1 3990 3991 def callback(fut): 3992 nonlocal set_by_cb 3993 fut.wait() 3994 set_by_cb = True 3995 3996 fut = rpc.rpc_async( 3997 worker_name(n % self.world_size), 3998 torch.add, 3999 args=(torch.ones(n, n), torch.ones(n, n)) 4000 ) 4001 4002 fut.add_done_callback(callback) 4003 fut_then = fut.then(lambda _: True) 4004 4005 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 4006 4007 # We have no guarantee that the add_done_callback fn will execute before the test finishes. 4008 # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback 4009 fut_then.wait() 4010 self.assertTrue(set_by_cb) 4011 self.assertEqual(fut.wait(), torch.ones(n, n) * 2) 4012 4013 @dist_init 4014 def test_mark_future_twice(self): 4015 fut = rpc.rpc_async( 4016 worker_name((self.rank + 1) % self.world_size), 4017 torch.add, 4018 args=(torch.zeros(2, 2), 1) 4019 ) 4020 self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1) 4021 with self.assertRaisesRegex( 4022 RuntimeError, 4023 "Future can only be marked completed once" 4024 ): 4025 fut.set_result(1) 4026 4027 @dist_init 4028 def test_pickle_future(self): 4029 fut = torch.futures.Future() 4030 errMsg = "Can not pickle torch.futures.Future" 4031 4032 dst = worker_name((self.rank + 1) % self.world_size) 4033 with TemporaryFileName() as fname: 4034 with self.assertRaisesRegex(RuntimeError, errMsg): 4035 rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) 4036 4037 with TemporaryFileName() as fname: 4038 with self.assertRaisesRegex(RuntimeError, errMsg): 4039 rpc.rpc_async(dst, fail_on_fut, args=(fut,)) 4040 4041 with TemporaryFileName() as fname: 4042 with self.assertRaisesRegex(RuntimeError, errMsg): 4043 rpc.remote(dst, fail_on_fut, args=(fut,)) 4044 4045 @dist_init 4046 def test_future_done(self): 4047 dst = worker_name((self.rank + 1) % self.world_size) 4048 fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1)) 4049 fut.wait() 4050 self.assertTrue(fut.done()) 4051 4052 @dist_init 4053 def test_future_done_exception(self): 4054 dst = worker_name((self.rank + 1) % self.world_size) 4055 fut = rpc.rpc_async(dst, raise_func) 4056 with self.assertRaisesRegex(ValueError, "Expected error"): 4057 fut.wait() 4058 self.assertTrue(fut.done()) 4059 4060 def _test_future_cb(self, func): 4061 dst1 = worker_name((self.rank + 1) % self.world_size) 4062 dst2 = worker_name((self.rank + 2) % self.world_size) 4063 4064 ret = rpc.rpc_sync( 4065 dst1, 4066 func, 4067 args=(dst2, torch.ones(2, 2), 1, 2) 4068 ) 4069 self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) 4070 4071 @dist_init 4072 def test_future_in_rpc(self): 4073 self._test_future_cb(add_use_future_set_result) 4074 4075 @dist_init 4076 def test_future_nested_callback(self): 4077 self._test_future_cb(add_use_future_nested_cb) 4078 4079 def _test_async_function_raise(self, mode): 4080 with self.assertRaisesRegex(RuntimeError, "Expected error"): 4081 self._run_func_in_mode( 4082 worker_name((self.rank + 1) % self.world_size), 4083 async_raise_func, 4084 mode 4085 ) 4086 4087 @dist_init 4088 def test_async_function_raise(self): 4089 self._test_async_function_raise(RPCExecMode.SYNC) 4090 4091 @dist_init 4092 def test_async_function_raise_async(self): 4093 self._test_async_function_raise(RPCExecMode.ASYNC) 4094 4095 @dist_init 4096 def test_async_function_raise_remote(self): 4097 self._test_async_function_raise(RPCExecMode.REMOTE) 4098 4099 def _test_async_function_wrong_return_type(self, mode): 4100 errMsg = ( 4101 "Functions decorated with @rpc\\.async_function must return a " 4102 "torch\\.futures\\.Future object," 4103 ) 4104 with self.assertRaisesRegex(RuntimeError, errMsg): 4105 self._run_func_in_mode( 4106 worker_name((self.rank + 1) % self.world_size), 4107 async_wrong_type, 4108 mode 4109 ) 4110 4111 @dist_init 4112 def test_async_function_wrong_return_type(self): 4113 self._test_async_function_wrong_return_type(RPCExecMode.SYNC) 4114 4115 @dist_init 4116 def test_async_function_wrong_return_type_async(self): 4117 self._test_async_function_wrong_return_type(RPCExecMode.ASYNC) 4118 4119 @dist_init 4120 def test_async_function_wrong_return_type_remote(self): 4121 self._test_async_function_wrong_return_type(RPCExecMode.REMOTE) 4122 4123 @dist_init 4124 def test_async_function_simple(self): 4125 dst1 = worker_name((self.rank + 1) % self.world_size) 4126 dst2 = worker_name((self.rank + 2) % self.world_size) 4127 4128 ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1)) 4129 self.assertEqual(ret, torch.ones(2, 2) + 1) 4130 4131 def _test_async_function(self, fn, mode=RPCExecMode.SYNC): 4132 dst1 = worker_name((self.rank + 1) % self.world_size) 4133 dst2 = worker_name((self.rank + 2) % self.world_size) 4134 4135 args = (dst2, torch.ones(2, 2), 1, 2) 4136 ret = self._run_func_in_mode(dst1, fn, mode, args=args) 4137 self.assertEqual(ret, torch.ones(2, 2) + 3) 4138 4139 @dist_init 4140 def test_async_function_with_future_ctor(self): 4141 self._test_async_function(async_add_with_future_ctor) 4142 4143 @dist_init 4144 def test_async_function_with_future_ctor_remote(self): 4145 self._test_async_function( 4146 async_add_with_future_ctor, 4147 RPCExecMode.REMOTE 4148 ) 4149 4150 @dist_init 4151 def test_async_function_chained(self): 4152 self._test_async_function(async_add_chained) 4153 4154 @dist_init 4155 def test_async_function_chained_remote(self): 4156 self._test_async_function(async_add_chained, RPCExecMode.REMOTE) 4157 4158 @dist_init 4159 def test_async_function_nested(self): 4160 self._test_async_function(async_add_nested) 4161 4162 @dist_init 4163 def test_async_function_nested_remote(self): 4164 self._test_async_function(async_add_nested, RPCExecMode.REMOTE) 4165 4166 @dist_init 4167 def test_async_static_method(self): 4168 self._test_async_function(AsyncExecutionClass.static_async_add) 4169 4170 @dist_init 4171 def test_async_static_method_remote(self): 4172 self._test_async_function( 4173 AsyncExecutionClass.static_async_add, 4174 RPCExecMode.REMOTE 4175 ) 4176 4177 @dist_init 4178 def test_async_class_method(self): 4179 self._test_async_function(AsyncExecutionClass.class_async_add) 4180 4181 @dist_init 4182 def test_async_class_method_remote(self): 4183 self._test_async_function( 4184 AsyncExecutionClass.class_async_add, 4185 RPCExecMode.REMOTE 4186 ) 4187 4188 def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC): 4189 dst1 = worker_name((self.rank + 1) % self.world_size) 4190 dst2 = worker_name((self.rank + 2) % self.world_size) 4191 rref = rpc.remote(dst1, AsyncExecutionClass) 4192 4193 x = torch.ones(2, 2) 4194 y = torch.ones(2, 2) + 1 4195 if mode == RPCExecMode.SYNC: 4196 ret = rref.rpc_sync().static_async_add(dst2, x, x, y) 4197 ret += rref.rpc_sync().class_async_add(dst2, x, x, y) 4198 ret += rref.rpc_sync().bound_async_add(dst2, x, x, y) 4199 elif mode == RPCExecMode.ASYNC: 4200 ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait() 4201 ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait() 4202 ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait() 4203 elif mode == RPCExecMode.REMOTE: 4204 ret = rref.remote().static_async_add(dst2, x, x, y).to_here() 4205 ret += rref.remote().class_async_add(dst2, x, x, y).to_here() 4206 ret += rref.remote().bound_async_add(dst2, x, x, y).to_here() 4207 4208 self.assertEqual(ret, 3 * 4 * x) 4209 4210 @dist_init 4211 def test_async_class_rref_proxy(self): 4212 self._test_test_async_class_rref_proxy() 4213 4214 @dist_init 4215 def test_async_class_rref_proxy_async(self): 4216 self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC) 4217 4218 @dist_init 4219 def test_async_class_rref_proxy_remote(self): 4220 self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE) 4221 4222 def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC): 4223 dst1 = worker_name((self.rank + 1) % self.world_size) 4224 dst2 = worker_name((self.rank + 2) % self.world_size) 4225 4226 num = 20 4227 step = 3 4228 args = (dst2, torch.ones(2, 2), num, step) 4229 ret = self._run_func_in_mode(dst1, fn, mode, args=args) 4230 self.assertEqual(ret, torch.ones(2, 2) + num * step) 4231 4232 @dist_init 4233 def test_async_function_multi_chained(self): 4234 self._test_async_function_multi(async_add_chained_multi) 4235 4236 @dist_init 4237 def test_async_function_multi_chained_async(self): 4238 self._test_async_function_multi( 4239 async_add_chained_multi, 4240 RPCExecMode.ASYNC 4241 ) 4242 4243 @dist_init 4244 def test_async_function_multi_chained_remote(self): 4245 self._test_async_function_multi( 4246 async_add_chained_multi, 4247 RPCExecMode.REMOTE 4248 ) 4249 4250 @dist_init 4251 def test_async_function_multi_fanout(self): 4252 self._test_async_function_multi(async_add_multi_fanout) 4253 4254 @dist_init 4255 def test_async_function_multi_fanout_async(self): 4256 self._test_async_function_multi( 4257 async_add_multi_fanout, 4258 RPCExecMode.ASYNC 4259 ) 4260 4261 @dist_init 4262 def test_async_function_multi_fanout_remote(self): 4263 self._test_async_function_multi( 4264 async_add_multi_fanout, 4265 RPCExecMode.REMOTE 4266 ) 4267 4268 def _test_return_future(self, mode): 4269 with self.assertRaisesRegex( 4270 RuntimeError, 4271 "Can not pickle torch.futures.Future" 4272 ): 4273 self._run_func_in_mode( 4274 worker_name((self.rank + 1) % self.world_size), 4275 return_future, 4276 mode 4277 ) 4278 4279 @dist_init 4280 def test_return_future(self): 4281 self._test_return_future(RPCExecMode.SYNC) 4282 4283 @dist_init 4284 def test_return_future_async(self): 4285 self._test_return_future(RPCExecMode.ASYNC) 4286 4287 @dist_init 4288 def test_return_future_remote(self): 4289 self._test_return_future(RPCExecMode.REMOTE) 4290 4291 @dist_init 4292 def test_rref_timeout(self): 4293 # This test is similar to ones in FaultyProcessGroupTest, but is meant to be 4294 # run with other backends besides ProcessGroup. 4295 if self.rank != 0: 4296 return 4297 4298 dst_rank = (self.rank + 1) % self.world_size 4299 dst_worker = f"worker{dst_rank}" 4300 # 10 ms timeout 4301 rref = rpc.remote(dst_worker, my_sleep_func, args=(2, ), timeout=0.01) 4302 # Future corresponding to the remote creation should time out. 4303 expected_error = self.get_timeout_error_regex() 4304 with self.assertRaisesRegex(RuntimeError, expected_error): 4305 rref._get_future().wait() 4306 # Call to ensure pending callbacks are run. 4307 wait_until_pending_futures_and_users_flushed() 4308 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 4309 rref.to_here() 4310 4311 wait_until_owners_and_forks_on_rank(1, 1, rank=1) 4312 4313 @dist_init(setup_rpc=False) 4314 @skip_but_pass_in_sandcastle_if( 4315 os.environ.get("RPC_INIT_WITH_TCP", None) == "1", 4316 "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." 4317 ) 4318 def test_init_pg_then_rpc(self): 4319 dist.init_process_group( 4320 backend="gloo", 4321 init_method=self.init_method, 4322 rank=self.rank, 4323 world_size=self.world_size, 4324 ) 4325 4326 rpc.init_rpc( 4327 name=worker_name(self.rank), 4328 backend=self.rpc_backend, 4329 rank=self.rank, 4330 world_size=self.world_size, 4331 rpc_backend_options=self.rpc_backend_options, 4332 ) 4333 4334 # Test RPC. 4335 next_rank = (self.rank + 1) % self.world_size 4336 ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) 4337 self.assertEqual(ret, torch.ones(2, 2) + 1) 4338 4339 # Test PG 4340 dist.barrier() 4341 4342 rpc.shutdown() 4343 4344 @dist_init(setup_rpc=False) 4345 @skip_but_pass_in_sandcastle_if( 4346 os.environ.get("RPC_INIT_WITH_TCP", None) == "1", 4347 "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." 4348 ) 4349 def test_init_rpc_then_pg(self): 4350 rpc.init_rpc( 4351 name=worker_name(self.rank), 4352 backend=self.rpc_backend, 4353 rank=self.rank, 4354 world_size=self.world_size, 4355 rpc_backend_options=self.rpc_backend_options, 4356 ) 4357 4358 dist.init_process_group( 4359 backend="gloo", 4360 init_method=self.init_method, 4361 rank=self.rank, 4362 world_size=self.world_size, 4363 ) 4364 4365 # Test RPC. 4366 next_rank = (self.rank + 1) % self.world_size 4367 ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) 4368 self.assertEqual(ret, torch.ones(2, 2) + 1) 4369 4370 # Test PG 4371 dist.barrier() 4372 4373 rpc.shutdown() 4374 4375 @dist_init 4376 def test_wait_all_with_exception(self): 4377 futs = [] 4378 dst = worker_name((self.rank + 1) % self.world_size) 4379 for _ in range(10): 4380 futs.append(rpc.rpc_async(dst, raise_func)) 4381 4382 with self.assertRaisesRegex(ValueError, "Expected error"): 4383 ret = torch.futures.wait_all(futs) 4384 4385 @dist_init 4386 def test_wait_all_with_partial_exception(self): 4387 futs = [] 4388 dst = worker_name((self.rank + 1) % self.world_size) 4389 for _ in range(10): 4390 futs.append(rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1))) 4391 4392 futs.append(rpc.rpc_async(dst, raise_func)) 4393 4394 with self.assertRaisesRegex(ValueError, "Expected error"): 4395 ret = torch.futures.wait_all(futs) 4396 4397 @dist_init(setup_rpc=False) 4398 @skip_but_pass_in_sandcastle_if( 4399 os.environ.get("RPC_INIT_WITH_TCP", None) == "1", 4400 "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491", 4401 ) 4402 def test_init_rpc_twice(self): 4403 initialize_pg(self.file_init_method, self.rank, self.world_size) 4404 4405 rpc.init_rpc( 4406 name=worker_name(self.rank), 4407 backend=self.rpc_backend, 4408 rank=self.rank, 4409 world_size=self.world_size, 4410 rpc_backend_options=self.rpc_backend_options, 4411 ) 4412 rpc.shutdown() 4413 4414 # Wait for all init to complete. 4415 dist.barrier() 4416 4417 # Use a different file name for the next initialization 4418 new_backend_options = self.rpc_backend_options 4419 new_backend_options.init_method += "init_2" 4420 4421 # Ensure rpc initialization works again. 4422 rpc.init_rpc( 4423 name=worker_name(self.rank), 4424 backend=self.rpc_backend, 4425 rank=self.rank, 4426 world_size=self.world_size, 4427 rpc_backend_options=new_backend_options, 4428 ) 4429 4430 # Verify RPCs work after re-init. 4431 dst = worker_name((self.rank + 1) % self.world_size) 4432 rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) 4433 rpc.rpc_sync(dst, foo_add, args=()) 4434 4435 rpc.shutdown() 4436 4437 def test_wrong_types(self): 4438 with self.assertRaisesRegex( 4439 TypeError, 4440 "Argument backend must be a member of BackendType", 4441 ): 4442 rpc.init_rpc( 4443 name=worker_name(self.rank), 4444 rank=self.rank, 4445 world_size=self.world_size, 4446 backend="TENSORPIPE", 4447 ) 4448 4449 with self.assertRaisesRegex( 4450 TypeError, 4451 "Argument rpc_backend_options must be an instance of RpcBackendOptions", 4452 ): 4453 rpc.init_rpc( 4454 name=worker_name(self.rank), 4455 rank=self.rank, 4456 world_size=self.world_size, 4457 backend=self.rpc_backend, 4458 rpc_backend_options={"init_method": self.init_method} 4459 ) 4460 4461 def test_cannot_infer_backend_from_options(self): 4462 # An exception should be raised if the backend isn't specified but 4463 # options are given which are not an instance of any of the known 4464 # agents' option classes. 4465 rpc_backend_options = FooBackendOptions(self.init_method) 4466 4467 with self.assertRaisesRegex(TypeError, "Could not infer backend for options"): 4468 rpc.init_rpc( 4469 name=worker_name(self.rank), 4470 rank=self.rank, 4471 world_size=self.world_size, 4472 # Do _not_ pass backend. 4473 rpc_backend_options=rpc_backend_options, 4474 ) 4475 4476 @dist_init 4477 def test_owner_rref_backward(self): 4478 dst = worker_name((self.rank + 1) % self.world_size) 4479 t1 = torch.rand(10, 10, requires_grad=True) 4480 rref = rpc.RRef(t1.sum() + t1.sum()) 4481 rref.backward() 4482 expected_grad = torch.ones_like(t1) * 2 4483 self.assertEqual(expected_grad, t1.grad) 4484 4485 with dist_autograd.context() as context_id: 4486 t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) 4487 rref = rpc.RRef(t2.sum()) 4488 rref.backward(context_id) 4489 self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1]) 4490 4491 # Double backward. 4492 with dist_autograd.context() as context_id: 4493 t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) 4494 rref = rpc.RRef(t2.sum()) 4495 rref.backward(context_id, retain_graph=True) 4496 rref.backward(context_id) 4497 self.assertEqual(expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]) 4498 4499 # Test errors. 4500 with self.assertRaisesRegex(RuntimeError, "tensors does not require grad and does not have a grad_fn"): 4501 rpc.RRef(torch.rand(10)).backward() 4502 4503 with self.assertRaisesRegex(RuntimeError, "grad can be implicitly created only for scalar outputs"): 4504 rpc.RRef(torch.rand(10, requires_grad=True)).backward() 4505 4506 with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id: 100"): 4507 rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100) 4508 4509 with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): 4510 rpc.RRef("foo").backward() 4511 4512 @staticmethod 4513 def _sum(x): 4514 return x.sum() 4515 4516 @staticmethod 4517 def _identity(x): 4518 return x 4519 4520 @dist_init 4521 def test_user_rref_backward(self): 4522 dst = worker_name((self.rank + 1) % self.world_size) 4523 t = torch.rand(10, requires_grad=True) 4524 with dist_autograd.context() as context_id: 4525 rref = rpc.remote(dst, RpcTest._sum, args=(t,)) 4526 rref.backward(context_id, retain_graph=True) 4527 rref.backward(context_id) 4528 self.assertEqual(torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]) 4529 4530 with dist_autograd.context() as context_id: 4531 rref = rpc.remote(dst, RpcTest._identity, args=("foo",)) 4532 with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): 4533 rref.backward(context_id) 4534 4535 with self.assertRaisesRegex(RuntimeError, "User RRefs require 'dist_autograd_ctx_id' to be specified"): 4536 rref.backward() 4537 4538 @dist_init(setup_rpc=False) 4539 def test_shutdown_errors(self): 4540 initialize_pg(self.file_init_method, self.rank, self.world_size) 4541 4542 rpc.init_rpc( 4543 name=worker_name(self.rank), 4544 backend=self.rpc_backend, 4545 rank=self.rank, 4546 world_size=self.world_size, 4547 rpc_backend_options=self.rpc_backend_options, 4548 ) 4549 4550 if self.rank != 0: 4551 og_func = rpc.api._broadcast_to_followers 4552 og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs 4553 4554 # Monkey-patch _broadcast_to_followers to fail, which would ensure 4555 # _all_gather on leader raises an exception. 4556 def raise_error(sequence_id, objects_map): 4557 og_func(sequence_id, objects_map) 4558 raise RuntimeError('simulation') 4559 4560 # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail, 4561 # which would ensure barrier is not called on followers. 4562 def rref_error(): 4563 raise RuntimeError('simulation rref') 4564 4565 try: 4566 rpc.api._broadcast_to_followers = raise_error 4567 rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error 4568 with self.assertRaisesRegex(RuntimeError, 'simulation rref'): 4569 rpc.shutdown() 4570 finally: 4571 rpc.api._broadcast_to_followers = og_func 4572 rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func 4573 else: 4574 with self.assertRaisesRegex(RuntimeError, 'timed out in _all_gather'): 4575 rpc.shutdown() 4576 4577 dist.barrier() 4578 4579 @dist_init 4580 def test_my_parameter_server(self): 4581 self._my_parameter_server(False) 4582 4583 4584class CudaRpcTest(RpcAgentTestFixture): 4585 4586 @skip_if_lt_x_gpu(2) 4587 @dist_init 4588 def test_profiler_remote_cuda(self): 4589 if self.rank != 1: 4590 return 4591 4592 dst_cuda_0 = (self.rank + 1) % self.world_size 4593 dst_cuda_1 = (self.rank + 2) % self.world_size 4594 dst_worker_cuda_0 = worker_name(dst_cuda_0) 4595 dst_worker_cuda_1 = worker_name(dst_cuda_1) 4596 4597 with _profile(use_cuda=True) as p: 4598 fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0, )) 4599 fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1, )) 4600 fut1.wait() 4601 fut2.wait() 4602 4603 def get_name(event): 4604 return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR):] 4605 4606 function_events = p.function_events 4607 for event in function_events: 4608 if event.is_async: 4609 self.assertEqual(0, event.device_time_total) 4610 self.assertEqual([], event.kernels) 4611 self.assertEqual(0, event.device_time) 4612 else: 4613 if event.node_id == 1: 4614 continue 4615 self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1]) 4616 if get_name(event) in EXPECTED_REMOTE_EVENTS: 4617 self.assertGreater(event.device_time_total, 0) 4618 self.assertEqual(1, len(event.kernels)) 4619 kernel = event.kernels[0] 4620 if event.node_id == dst_cuda_0: 4621 self.assertEqual(kernel.device, 0) 4622 if event.node_id == dst_cuda_1: 4623 self.assertEqual(kernel.device, 1) 4624 self.assertGreater(event.device_time, 0) 4625 4626 # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled 4627 # events. 4628 remote_events = [event for event in function_events if event.is_remote] 4629 remote_event_names = [get_name(event) for event in remote_events if get_name(event) in EXPECTED_REMOTE_EVENTS] 4630 self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS)) 4631 4632 4633class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon): 4634 4635 def test_mismatched_type_for_options(self): 4636 # An exception should be raised if the options are not an instance of 4637 # TensorPipeRpcBackendOptions. 4638 rpc_backend_options = FooBackendOptions(self.init_method) 4639 4640 with self.assertRaisesRegex( 4641 TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`" 4642 ): 4643 rpc.init_rpc( 4644 name=worker_name(self.rank), 4645 rank=self.rank, 4646 world_size=self.world_size, 4647 backend=rpc.BackendType.TENSORPIPE, 4648 rpc_backend_options=rpc_backend_options, 4649 ) 4650 4651 def test_infer_backend_from_options(self): 4652 rpc_backend_options = rpc.TensorPipeRpcBackendOptions( 4653 init_method=self.init_method, 4654 _transports=tp_transports() 4655 ) 4656 4657 rpc.init_rpc( 4658 name=worker_name(self.rank), 4659 rank=self.rank, 4660 world_size=self.world_size, 4661 # Do _not_ pass backend. 4662 rpc_backend_options=rpc_backend_options, 4663 ) 4664 4665 self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent) 4666 4667 # FIXME Merge this test with the corresponding one in RpcTest. 4668 @dist_init(setup_rpc=False) 4669 def test_set_and_get_num_worker_threads(self): 4670 NUM_THREADS = 27 4671 rpc_backend_options = rpc.TensorPipeRpcBackendOptions( 4672 init_method=self.rpc_backend_options.init_method, 4673 num_worker_threads=NUM_THREADS, 4674 _transports=tp_transports(), 4675 ) 4676 rpc.init_rpc( 4677 name=worker_name(self.rank), 4678 backend=self.rpc_backend, 4679 rank=self.rank, 4680 world_size=self.world_size, 4681 rpc_backend_options=rpc_backend_options, 4682 ) 4683 4684 info = rpc.api._get_current_rpc_agent().get_debug_info() 4685 self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS) 4686 rpc.shutdown() 4687 4688 # FIXME Merge this test with the corresponding one in RpcTest. 4689 @dist_init(setup_rpc=False) 4690 def test_tensorpipe_set_default_timeout(self): 4691 # Set a high timeout since it doesn't affect test runtime and ensures 4692 # the test doesn't erroneously timeout due to slow machines. 4693 timeout = 100 4694 rpc_backend_options = rpc.TensorPipeRpcBackendOptions( 4695 init_method=self.rpc_backend_options.init_method, 4696 num_worker_threads=self.rpc_backend_options.num_worker_threads, 4697 rpc_timeout=timeout, 4698 _transports=tp_transports(), 4699 ) 4700 rpc.init_rpc( 4701 name=worker_name(self.rank), 4702 backend=self.rpc_backend, 4703 rank=self.rank, 4704 world_size=self.world_size, 4705 rpc_backend_options=rpc_backend_options, 4706 ) 4707 4708 default_timeout = rpc.get_rpc_timeout() 4709 self.assertEqual(default_timeout, timeout) 4710 rpc.shutdown() 4711 4712 # FIXME Merge this test with the corresponding one in RpcTest. 4713 @dist_init(setup_rpc=False) 4714 def test_tensorpipe_options_throw_on_timedelta_timeout(self): 4715 from datetime import timedelta 4716 4717 timeout = timedelta() 4718 # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails 4719 with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"): 4720 rpc_backend_options = rpc.TensorPipeRpcBackendOptions( 4721 init_method=self.rpc_backend_options.init_method, 4722 num_worker_threads=self.rpc_backend_options.num_worker_threads, 4723 rpc_timeout=timeout, 4724 ) 4725 4726 @dist_init 4727 def _test_rref_get_type_timeout(self, blocking): 4728 # Test where we try to get the type of a RRef from an owner, but RRef 4729 # creation is slower than timeout passed into _get_type. 4730 dst_rank = (self.rank + 1) % self.world_size 4731 dst = worker_name(dst_rank) 4732 slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) 4733 timeout = 0.5 4734 expected_err = self.get_timeout_error_regex() 4735 # Blocking: blocks on inline call 4736 if blocking: 4737 with self.assertRaisesRegex(RuntimeError, expected_err): 4738 slow_rref._get_type(timeout=timeout, blocking=blocking) 4739 # Non-blocking: blocks on wait 4740 else: 4741 fut = slow_rref._get_type(timeout=timeout, blocking=blocking) 4742 with self.assertRaisesRegex(RuntimeError, expected_err): 4743 fut.wait() 4744 4745 # FIXME We wait until the remote completed creating the OwnerRRef 4746 # because there's currently a race if we shut down RPC before that. 4747 slow_rref.to_here() 4748 4749 def test_rref_get_type_timeout_blocking(self): 4750 self._test_rref_get_type_timeout(blocking=True) 4751 4752 def test_rref_get_type_timeout_non_blocking(self): 4753 self._test_rref_get_type_timeout(blocking=False) 4754 4755 @dist_init 4756 def test_op_with_invalid_args(self): 4757 dst = worker_name((self.rank + 1) % self.world_size) 4758 with self.assertRaisesRegex( 4759 RuntimeError, "Overloaded torch operator invoked from Python failed to match any schema" 4760 ): 4761 rpc.rpc_sync(dst, torch.add, args=()) 4762 4763 def _test_rref_proxy_timeout(self, rref_proxy_api): 4764 dst_rank = (self.rank + 1) % self.world_size 4765 dst = worker_name(dst_rank) 4766 rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), )) 4767 # Ensure RRef is created on remote node. 4768 rref.to_here() 4769 rref_api = getattr(rref, rref_proxy_api) 4770 self.assertTrue(rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}") 4771 expected_error = self.get_timeout_error_regex() 4772 timeout = 2 4773 with self.assertRaisesRegex(RuntimeError, expected_error): 4774 result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2)) 4775 if rref_api == rref.rpc_async: 4776 result.wait() 4777 elif rref_api == rref.remote: 4778 result._get_future().wait() 4779 4780 # Case where rpc.remote() is stuck and exceeds timeout 4781 slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) 4782 timeout = 0.01 4783 rref_api = getattr(slow_rref, rref_proxy_api) 4784 # Note that even when we call rref.rpc_async() in this case, we 4785 # time out in future creation, not waiting for future. This is because 4786 # rref proxy function calls rref._get_type before returning future, 4787 # which blocks on the RRef being created on owner node, until the 4788 # specified timeout. 4789 with self.assertRaisesRegex(RuntimeError, expected_error): 4790 result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2)) 4791 # rpc_async returns immediately and surface a timeout through wait() 4792 if rref_api == slow_rref.rpc_async: 4793 result.wait() 4794 4795 # FIXME We wait until the remote completed creating the OwnerRRef 4796 # because there's currently a race if we shut down RPC before that. 4797 slow_rref.to_here() 4798 4799 @dist_init 4800 def test_rref_proxy_timeout(self): 4801 for rpc_api in ["rpc_sync", "rpc_async", "remote"]: 4802 self._test_rref_proxy_timeout(rpc_api) 4803 4804 @dist_init 4805 def test_send_to_rank_sparse(self): 4806 dst_rank = (self.rank + 1) % self.world_size 4807 4808 # Test sparse tensor 4809 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 4810 x = build_sparse_tensor() 4811 y = build_sparse_tensor() 4812 expected_tensor = (x + y) 4813 ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) 4814 self.assertEqual(expected_tensor, ret) 4815 4816 for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: 4817 x = build_sparse_tensor(coalesce=True) 4818 y = build_sparse_tensor(coalesce=True) 4819 expected_tensor = (x + y) 4820 ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) 4821 self.assertEqual(expected_tensor, ret) 4822 4823 @dist_init 4824 def test_self_py_udf_remote_sparse(self): 4825 self._self_py_udf_remote( 4826 rpc.get_worker_info(), 4827 build_sparse_tensor(), 4828 build_sparse_tensor(), 4829 build_sparse_tensor() 4830 ) 4831 4832 @dist_init 4833 def test_self_remote_rref_as_rpc_arg_sparse(self): 4834 dst = worker_name((self.rank + 1) % self.world_size) 4835 self._self_remote_rref_as_rpc_arg( 4836 dst, 4837 build_sparse_tensor(), 4838 build_sparse_tensor(), 4839 build_sparse_tensor() 4840 ) 4841 4842 @dist_init 4843 def test_self_remote_rref_as_self_rpc_arg_sparse(self): 4844 self._self_remote_rref_as_rpc_arg( 4845 rpc.get_worker_info(), 4846 build_sparse_tensor(), 4847 build_sparse_tensor(), 4848 build_sparse_tensor() 4849 ) 4850 4851 @dist_init 4852 def test_self_remote_rref_as_remote_arg_sparse(self): 4853 dst = worker_name((self.rank + 1) % self.world_size) 4854 self._self_remote_rref_as_remote_arg( 4855 dst, 4856 build_sparse_tensor(), 4857 build_sparse_tensor(), 4858 build_sparse_tensor() 4859 ) 4860 4861 @dist_init 4862 def test_self_remote_rref_as_self_remote_arg_sparse(self): 4863 self._self_remote_rref_as_remote_arg( 4864 rpc.get_worker_info(), 4865 build_sparse_tensor(), 4866 build_sparse_tensor(), 4867 build_sparse_tensor() 4868 ) 4869 4870 def test_world_size_one_sparse(self): 4871 self._world_size_one( 4872 build_sparse_tensor(), 4873 build_sparse_tensor() 4874 ) 4875 4876 @dist_init 4877 def test_multi_rpc_sparse(self): 4878 self._multi_rpc(True) 4879 4880 def test_wait_all_workers_sparse(self): 4881 self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor()) 4882 4883 def test_wait_all_workers_twice_sparse(self): 4884 self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor()) 4885 4886 @dist_init 4887 def test_py_sparse_tensors_in_container(self): 4888 n = self.rank + 1 4889 dst_rank = n % self.world_size 4890 a = [build_sparse_tensor(), build_sparse_tensor()] 4891 ret = rpc.rpc_sync( 4892 worker_name(dst_rank), my_container_sum, args=(a,) 4893 ) 4894 self.assertEqual(ret, my_container_sum(a)) 4895 4896 @dist_init 4897 def test_nested_rpc_sparse(self): 4898 self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2) 4899 4900 @dist_init 4901 def test_stress_heavy_rpc_sparse(self): 4902 self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)) 4903 4904 @dist_init 4905 def test_builtin_remote_ret_sparse(self): 4906 self._builtin_remote_ret( 4907 build_sparse_tensor(), 4908 build_sparse_tensor(), 4909 build_sparse_tensor() * 2 4910 ) 4911 4912 @dist_init 4913 def test_builtin_remote_self_sparse(self): 4914 self._builtin_remote_self( 4915 build_sparse_tensor(), 4916 build_sparse_tensor(), 4917 build_sparse_tensor() * 2 4918 ) 4919 4920 @dist_init 4921 def test_multi_builtin_remote_ret_sparse(self): 4922 self._test_multi_remote_call( 4923 torch.add, True, 4924 args_fn=RpcTest._multi_args_fn 4925 ) 4926 4927 @dist_init 4928 def test_multi_py_udf_remote_sparse(self): 4929 self._test_multi_remote_call( 4930 my_function, 4931 True, 4932 kwargs_fn=RpcTest._multi_kwargs_fn 4933 ) 4934 4935 @dist_init 4936 def test_py_rref_args_sparse(self): 4937 self._py_rref_args( 4938 build_sparse_tensor(), 4939 build_sparse_tensor(), 4940 build_sparse_tensor(), 4941 build_sparse_tensor(), 4942 build_sparse_tensor() * 4 4943 ) 4944 4945 @dist_init 4946 def test_py_rref_args_user_share_sparse(self): 4947 self._py_rref_args_user_share( 4948 build_sparse_tensor(), 4949 build_sparse_tensor(), 4950 build_sparse_tensor(), 4951 build_sparse_tensor(), 4952 build_sparse_tensor(), 4953 build_sparse_tensor(), 4954 build_sparse_tensor() * 6 4955 ) 4956 4957 @dist_init 4958 def test_py_rpc_rref_args_sparse(self): 4959 self._py_rpc_rref_args( 4960 build_sparse_tensor(), 4961 build_sparse_tensor(), 4962 build_sparse_tensor(), 4963 build_sparse_tensor(), 4964 build_sparse_tensor(), 4965 build_sparse_tensor(), 4966 build_sparse_tensor() * 6 4967 ) 4968 4969 @dist_init 4970 def test_nested_remote_sparse(self): 4971 self._nested_remote( 4972 nested_remote_sparse, 4973 build_sparse_tensor() + build_sparse_tensor() 4974 ) 4975 4976 @dist_init 4977 def test_nested_rref_sparse(self): 4978 self._nested_rref( 4979 nested_rref_sparse, 4980 build_sparse_tensor() * 2, 4981 build_sparse_tensor() * 2 4982 ) 4983 4984 @dist_init 4985 def test_nested_rref_stress_sparse(self): 4986 self._nested_rref_stress( 4987 nested_rref_sparse, 4988 build_sparse_tensor() * 2, 4989 build_sparse_tensor() * 2 4990 ) 4991 4992 @dist_init 4993 def test_my_parameter_server_sparse(self): 4994 self._my_parameter_server(True) 4995 4996 # Test init_rpc without world_size argument 4997 @dist_init(setup_rpc=False) 4998 def test_dynamic_rpc_init_rpc(self): 4999 rpc.init_rpc( 5000 name=worker_name(self.rank), 5001 backend=self.rpc_backend, 5002 rank=self.rank, 5003 rpc_backend_options=self.rpc_backend_options, 5004 ) 5005 rpc.shutdown() 5006 5007 # Dynamic RPC new ranks communicate with existing ranks 5008 @dist_init(setup_rpc=False) 5009 def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self): 5010 initialize_pg(self.file_init_method, self.rank, self.world_size) 5011 5012 if self.rank == 0: 5013 rpc.init_rpc( 5014 name=worker_name(self.rank), 5015 backend=self.rpc_backend, 5016 rank=self.rank, 5017 rpc_backend_options=self.rpc_backend_options, 5018 ) 5019 5020 # Rank 0 will be initialized with RPC after this barrier 5021 dist.barrier() 5022 5023 if self.rank != 0: 5024 # Newly joined ranks will be able to communicate with rank 0, since that was created first 5025 rpc.init_rpc( 5026 name=worker_name(self.rank), 5027 backend=self.rpc_backend, 5028 rank=self.rank, 5029 rpc_backend_options=self.rpc_backend_options, 5030 ) 5031 result = rpc.rpc_sync(worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1))) 5032 self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) 5033 5034 # Barrier to ensure that all rpc_sync calls are finished 5035 dist.barrier() 5036 rpc.shutdown() 5037 5038 # Dynamic RPC existing ranks can communicate with new ranks 5039 @dist_init(setup_rpc=False) 5040 def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self): 5041 initialize_pg(self.file_init_method, self.rank, self.world_size) 5042 5043 if self.rank == 0: 5044 rpc.init_rpc( 5045 name=worker_name(self.rank), 5046 backend=self.rpc_backend, 5047 rank=self.rank, 5048 rpc_backend_options=self.rpc_backend_options, 5049 ) 5050 5051 # Rank 0 will be initialized with RPC after this barrier 5052 dist.barrier() 5053 5054 # Rest of ranks join after barrier 5055 if self.rank != 0: 5056 # Newly joined ranks will be able to communicate with rank 0, since that was created first 5057 rpc.init_rpc( 5058 name=worker_name(self.rank), 5059 backend=self.rpc_backend, 5060 rank=self.rank, 5061 rpc_backend_options=self.rpc_backend_options, 5062 ) 5063 5064 dist.barrier() 5065 if self.rank == 0: 5066 for i in range(1, self.world_size): 5067 result = rpc.rpc_sync(worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1))) 5068 self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result) 5069 5070 # Barrier to ensure that all rpc_sync calls are finished 5071 dist.barrier() 5072 rpc.shutdown() 5073 5074 # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc 5075 @skip_if_lt_x_gpu(2) 5076 @dist_init(setup_rpc=False) 5077 def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self): 5078 initialize_pg(self.file_init_method, self.rank, self.world_size) 5079 5080 if self.rank == 0: 5081 options = self.rpc_backend_options 5082 for i in range(1, self.world_size): 5083 dst = worker_name(i) 5084 options.set_device_map(dst, {1: 0}) 5085 options.set_device_map(dst, {0: 1}) 5086 rpc.init_rpc( 5087 name=worker_name(self.rank), 5088 backend=self.rpc_backend, 5089 rank=self.rank, 5090 rpc_backend_options=options, 5091 ) 5092 5093 # Rank 0 will be initialized with RPC after this barrier 5094 dist.barrier() 5095 5096 # Rest of ranks join after barrier 5097 if self.rank != 0: 5098 # Newly joined ranks will be able to communicate with rank 0, since that was created first 5099 rpc.init_rpc( 5100 name=worker_name(self.rank), 5101 backend=self.rpc_backend, 5102 rank=self.rank, 5103 rpc_backend_options=self.rpc_backend_options, 5104 ) 5105 5106 # TODO: Cuda RPC is failing due to: 5107 # terminate called after throwing an instance of 'c10::Error' 5108 # what(): 0 <= device && static_cast<size_t>(device) < device_allocator.size() 5109 # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937, 5110 # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init? 5111 # dist.barrier() 5112 # if self.rank == 0: 5113 # for i in range(1, self.world_size): 5114 # x = torch.ones(2) 5115 # result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1)) 5116 # result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1)) 5117 # self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0) 5118 # self.assertEqual(torch.device('cuda:0'), result_on_device_0.device) 5119 # self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1) 5120 # self.assertEqual(torch.device('cuda:1'), result_on_device_1.device) 5121 5122 # Barrier to ensure that all rpc_sync calls are finished 5123 dist.barrier() 5124 rpc.shutdown() 5125 5126 @dist_init(setup_rpc=False) 5127 def test_dynamic_rpc_init_rpc_without_rank(self): 5128 # default initialization uses file init 5129 with self.assertRaisesRegex(ValueError, "rank parameter missing"): 5130 rpc.init_rpc( 5131 name=worker_name(self.rank), 5132 backend=self.rpc_backend, 5133 rpc_backend_options=self.rpc_backend_options, 5134 ) 5135 5136 # env init 5137 with self.assertRaisesRegex(ValueError, "environment variable RANK expected"): 5138 rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://") 5139 rpc.init_rpc( 5140 name=worker_name(self.rank), 5141 backend=self.rpc_backend, 5142 rpc_backend_options=rpc_backend_options, 5143 ) 5144 5145 # tcp init 5146 with self.assertRaisesRegex(ValueError, "rank parameter missing"): 5147 rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="tcp://127.0.0.1:23456") 5148 rpc.init_rpc( 5149 name=worker_name(self.rank), 5150 backend=self.rpc_backend, 5151 rpc_backend_options=rpc_backend_options, 5152 ) 5153 5154 @dist_init(setup_rpc=False) 5155 def test_dynamic_and_static_init_rpc_together(self): 5156 # Initialize a static rpc group with size = self.world_size - 1 5157 dist.init_process_group( 5158 backend='gloo', 5159 init_method=self.file_init_method, 5160 rank=self.rank, 5161 world_size=self.world_size) 5162 5163 world_size_minus_one = self.world_size - 1 5164 if self.rank < world_size_minus_one: 5165 rpc.init_rpc( 5166 name=worker_name(self.rank), 5167 backend=self.rpc_backend, 5168 rank=self.rank, 5169 world_size=world_size_minus_one, 5170 rpc_backend_options=self.rpc_backend_options, 5171 ) 5172 5173 dist.barrier() 5174 5175 # Attempt to add an additional dynamic group member 5176 if self.rank == world_size_minus_one: 5177 # Expect error message to be thrown 5178 with self.assertRaisesRegex(RuntimeError, "RPC group mixes statically and dynamically\ 5179 initialized members which is not supported."): 5180 rpc.init_rpc( 5181 name=worker_name(self.rank), 5182 backend=self.rpc_backend, 5183 rank=self.rank, 5184 rpc_backend_options=self.rpc_backend_options, 5185 ) 5186 5187class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): 5188 5189 def _test_device_maps(self, options, errMsg): 5190 with self.assertRaisesRegex(ValueError, errMsg): 5191 rpc.init_rpc( 5192 name=worker_name(self.rank), 5193 backend=self.rpc_backend, 5194 rank=self.rank, 5195 world_size=self.world_size, 5196 rpc_backend_options=options, 5197 ) 5198 5199 self.assertFalse(rpc.api._is_current_rpc_agent_set()) 5200 5201 @skip_if_lt_x_gpu(2) 5202 def test_device_maps_wrong_worker_name(self): 5203 options = self.rpc_backend_options 5204 options.set_device_map("none_exist", {0: 1}) 5205 5206 self._test_device_maps( 5207 options, 5208 errMsg="Node worker0 has invalid target node names in its device maps" 5209 ) 5210 5211 @skip_if_lt_x_gpu(1) 5212 def test_device_maps_invalid_max_local_device(self): 5213 options = self.rpc_backend_options 5214 dst = worker_name((self.rank + 1) % self.world_size) 5215 options.set_device_map(dst, {torch.cuda.device_count(): 0}) 5216 5217 self._test_device_maps( 5218 options, 5219 errMsg="Node worker0 has source devices with invalid indices in its device map for worker1" 5220 ) 5221 5222 @skip_if_lt_x_gpu(1) 5223 def test_device_maps_invalid_max_remote_device(self): 5224 options = self.rpc_backend_options 5225 dst = worker_name((self.rank + 1) % self.world_size) 5226 options.set_device_map(dst, {0: torch.cuda.device_count()}) 5227 5228 self._test_device_maps( 5229 options, 5230 errMsg="Node worker0 has target devices with invalid indices in its device map for worker1" 5231 ) 5232 5233 @skip_if_lt_x_gpu(2) 5234 def test_device_maps_many_to_one(self): 5235 options = self.rpc_backend_options 5236 dst = worker_name((self.rank + 1) % self.world_size) 5237 options.set_device_map(dst, {1: 0}) 5238 options.set_device_map(dst, {0: 0}) 5239 5240 self._test_device_maps( 5241 options, 5242 errMsg="Node worker0 has duplicated target devices in its device map for worker1" 5243 ) 5244 5245 @skip_if_lt_x_gpu(2) 5246 def test_device_maps_one_to_many(self): 5247 if self.rank == 0: 5248 options = self.rpc_backend_options 5249 dst = worker_name((self.rank + 1) % self.world_size) 5250 options.set_device_map(dst, {0: 1}) 5251 with self.assertRaisesRegex( 5252 ValueError, "`set_device_map` only supports 1-to-1 mapping" 5253 ): 5254 options.set_device_map(dst, {0: 0}) 5255 5256 @skip_if_lt_x_gpu(1) 5257 def test_device_maps_invalid_min_device(self): 5258 options = self.rpc_backend_options 5259 dst = worker_name((self.rank + 1) % self.world_size) 5260 with self.assertRaisesRegex( 5261 RuntimeError, "Device index must not be negative" 5262 ): 5263 options.set_device_map(dst, {-1: 0}) 5264 5265 with self.assertRaisesRegex( 5266 RuntimeError, "Device index must not be negative" 5267 ): 5268 options.set_device_map(dst, {0: -1}) 5269 5270 @staticmethod 5271 def _gpu_add(x, y): 5272 if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]): 5273 return (x + y).to(0) 5274 else: 5275 raise ValueError("Wrong device affinity") 5276 5277 @skip_if_lt_x_gpu(2) 5278 def test_device_maps_gpu(self): 5279 options = self.rpc_backend_options 5280 dst = worker_name((self.rank + 1) % self.world_size) 5281 options.set_device_map(dst, {0: 1, 1: 0}) 5282 5283 rpc.init_rpc( 5284 name=worker_name(self.rank), 5285 backend=self.rpc_backend, 5286 rank=self.rank, 5287 world_size=self.world_size, 5288 rpc_backend_options=options, 5289 ) 5290 5291 ret = rpc.rpc_sync( 5292 dst, 5293 TensorPipeAgentCudaRpcTest._gpu_add, 5294 args=(torch.zeros(2).to(0), torch.ones(2).to(0)) 5295 ) 5296 self.assertEqual(ret.device, torch.device(1)) 5297 self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) 5298 rpc.shutdown() 5299 5300 @staticmethod 5301 def _gpu_add_given_devices(x, y, x_to, y_to, z_to): 5302 x_device = "cpu" if x.device.type == "cpu" else x.device.index 5303 y_device = "cpu" if y.device.type == "cpu" else y.device.index 5304 if x_device == x_to and y_device == y_to: 5305 return x.to(z_to) + y.to(z_to) 5306 else: 5307 raise ValueError("Wrong device affinity") 5308 5309 def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None, fn=None): 5310 fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn 5311 x_to = device_map[x_from] 5312 y_to = device_map[y_from] 5313 5314 options = self.rpc_backend_options 5315 dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst 5316 options.set_device_map(dst, device_map) 5317 5318 rpc.init_rpc( 5319 name=worker_name(self.rank), 5320 backend=self.rpc_backend, 5321 rank=self.rank, 5322 world_size=self.world_size, 5323 rpc_backend_options=options, 5324 ) 5325 5326 x = torch.zeros(2).to(x_from) 5327 y = torch.ones(2).to(y_from) 5328 5329 ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to)) 5330 5331 reverse_device_map = {device_map[k] : k for k in device_map} 5332 z_from = reverse_device_map[z_to] 5333 5334 ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index 5335 self.assertEqual(ret_device, z_from) 5336 self.assertEqual(ret, torch.ones(2).to(z_from)) 5337 5338 rpc.shutdown() 5339 5340 def test_device_map_cpu(self): 5341 self._test_device_maps_gpu( 5342 x_from="cpu", 5343 y_from="cpu", 5344 z_to="cpu", 5345 device_map={"cpu" : "cpu"}, 5346 fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, 5347 ) 5348 5349 @skip_if_lt_x_gpu(1) 5350 def test_device_map_cpu_to_gpu_default(self): 5351 self._test_device_maps_gpu( 5352 x_from="cpu", 5353 y_from="cpu", 5354 z_to=0, 5355 device_map={"cpu" : 0}, 5356 fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, 5357 ) 5358 5359 @skip_if_lt_x_gpu(2) 5360 def test_device_map_cpu_to_gpu_non_default(self): 5361 self._test_device_maps_gpu( 5362 x_from="cpu", 5363 y_from="cpu", 5364 z_to=1, 5365 device_map={"cpu" : 1}, 5366 fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, 5367 ) 5368 5369 @skip_if_lt_x_gpu(1) 5370 def test_device_map_gpu_to_cpu_default(self): 5371 self._test_device_maps_gpu( 5372 x_from=0, 5373 y_from=0, 5374 z_to="cpu", 5375 device_map={0 : "cpu"}, 5376 fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, 5377 ) 5378 5379 @skip_if_lt_x_gpu(2) 5380 def test_device_map_gpu_to_cpu_non_default(self): 5381 self._test_device_maps_gpu( 5382 x_from=1, 5383 y_from=1, 5384 z_to="cpu", 5385 device_map={1 : "cpu"}, 5386 fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices, 5387 ) 5388 5389 @skip_if_lt_x_gpu(2) 5390 def test_device_map_gpu_default(self): 5391 self._test_device_maps_gpu( 5392 x_from=0, 5393 y_from=0, 5394 z_to=0, 5395 device_map={0 : 0} 5396 ) 5397 5398 @skip_if_lt_x_gpu(2) 5399 def test_device_map_gpu_non_default(self): 5400 self._test_device_maps_gpu( 5401 x_from=1, 5402 y_from=1, 5403 z_to=1, 5404 device_map={1 : 1} 5405 ) 5406 5407 @skip_if_lt_x_gpu(2) 5408 def test_device_map_gpu_default_to_non_default(self): 5409 self._test_device_maps_gpu( 5410 x_from=0, 5411 y_from=0, 5412 z_to=1, 5413 device_map={0 : 1} 5414 ) 5415 5416 @skip_if_lt_x_gpu(2) 5417 def test_device_map_gpu_non_default_to_default(self): 5418 self._test_device_maps_gpu( 5419 x_from=1, 5420 y_from=1, 5421 z_to=0, 5422 device_map={1 : 0} 5423 ) 5424 5425 @skip_if_lt_x_gpu(2) 5426 def test_device_map_gpu_mixed_1(self): 5427 self._test_device_maps_gpu( 5428 x_from=0, 5429 y_from=1, 5430 z_to=0, 5431 device_map={0 : 0, 1 : 1} 5432 ) 5433 5434 @skip_if_lt_x_gpu(2) 5435 def test_device_map_gpu_mixed_2(self): 5436 self._test_device_maps_gpu( 5437 x_from=0, 5438 y_from=1, 5439 z_to=1, 5440 device_map={0 : 0, 1 : 1} 5441 ) 5442 5443 @skip_if_lt_x_gpu(2) 5444 def test_device_map_gpu_mixed_3(self): 5445 self._test_device_maps_gpu( 5446 x_from=1, 5447 y_from=0, 5448 z_to=0, 5449 device_map={0 : 0, 1 : 1} 5450 ) 5451 5452 @skip_if_lt_x_gpu(2) 5453 def test_device_map_gpu_mixed_4(self): 5454 self._test_device_maps_gpu( 5455 x_from=1, 5456 y_from=0, 5457 z_to=1, 5458 device_map={0 : 0, 1 : 1} 5459 ) 5460 5461 @skip_if_lt_x_gpu(2) 5462 def test_device_map_gpu_mixed_5(self): 5463 self._test_device_maps_gpu( 5464 x_from=0, 5465 y_from=1, 5466 z_to=0, 5467 device_map={0 : 1, 1 : 0} 5468 ) 5469 5470 @skip_if_lt_x_gpu(2) 5471 def test_device_map_gpu_mixed_6(self): 5472 self._test_device_maps_gpu( 5473 x_from=0, 5474 y_from=1, 5475 z_to=1, 5476 device_map={0 : 1, 1 : 0} 5477 ) 5478 5479 @skip_if_lt_x_gpu(2) 5480 def test_device_map_gpu_mixed_7(self): 5481 self._test_device_maps_gpu( 5482 x_from=1, 5483 y_from=0, 5484 z_to=0, 5485 device_map={0 : 1, 1 : 0} 5486 ) 5487 5488 @skip_if_lt_x_gpu(2) 5489 def test_device_map_gpu_mixed_8(self): 5490 self._test_device_maps_gpu( 5491 x_from=1, 5492 y_from=0, 5493 z_to=1, 5494 device_map={0 : 1, 1 : 0} 5495 ) 5496 5497 @skip_if_lt_x_gpu(2) 5498 def test_device_map_gpu_mixed_self_1(self): 5499 self._test_device_maps_gpu( 5500 x_from=0, 5501 y_from=1, 5502 z_to=0, 5503 device_map={0 : 0, 1 : 1}, 5504 dst=worker_name(self.rank) 5505 ) 5506 5507 @skip_if_lt_x_gpu(2) 5508 def test_device_map_gpu_mixed_self_2(self): 5509 self._test_device_maps_gpu( 5510 x_from=0, 5511 y_from=1, 5512 z_to=1, 5513 device_map={0 : 0, 1 : 1}, 5514 dst=worker_name(self.rank) 5515 ) 5516 5517 @skip_if_lt_x_gpu(2) 5518 def test_device_map_gpu_mixed_self_3(self): 5519 self._test_device_maps_gpu( 5520 x_from=1, 5521 y_from=0, 5522 z_to=0, 5523 device_map={0 : 0, 1 : 1}, 5524 dst=worker_name(self.rank) 5525 ) 5526 5527 @skip_if_lt_x_gpu(2) 5528 def test_device_map_gpu_mixed_self_4(self): 5529 self._test_device_maps_gpu( 5530 x_from=1, 5531 y_from=0, 5532 z_to=1, 5533 device_map={0 : 0, 1 : 1}, 5534 dst=worker_name(self.rank) 5535 ) 5536 5537 @skip_if_lt_x_gpu(2) 5538 def test_device_map_gpu_mixed_self_5(self): 5539 self._test_device_maps_gpu( 5540 x_from=0, 5541 y_from=1, 5542 z_to=0, 5543 device_map={0 : 1, 1 : 0}, 5544 dst=worker_name(self.rank) 5545 ) 5546 5547 @skip_if_lt_x_gpu(2) 5548 def test_device_map_gpu_mixed_self_6(self): 5549 self._test_device_maps_gpu( 5550 x_from=0, 5551 y_from=1, 5552 z_to=1, 5553 device_map={0 : 1, 1 : 0}, 5554 dst=worker_name(self.rank) 5555 ) 5556 5557 @skip_if_lt_x_gpu(2) 5558 def test_device_map_gpu_mixed_self_7(self): 5559 self._test_device_maps_gpu( 5560 x_from=1, 5561 y_from=0, 5562 z_to=0, 5563 device_map={0 : 1, 1 : 0}, 5564 dst=worker_name(self.rank) 5565 ) 5566 5567 @skip_if_lt_x_gpu(2) 5568 def test_device_map_gpu_mixed_self_8(self): 5569 self._test_device_maps_gpu( 5570 x_from=1, 5571 y_from=0, 5572 z_to=1, 5573 device_map={0 : 1, 1 : 0}, 5574 dst=worker_name(self.rank) 5575 ) 5576 5577 @staticmethod 5578 def _gpu_add_multi_gpu(x, y): 5579 if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]): 5580 return x.to(0) + y, x - y.to(1) 5581 else: 5582 raise ValueError("Wrong device affinity") 5583 5584 def _test_device_maps_multi_gpu(self, dst): 5585 options = self.rpc_backend_options 5586 options.set_device_map(dst, {0: 1}) 5587 options.set_device_map(dst, {1: 0}) 5588 5589 rpc.init_rpc( 5590 name=worker_name(self.rank), 5591 backend=self.rpc_backend, 5592 rank=self.rank, 5593 world_size=self.world_size, 5594 rpc_backend_options=options, 5595 ) 5596 5597 x = torch.zeros(2).to(0) 5598 y = torch.ones(2).to(1) 5599 rets = rpc.rpc_sync( 5600 dst, 5601 TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, 5602 args=(x, y) 5603 ) 5604 5605 self.assertEqual(rets[0].device, torch.device(1)) 5606 self.assertEqual(rets[1].device, torch.device(0)) 5607 self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) 5608 self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) 5609 rpc.shutdown() 5610 5611 @skip_if_lt_x_gpu(2) 5612 def test_device_maps_multi_gpu(self): 5613 dst = worker_name((self.rank + 1) % self.world_size) 5614 self._test_device_maps_multi_gpu(dst) 5615 5616 @skip_if_lt_x_gpu(2) 5617 def test_device_maps_multi_gpu_self(self): 5618 dst = worker_name(self.rank) 5619 self._test_device_maps_multi_gpu(dst) 5620 5621 @staticmethod 5622 def _gpu_add_return_to_gpu(x, y): 5623 if x.device.type == 'cpu' and y.device.type == 'cpu': 5624 return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3) 5625 else: 5626 raise ValueError("Wrong device affinity") 5627 5628 @skip_if_lt_x_gpu(2) 5629 def test_device_maps_in_options(self): 5630 dst = worker_name((self.rank + 1) % self.world_size) 5631 options = self.rpc_backend_options 5632 5633 rpc.init_rpc( 5634 name=worker_name(self.rank), 5635 backend=self.rpc_backend, 5636 rank=self.rank, 5637 world_size=self.world_size, 5638 rpc_backend_options=rpc.TensorPipeRpcBackendOptions( 5639 init_method=options.init_method, 5640 num_worker_threads=options.num_worker_threads, 5641 device_maps={dst: {0: 1, 1: 0}}, 5642 _transports=tp_transports() 5643 ) 5644 ) 5645 5646 rets = rpc.rpc_sync( 5647 dst, 5648 TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, 5649 args=(torch.zeros(2).to(0), torch.ones(2).to(1)) 5650 ) 5651 self.assertEqual(rets[0].device, torch.device(1)) 5652 self.assertEqual(rets[1].device, torch.device(0)) 5653 self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) 5654 self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) 5655 rpc.shutdown() 5656 5657 def _test_device_maps_return_to_gpu(self, dst): 5658 options = self.rpc_backend_options 5659 5660 options.set_device_map(dst, {0: 1}) 5661 options.set_device_map(dst, {1: 2}) 5662 options.set_device_map(dst, {2: 3}) 5663 options.set_device_map(dst, {3: 0}) 5664 5665 rpc.init_rpc( 5666 name=worker_name(self.rank), 5667 backend=self.rpc_backend, 5668 rank=self.rank, 5669 world_size=self.world_size, 5670 rpc_backend_options=options, 5671 ) 5672 5673 rets = rpc.rpc_sync( 5674 dst, 5675 TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu, 5676 args=(torch.zeros(2), torch.ones(2)) 5677 ) 5678 for i in range(len(rets)): 5679 self.assertEqual(rets[i].device, torch.device((3 + i) % 4)) 5680 self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3)) 5681 self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) 5682 self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1)) 5683 self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2)) 5684 rpc.shutdown() 5685 5686 @skip_if_lt_x_gpu(4) 5687 def test_device_maps_return_to_gpu(self): 5688 dst = worker_name((self.rank + 1) % self.world_size) 5689 self._test_device_maps_return_to_gpu(dst) 5690 5691 @skip_if_lt_x_gpu(4) 5692 def test_device_maps_return_to_gpu_self(self): 5693 dst = worker_name(self.rank) 5694 self._test_device_maps_return_to_gpu(dst) 5695 5696 @staticmethod 5697 def _add_to_gpu(x, y): 5698 return (x + y).to(0) 5699 5700 def _test_device_maps_missing_config(self, mode): 5701 dst = worker_name((self.rank + 1) % self.world_size) 5702 errMsg = ( 5703 "TensorPipe RPC backend only supports CPU tensors by default.*" 5704 "`set_device_map` on `TensorPipeRpcBackendOptions`" 5705 ) 5706 5707 with self.assertRaisesRegex(RuntimeError, errMsg): 5708 if mode == RPCExecMode.SYNC: 5709 rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1)) 5710 elif mode == RPCExecMode.REMOTE: 5711 rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here() 5712 else: 5713 raise ValueError(f"unexpected mode {mode}") 5714 5715 # make sure RPC is still functioning 5716 ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) 5717 self.assertEqual(ret, torch.ones(2) + 1) 5718 5719 def _test_device_maps_missing_config_response(self, mode): 5720 dst = worker_name((self.rank + 1) % self.world_size) 5721 errMsg = "Response device mapping is not available" 5722 5723 with self.assertRaisesRegex(RuntimeError, errMsg): 5724 if mode == RPCExecMode.SYNC: 5725 rpc.rpc_sync( 5726 dst, 5727 TensorPipeAgentCudaRpcTest._add_to_gpu, 5728 args=(torch.zeros(2), 1) 5729 ) 5730 elif mode == RPCExecMode.REMOTE: 5731 rpc.remote( 5732 dst, 5733 TensorPipeAgentCudaRpcTest._add_to_gpu, 5734 args=(torch.zeros(2), 1) 5735 ).to_here() 5736 else: 5737 raise ValueError(f"unexpected mode {mode}") 5738 5739 # make sure RPC is still functioning 5740 ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) 5741 self.assertEqual(ret, torch.ones(2) + 1) 5742 5743 @skip_if_lt_x_gpu(1) 5744 @dist_init 5745 def test_device_maps_missing_config(self): 5746 self._test_device_maps_missing_config(RPCExecMode.SYNC) 5747 5748 @skip_if_lt_x_gpu(1) 5749 def test_device_maps_missing_config_not_timeout(self): 5750 dst = worker_name((self.rank + 1) % self.world_size) 5751 options = self.rpc_backend_options 5752 5753 rpc.init_rpc( 5754 name=worker_name(self.rank), 5755 backend=self.rpc_backend, 5756 rank=self.rank, 5757 world_size=self.world_size, 5758 rpc_backend_options=self.rpc_backend_options 5759 ) 5760 5761 timeout = rpc.get_rpc_timeout() 5762 5763 tik = time.time() 5764 self._test_device_maps_missing_config(RPCExecMode.SYNC) 5765 rpc.shutdown() 5766 tok = time.time() 5767 5768 self.assertTrue(tok - tik < timeout) 5769 5770 @skip_if_lt_x_gpu(1) 5771 @dist_init 5772 def test_device_maps_missing_config_loop(self): 5773 for _ in range(self.rpc_backend_options.num_worker_threads + 5): 5774 self._test_device_maps_missing_config(RPCExecMode.SYNC) 5775 5776 @skip_if_lt_x_gpu(1) 5777 @dist_init 5778 def test_device_maps_missing_config_response(self): 5779 self._test_device_maps_missing_config_response(RPCExecMode.SYNC) 5780 5781 @skip_if_lt_x_gpu(1) 5782 @dist_init 5783 def test_device_maps_missing_config_response_loop(self): 5784 for _ in range(self.rpc_backend_options.num_worker_threads + 5): 5785 self._test_device_maps_missing_config_response(RPCExecMode.SYNC) 5786 5787 @skip_if_lt_x_gpu(1) 5788 @dist_init 5789 def test_device_maps_missing_config_remote(self): 5790 self._test_device_maps_missing_config(RPCExecMode.REMOTE) 5791 5792 @skip_if_lt_x_gpu(1) 5793 @dist_init 5794 def test_device_maps_missing_config_remote_response(self): 5795 self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) 5796 5797 @skip_if_lt_x_gpu(2) 5798 def test_device_maps_remote(self): 5799 options = self.rpc_backend_options 5800 dst = worker_name((self.rank + 1) % self.world_size) 5801 options.set_device_map(dst, {1: 0}) 5802 5803 rpc.init_rpc( 5804 name=worker_name(self.rank), 5805 backend=self.rpc_backend, 5806 rank=self.rank, 5807 world_size=self.world_size, 5808 rpc_backend_options=options, 5809 ) 5810 5811 rref = rpc.remote( 5812 dst, 5813 TensorPipeAgentCudaRpcTest._add_to_gpu, 5814 args=(torch.zeros(2), 1) 5815 ) 5816 5817 self.assertEqual(rref.to_here().device.index, 1) 5818 self.assertEqual(rref.to_here(), torch.ones(2).to(1)) 5819 5820 rpc.shutdown() 5821 5822 @staticmethod 5823 def _slow_add_on_user_stream(x, y): 5824 s0 = torch.cuda.current_stream(x.device) 5825 s1 = torch.cuda.Stream(device=x.device) 5826 s1.wait_stream(s0) 5827 x.record_stream(s1) 5828 y.record_stream(s1) 5829 with torch.cuda.stream(s1): 5830 torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) 5831 z = x + y 5832 s0.wait_stream(s1) 5833 z.record_stream(s0) 5834 return z 5835 5836 def _test_custom_stream(self, fn, device_map): 5837 options = self.rpc_backend_options 5838 dst = worker_name((self.rank + 1) % self.world_size) 5839 options.set_device_map(dst, device_map) 5840 5841 rpc.init_rpc( 5842 name=worker_name(self.rank), 5843 backend=self.rpc_backend, 5844 rank=self.rank, 5845 world_size=self.world_size, 5846 rpc_backend_options=options, 5847 ) 5848 5849 fn(dst) 5850 5851 rpc.shutdown() 5852 5853 def _test_stream_sync(self, dst): 5854 x = torch.ones(2, 2).to(0) 5855 ret = rpc.rpc_sync( 5856 dst, 5857 TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, 5858 args=(x, x) 5859 ) 5860 self.assertEqual(ret, 2 * x) 5861 5862 @skip_if_lt_x_gpu(2) 5863 def test_custom_stream(self): 5864 self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"}) 5865 5866 def _test_stream_multi_async(self, dst): 5867 futs = [] 5868 for i in range(20): 5869 x = torch.ones(2, 2).to(0) * i 5870 futs.append( 5871 rpc.rpc_async( 5872 dst, 5873 TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, 5874 args=(x, x) 5875 ) 5876 ) 5877 5878 for i in range(20): 5879 self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i) 5880 5881 @skip_if_lt_x_gpu(2) 5882 def test_custom_stream_multi(self): 5883 self._test_custom_stream( 5884 self._test_stream_multi_async, 5885 {"cuda:0": "cuda:1"} 5886 ) 5887 5888 @staticmethod 5889 def _nested_slow_add_on_user_stream(dst, x, y, z): 5890 ret = rpc.rpc_sync( 5891 dst, 5892 TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, 5893 args=(x, y) 5894 ) 5895 5896 return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z) 5897 5898 def _test_stream_nested_sync(self, dst): 5899 x = torch.ones(2, 2).to(0) 5900 y = torch.ones(2, 2).to(0) * 2 5901 z = torch.ones(2, 2).to(0) * 3 5902 nested_dst = worker_name((self.rank + 2) % self.world_size) 5903 ret = rpc.rpc_sync( 5904 dst, 5905 TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, 5906 args=(nested_dst, x, y, z) 5907 ) 5908 self.assertEqual(ret, 6 * x) 5909 5910 @skip_if_lt_x_gpu(2) 5911 def test_custom_stream_nested(self): 5912 self._test_custom_stream( 5913 self._test_stream_nested_sync, 5914 {"cuda:0": "cuda:1", "cuda:1": "cuda:0"} 5915 ) 5916 5917 def _test_stream_nested_multi_async(self, dst): 5918 if self.rank == 0: 5919 futs = [] 5920 n = 5 5921 xs, ys, zs = [], [], [] 5922 for i in range(n): 5923 x = torch.ones(2, 2).to(0) * (i - 1) 5924 y = torch.ones(2, 2).to(0) * i 5925 z = torch.ones(2, 2).to(0) * (i + 1) 5926 xs.append(x) 5927 ys.append(y) 5928 zs.append(z) 5929 nested_dst = worker_name((self.rank + 2) % self.world_size) 5930 futs.append( 5931 rpc.rpc_async( 5932 dst, 5933 TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream, 5934 args=(nested_dst, x, y, z) 5935 ) 5936 ) 5937 5938 for i in range(n): 5939 self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i]) 5940 5941 @skip_if_lt_x_gpu(2) 5942 def test_custom_stream_nested_multi(self): 5943 self._test_custom_stream( 5944 self._test_stream_nested_multi_async, 5945 {"cuda:0": "cuda:1", "cuda:1": "cuda:0"} 5946 ) 5947 5948 @staticmethod 5949 def _gpu_add_wrong_gpus(x, y): 5950 if x.is_cuda and y.is_cuda: 5951 return x.cpu() + y.cuda() 5952 else: 5953 raise ValueError("Wrong device affinity") 5954 5955 @skip_if_lt_x_gpu(1) 5956 def test_device_mismatch(self): 5957 dst = worker_name((self.rank + 1) % self.world_size) 5958 options = self.rpc_backend_options 5959 options.set_device_map(dst, {0: 0}) 5960 5961 rpc.init_rpc( 5962 name=worker_name(self.rank), 5963 backend=self.rpc_backend, 5964 rank=self.rank, 5965 world_size=self.world_size, 5966 rpc_backend_options=options, 5967 ) 5968 5969 x = torch.zeros(2).to(0) 5970 y = torch.ones(2).to(0) 5971 5972 with self.assertRaisesRegex( 5973 RuntimeError, 5974 "Expected all tensors to be on the same device, but found at least two devices" 5975 ): 5976 rets = rpc.rpc_sync( 5977 dst, 5978 TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, 5979 args=(x, y) 5980 ) 5981 5982 rpc.shutdown() 5983 5984 def _test_rref_synchronization(self, local_device, remote_device): 5985 dst = worker_name((self.rank + 1) % self.world_size) 5986 options = self.rpc_backend_options 5987 options.set_device_map(dst, {local_device : remote_device}) 5988 5989 rpc.init_rpc( 5990 name=worker_name(self.rank), 5991 backend=self.rpc_backend, 5992 rank=self.rank, 5993 world_size=self.world_size, 5994 rpc_backend_options=options, 5995 ) 5996 5997 if self.rank == 1: 5998 # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() 5999 # If to_here() is properly synchronized with forward(x) the results must be identical 6000 # This test needs multiple iterations and significant batch size to simulate real 6001 # training of a CNN of MNIST-like data. 6002 # see https://github.com/pytorch/pytorch/issues/54771 6003 rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) 6004 for _ in range(10): 6005 x = torch.randn(200, 1, 28, 28).to(local_device) 6006 actual = rref.remote().forward(x).to_here() 6007 expected = rref.rpc_sync().forward(x) 6008 self.assertEqual(actual, expected) 6009 6010 rpc.shutdown() 6011 6012 @skip_if_lt_x_gpu(1) 6013 def test_rref_to_here_synchronization1(self): 6014 self._test_rref_synchronization("cuda:0", "cuda:0") 6015 6016 @skip_if_lt_x_gpu(2) 6017 def test_rref_to_here_synchronization2(self): 6018 self._test_rref_synchronization("cuda:1", "cuda:0") 6019 6020 @skip_if_lt_x_gpu(2) 6021 def test_rref_to_here_synchronization3(self): 6022 self._test_rref_synchronization("cuda:1", "cuda:1") 6023 6024 @skip_if_lt_x_gpu(2) 6025 def test_rref_to_here_synchronization4(self): 6026 self._test_rref_synchronization("cuda:0", "cuda:1") 6027 6028 def _test_rref_as_arg_synchronization( 6029 self, 6030 local_device, 6031 remote_device, 6032 devicesOptions=None 6033 ): 6034 dst = worker_name((self.rank + 1) % self.world_size) 6035 options = self.rpc_backend_options 6036 options.set_device_map(dst, {local_device: remote_device}) 6037 6038 input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size) 6039 options.set_device_map(input_src, {remote_device: local_device}) 6040 6041 if devicesOptions is not None: 6042 options.set_devices(devicesOptions[self.rank]) 6043 6044 rpc.init_rpc( 6045 name=worker_name(self.rank), 6046 backend=self.rpc_backend, 6047 rank=self.rank, 6048 world_size=self.world_size, 6049 rpc_backend_options=options, 6050 ) 6051 6052 if self.rank == 1: 6053 # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() 6054 # If to_here() is properly synchronized with forward(x) the results must be identical 6055 # This test needs multiple iterations and significant batch size to simulate real 6056 # training of a CNN of MNIST-like data. 6057 # see https://github.com/pytorch/pytorch/issues/54771 6058 rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,)) 6059 for _ in range(10): 6060 rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device)) 6061 actual = rref.remote().forward(rref_x, True).to_here() 6062 expected = rref.rpc_sync().forward(rref_x, True) 6063 self.assertEqual(actual, expected) 6064 6065 rpc.shutdown() 6066 6067 @skip_if_lt_x_gpu(1) 6068 def test_rref_as_arg_synchronization1(self): 6069 self._test_rref_as_arg_synchronization("cuda:0", "cuda:0") 6070 6071 @skip_if_lt_x_gpu(2) 6072 def test_rref_as_arg_synchronization2(self): 6073 self._test_rref_as_arg_synchronization("cuda:1", "cuda:0") 6074 6075 @skip_if_lt_x_gpu(2) 6076 def test_rref_as_arg_synchronization3(self): 6077 self._test_rref_as_arg_synchronization("cuda:1", "cuda:1") 6078 6079 @skip_if_lt_x_gpu(2) 6080 def test_rref_as_arg_synchronization4(self): 6081 self._test_rref_as_arg_synchronization("cuda:0", "cuda:1") 6082 6083 @skip_if_lt_x_gpu(1) 6084 def test_rref_as_arg_synchronization5(self): 6085 self._test_rref_as_arg_synchronization( 6086 "cuda:0", 6087 "cuda:0", 6088 [["cuda:0"] for _ in range(4)], # devicesOptions 6089 ) 6090 6091 @staticmethod 6092 def _rref_relay(rref): 6093 return rref.to_here() 6094 6095 def _test_rref_forward_synchronization(self, local_device, remote_device): 6096 options = self.rpc_backend_options 6097 6098 input_src = worker_name(0) 6099 model_dst = worker_name(1) 6100 out_relay = worker_name(2) 6101 6102 if self.rank == 0: 6103 # for 1) model construction 2) forward execution 6104 options.set_device_map(model_dst, {local_device: remote_device}) 6105 6106 # Forward output will be first copied to the relay node before 6107 # returning to the worker. This is intentional, to test RRef 6108 # forward CUDA stream synchronizations. 6109 options.set_device_map(out_relay, {local_device: local_device}) 6110 elif self.rank == 1: 6111 # worker1 hosts the model and runs forward. The forward functions 6112 # calls RRef.to_here(), hence needs to configure the device map 6113 options.set_device_map(input_src, {remote_device: local_device}) 6114 elif self.rank == 2: 6115 # worker2 will get the out RRef and call to_here() and hence, needs 6116 # to configure device map. 6117 options.set_device_map(model_dst, {local_device: remote_device}) 6118 6119 rpc.init_rpc( 6120 name=worker_name(self.rank), 6121 backend=self.rpc_backend, 6122 rank=self.rank, 6123 world_size=self.world_size, 6124 rpc_backend_options=options, 6125 ) 6126 6127 if self.rank == 0: 6128 # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here() 6129 # If to_here() is properly synchronized with forward(x) the results must be identical 6130 # This test needs multiple iterations and significant batch size to simulate real 6131 # training of a CNN of MNIST-like data. 6132 # see https://github.com/pytorch/pytorch/issues/54771 6133 rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,)) 6134 for _ in range(10): 6135 rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device)) 6136 rref_out = rref.remote().forward(rref_input, True) 6137 out = rpc.remote( 6138 out_relay, 6139 TensorPipeAgentCudaRpcTest._rref_relay, 6140 args=(rref_out,) 6141 ).to_here() 6142 expected = rref.rpc_sync().forward(rref_input, True) 6143 self.assertEqual(out, expected) 6144 6145 rpc.shutdown() 6146 6147 @skip_if_lt_x_gpu(1) 6148 def test_rref_forward_synchronization1(self): 6149 self._test_rref_forward_synchronization("cuda:0", "cuda:0") 6150 6151 @skip_if_lt_x_gpu(2) 6152 def test_rref_forward_synchronization2(self): 6153 self._test_rref_forward_synchronization("cuda:0", "cuda:1") 6154 6155 @skip_if_lt_x_gpu(2) 6156 def test_rref_forward_synchronization3(self): 6157 self._test_rref_forward_synchronization("cuda:1", "cuda:0") 6158 6159 @skip_if_lt_x_gpu(2) 6160 def test_rref_forward_synchronization4(self): 6161 self._test_rref_forward_synchronization("cuda:1", "cuda:1") 6162 6163 def _test_owner_rref_forward_synchronization(self, local_device, remote_device): 6164 if self.rank == 0: 6165 options = self.rpc_backend_options 6166 options.set_device_map("w0", {local_device: remote_device}) 6167 rpc.init_rpc( 6168 "w0", 6169 rank=0, 6170 world_size=1, 6171 rpc_backend_options=options 6172 ) 6173 6174 model = rpc.remote( 6175 "w0", torch.nn.Linear, (2048, 20000) 6176 ).remote().to(remote_device) 6177 for _ in range(30): 6178 data = torch.rand(2048, 2048).to(local_device) 6179 output = model.rpc_sync().forward(data) 6180 # to_here() internally calls localValue as the caller is 6181 # the owner of the RRef. 6182 v0 = rpc.RRef(output).remote().sum().to_here().item() 6183 v1 = output.sum().item() 6184 self.assertEqual(v0, v1) 6185 6186 rpc.shutdown() 6187 6188 @skip_if_lt_x_gpu(1) 6189 def test_owner_rref_forward_synchronization1(self): 6190 self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0") 6191 6192 @skip_if_lt_x_gpu(2) 6193 def test_owner_rref_forward_synchronization2(self): 6194 self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1") 6195 6196 @skip_if_lt_x_gpu(2) 6197 def test_owner_rref_forward_synchronization3(self): 6198 self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0") 6199 6200 @skip_if_lt_x_gpu(2) 6201 def test_owner_rref_forward_synchronization4(self): 6202 self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1") 6203 6204 @staticmethod 6205 def _return_tensor_view(i): 6206 x = torch.ones(1000, 200).cuda(0) * i 6207 torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) 6208 # serialization of the return value will create a new tensor from the 6209 # view, which is done outside of the user function. 6210 return x.split(100)[0] 6211 6212 @skip_if_lt_x_gpu(1) 6213 def test_tensor_view_as_return_value(self): 6214 dst = worker_name((self.rank + 1) % self.world_size) 6215 options = self.rpc_backend_options 6216 options.set_device_map(dst, {0 : 0}) 6217 6218 rpc.init_rpc( 6219 name=worker_name(self.rank), 6220 backend=self.rpc_backend, 6221 rank=self.rank, 6222 world_size=self.world_size, 6223 rpc_backend_options=options, 6224 ) 6225 6226 futs = [] 6227 for i in range(5): 6228 futs.append(rpc.rpc_async( 6229 dst, 6230 TensorPipeAgentCudaRpcTest._return_tensor_view, 6231 args=(i,) 6232 )) 6233 6234 for i in range(5): 6235 self.assertEqual(torch.ones(100, 200) * i, futs[i].wait()) 6236 6237 rpc.shutdown() 6238 6239 @skip_if_lt_x_gpu(2) 6240 def test_devices_option_mismatch(self): 6241 with self.assertRaisesRegex( 6242 ValueError, 6243 "Node worker0 has unexpected source devices in its device map for worker1" 6244 ): 6245 dst = worker_name((self.rank + 1) % self.world_size) 6246 options = self.rpc_backend_options 6247 options.set_device_map(dst, {0 : 0}) 6248 options.set_devices([1]) 6249 6250 rpc.init_rpc( 6251 name=worker_name(self.rank), 6252 backend=self.rpc_backend, 6253 rank=self.rank, 6254 world_size=self.world_size, 6255 rpc_backend_options=options, 6256 ) 6257 6258 rpc.shutdown() 6259 6260 @skip_if_lt_x_gpu(2) 6261 def test_devices_option_mismatch_reverse(self): 6262 with self.assertRaisesRegex( 6263 ValueError, 6264 "Node worker0 has unexpected target devices in its device map for worker1" 6265 ): 6266 dst = worker_name((self.rank + 1) % self.world_size) 6267 6268 options = rpc.TensorPipeRpcBackendOptions( 6269 init_method=self.rpc_backend_options.init_method, 6270 num_worker_threads=self.rpc_backend_options.num_worker_threads, 6271 device_maps={dst: {0 : 1}}, 6272 devices=[0] 6273 ) 6274 6275 rpc.init_rpc( 6276 name=worker_name(self.rank), 6277 backend=self.rpc_backend, 6278 rank=self.rank, 6279 world_size=self.world_size, 6280 rpc_backend_options=options, 6281 ) 6282 6283 rpc.shutdown() 6284 6285 @skip_if_lt_x_gpu(1) 6286 def test_cuda_future_device_as_int(self): 6287 fut = Future(devices=[0]) 6288 6289 @skip_if_lt_x_gpu(1) 6290 def test_cuda_future_device_as_str(self): 6291 fut = Future(devices=["cuda:0"]) 6292 6293 @skip_if_lt_x_gpu(1) 6294 def test_cuda_future_device_as_device(self): 6295 fut = Future(devices=[torch.device("cuda", 0)]) 6296 6297 @skip_if_lt_x_gpu(1) 6298 def test_cuda_future_device_not_cuda(self): 6299 with self.assertRaisesRegex( 6300 ValueError, "Expected devices to have indices, got cpu" 6301 ): 6302 fut = Future(devices=["cpu"]) 6303 6304 @skip_if_lt_x_gpu(1) 6305 def test_cuda_future_can_extract_cuda_tensor(self): 6306 self._test_cuda_future_extraction( 6307 wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False 6308 ) 6309 6310 @skip_if_lt_x_gpu(1) 6311 def test_cuda_future_can_extract_list_with_cuda_tensor(self): 6312 self._test_cuda_future_extraction( 6313 wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False 6314 ) 6315 6316 @skip_if_lt_x_gpu(1) 6317 def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self): 6318 self._test_cuda_future_extraction( 6319 wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False 6320 ) 6321 6322 @skip_if_lt_x_gpu(2) 6323 def test_cuda_future_callback_changes_devices(self): 6324 # We check proper CUDA stream synchronization by filling the tensor with 6325 # the expected value in one stream, and reading it from another stream. 6326 tensor0 = torch.zeros((100,), device="cuda:0") 6327 tensor1 = torch.zeros((100,), device="cuda:1") 6328 parent_future = Future(devices=["cuda:0", "cuda:1"]) 6329 6330 def cb(fut): 6331 t0 = fut.value() 6332 tensor1.copy_(t0, non_blocking=True) 6333 return tensor1 6334 6335 child_future = parent_future.then(cb) 6336 with torch.cuda.device("cuda:0"): 6337 stream = torch.cuda.Stream() 6338 with torch.cuda.stream(stream): 6339 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 6340 tensor0.fill_(1) 6341 parent_future.set_result(tensor0) 6342 with torch.cuda.device("cuda:1"): 6343 another_stream = torch.cuda.Stream() 6344 with torch.cuda.stream(another_stream): 6345 self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) 6346 6347 @skip_if_lt_x_gpu(2) 6348 def test_cuda_future_value_on_bad_device(self): 6349 tensor0 = torch.zeros((100,), device="cuda:0") 6350 tensor1 = torch.zeros((100,), device="cuda:1") 6351 parent_future = Future(devices=["cuda:1"]) 6352 6353 # As a plus, we test that futures still invoke callbacks even in case of 6354 # error, and that the child futures are successful if those callbacks 6355 # don't access the parent future. 6356 def cb(fut): 6357 with torch.cuda.device("cuda:1"): 6358 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 6359 tensor1.fill_(1) 6360 return tensor1 6361 6362 child_future = parent_future.then(cb) 6363 with torch.cuda.device("cuda:0"): 6364 stream = torch.cuda.Stream() 6365 with torch.cuda.stream(stream): 6366 torch.cuda._sleep(int(1000 * get_cycles_per_ms())) 6367 tensor0.fill_(1) 6368 parent_future.set_result(tensor0) 6369 with self.assertRaisesRegex( 6370 ValueError, 6371 r"The result contained tensors residing on device\(s\) cuda:0 " 6372 r"which are not among the expected device\(s\) cuda:1", 6373 ): 6374 parent_future.wait() 6375 with torch.cuda.device("cuda:1"): 6376 another_stream = torch.cuda.Stream() 6377 with torch.cuda.stream(another_stream): 6378 self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) 6379 6380 @skip_if_lt_x_gpu(1) 6381 def test_async_execution_with_cuda_future(self): 6382 dst = worker_name((self.rank + 1) % self.world_size) 6383 options = self.rpc_backend_options 6384 options.set_device_map(dst, {"cuda:0": "cuda:0"}) 6385 6386 rpc.init_rpc( 6387 name=worker_name(self.rank), 6388 backend=self.rpc_backend, 6389 rank=self.rank, 6390 world_size=self.world_size, 6391 rpc_backend_options=options, 6392 ) 6393 6394 t = torch.zeros((100,), device="cuda:0") 6395 fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,)) 6396 another_stream = torch.cuda.Stream("cuda:0") 6397 with torch.cuda.stream(another_stream): 6398 self.assertTrue(torch.eq(fut.wait(), 1).all().item()) 6399 6400 rpc.shutdown() 6401 6402 @skip_if_lt_x_gpu(1) 6403 def test_async_execution_nested_with_cuda_future(self): 6404 dst = worker_name((self.rank + 1) % self.world_size) 6405 nested_dst = worker_name((self.rank + 2) % self.world_size) 6406 options = self.rpc_backend_options 6407 options.set_device_map(dst, {"cuda:0": "cuda:0"}) 6408 6409 rpc.init_rpc( 6410 name=worker_name(self.rank), 6411 backend=self.rpc_backend, 6412 rank=self.rank, 6413 world_size=self.world_size, 6414 rpc_backend_options=options, 6415 ) 6416 6417 a = torch.ones((100,), device="cuda:0") 6418 b = torch.ones((100,), device="cuda:0") 6419 c = torch.ones((100,), device="cuda:0") 6420 fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c)) 6421 another_stream = torch.cuda.Stream("cuda:0") 6422 with torch.cuda.stream(another_stream): 6423 self.assertTrue(torch.eq(fut.wait(), 3).all().item()) 6424 6425 rpc.shutdown() 6426 6427 @skip_if_lt_x_gpu(1) 6428 def test_cuda_future_modify_tensor_inplace(self): 6429 tensor = torch.zeros((100,), device="cuda:0") 6430 future = Future(devices=["cuda:0"]) 6431 future.set_result(tensor) 6432 # It's weird to modify the value of a future once it's complete, but 6433 # technically possible. Currently this is considered undefined behavior 6434 # (in practice the future will ignore the modification and still 6435 # synchronize with the original value). We could one day add logic to 6436 # detect and warn or throw in such cases, but for now we just check that 6437 # this doesn't crash. 6438 tensor.fill_(1) 6439 future.wait() 6440 6441 @skip_if_lt_x_gpu(1) 6442 def test_cuda_future_replace_tensor(self): 6443 tensor_list = [torch.zeros((100,), device="cuda:0")] 6444 future = Future(devices=["cuda:0"]) 6445 future.set_result(tensor_list) 6446 # It's weird to modify the value of a future once it's complete, but 6447 # technically possible. Currently this is considered undefined behavior 6448 # (in practice the future will ignore the modification and still 6449 # synchronize with the original value). We could one day add logic to 6450 # detect and warn or throw in such cases, but for now we just check that 6451 # this doesn't crash. 6452 # We set things up so that the original tensor contained in the list 6453 # gets deleted once we replace it with the other one. This will 6454 # invalidate any cached information held by the future. 6455 tensor_list[0] = torch.ones((100,), device="cuda:0") 6456 future.wait() 6457 6458 @skip_if_lt_x_gpu(1) 6459 def test_rref_with_unpickleable_attributes(self): 6460 dst = worker_name((self.rank + 1) % self.world_size) 6461 options = self.rpc_backend_options 6462 options.set_device_map(dst, {"cuda:0": "cuda:0"}) 6463 6464 rpc.init_rpc( 6465 name=worker_name(self.rank), 6466 backend=self.rpc_backend, 6467 rank=self.rank, 6468 world_size=self.world_size, 6469 rpc_backend_options=options, 6470 ) 6471 6472 rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),)) 6473 rref.rpc_sync().increase(1) 6474 ret = rref.rpc_sync().sum() 6475 self.assertEqual(ret, 42) 6476 6477 rpc.shutdown() 6478 6479 @skip_if_lt_x_gpu(1) 6480 def test_cuda_future_can_extract_cuda_sparse_tensor(self): 6481 self._test_cuda_future_extraction( 6482 wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True 6483 ) 6484 6485 @skip_if_lt_x_gpu(1) 6486 def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self): 6487 self._test_cuda_future_extraction( 6488 wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True 6489 ) 6490 6491 @skip_if_lt_x_gpu(1) 6492 def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self): 6493 self._test_cuda_future_extraction( 6494 wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True 6495 ) 6496