1# mypy: allow-untyped-defs 2 3import time 4import io 5from typing import Dict, List, Tuple, Any 6 7import torch 8import torch.distributed as dist 9import torch.distributed.rpc as rpc 10from torch import Tensor 11from torch.autograd.profiler import record_function 12from torch.distributed.rpc import RRef 13from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key 14from torch.futures import Future 15from torch.testing._internal.common_utils import TemporaryFileName 16from torch.testing._internal.dist_utils import ( 17 dist_init, 18 get_function_event, 19 initialize_pg, 20 worker_name, 21) 22from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 23 RpcAgentTestFixture, 24) 25 26from torch.autograd.profiler_legacy import profile as _profile 27 28def rref_isinstance(rref, cls_to_check): 29 return isinstance(rref.local_value(), cls_to_check) 30 31def sleep(t): 32 time.sleep(t) 33 34 35def rpc_return_rref(dst): 36 return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) 37 38 39@torch.jit.script 40def rref_local_value(rref: RRef[Tensor]) -> Tensor: 41 return rref.local_value() 42 43 44@torch.jit.script 45def list_create() -> List[int]: 46 global_list = [1, 2, 3] 47 return global_list 48 49 50@torch.jit.script 51def rref_list_mutate(rref: RRef[List[int]]) -> None: 52 rref.local_value().append(4) 53 rref.to_here().append(5) 54 rref.to_here(5.0).append(6) 55 56 57def return_value(value: int) -> int: 58 return value 59 60 61class RRefAPITest: 62 @dist_init 63 def test_rref_is_owner(self): 64 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 65 rref_var = rpc_return_rref(dst_worker_name) 66 67 @torch.jit.script 68 def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool: 69 return rref_var.is_owner() 70 71 res = rref_tensor_is_owner(rref_var) 72 self.assertEqual(res, False) 73 74 @dist_init 75 def test_rref_local_value(self): 76 if self.rank != 0: 77 return 78 79 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 80 rref = rpc_return_rref(dst_worker_name) 81 82 with self.assertRaisesRegex( 83 RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef" 84 ): 85 rref_local_value(rref) 86 87 ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,)) 88 self.assertEqual(ret, torch.add(torch.ones(2, 2), 1)) 89 90 @dist_init 91 def test_local_rref_local_value(self): 92 if self.rank != 0: 93 return 94 95 dst_worker_name = worker_name(self.rank) 96 rref = rpc.remote(dst_worker_name, return_value, (5,), {}) 97 98 ret = rref_local_value(rref) 99 self.assertEqual(ret, 5) 100 101 def _create_rref(self): 102 owner_rank = (self.rank + 2) % self.world_size 103 return rpc.remote( 104 worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1) 105 ) 106 107 @dist_init 108 def test_user_rrefs_confirmed(self): 109 dst_rank = (self.rank + 1) % self.world_size 110 rref = self._create_rref() 111 ret = rpc.rpc_sync( 112 worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) 113 ) 114 self.assertEqual(ret, True) 115 116 @dist_init 117 def test_user_rrefs_confirmed_remote(self): 118 dst_rank = (self.rank + 1) % self.world_size 119 rref = self._create_rref() 120 ret_rref = rpc.remote( 121 worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) 122 ) 123 self.assertEqual(ret_rref.to_here(), True) 124 125 @dist_init 126 def test_rref_list_mutate(self): 127 dst = worker_name((self.rank + 1) % self.world_size) 128 list_rref = rpc.remote(dst, list_create) 129 130 rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,)) 131 self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6]) 132 133 134@torch.jit.script 135def no_arg(): 136 return 0 137 138 139@torch.jit.script 140def one_arg(value): 141 return value + 1 142 143@torch.jit.script 144def script_add_ones(x): 145 return torch.add(x, torch.ones(1)) 146 147@torch.jit.script 148def script_add_ones_with_record_function(x, block: str): 149 with record_function(block): 150 return torch.add(x, torch.ones(1)) 151 152 153@torch.jit.script 154def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor: 155 t: Tensor = torch.ones(1) 156 with record_function(block) as rf: 157 fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) 158 # Extra operator call to avoid de-duplication of the next async call 159 # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279 160 zero = torch.zeros_like(t) 161 fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) 162 res = fut1.wait() + fut2.wait() + zero 163 return res 164 165 166 167@torch.jit.script 168def script_fork_wait_udf(tensor): 169 fut = torch.jit._fork(script_add_ones, tensor) 170 x = torch.jit._wait(fut) 171 return x 172 173 174@torch.jit.script 175def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: 176 return rref_var.to_here() 177 178 179@torch.jit.script 180def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]: 181 return rref_var 182 183 184@torch.jit.script 185def script_raise_func(value): 186 if value.numel() == 2: 187 raise ValueError("Expected error") 188 return value + 1 189 190 191@torch.jit.script 192def script_fork_wait_throw(invalue): 193 fut = torch.jit._fork(script_raise_func, invalue) 194 value = torch.jit._wait(fut) 195 return value 196 197 198@torch.jit.script 199def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_worker_name: str) -> Tensor: 200 # Call rpc_async from within ScriptFunction and ensure that we can attach 201 # profiling callbacks. Note that handle here is a Tensor representation of 202 # RecordFunction. 203 fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),)) 204 torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) 205 ret = fut.wait() 206 return ret 207 208@torch.jit.script 209def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor: 210 fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)) 211 return fut.wait() 212 213 214@torch.jit.script 215def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor: 216 # Call fork from within ScriptFunction and ensure that we can attach profiling 217 # callbacks to the resulting future. Note that handle here is a Tensor 218 # representation of RecordFunction. 219 fut = torch.jit._fork(one_arg, torch.tensor(1)) 220 torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) 221 ret = fut.wait() 222 return ret 223 224 225class MyScriptModuleWithRRefs(torch.jit.ScriptModule): 226 def __init__(self, dst_worker): 227 super().__init__() 228 self.rrefs = [] 229 for _ in range(4): 230 self.rrefs.append(rpc_return_rref(dst_worker)) 231 232 @torch.jit.script_method 233 def forward(self) -> Tensor: 234 res_tensor = torch.ones(2, 2) 235 for rref in self.rrefs: 236 res_tensor += rref.to_here() 237 238 return res_tensor 239 240 241@torch.jit.ignore 242def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]: 243 return rref_var 244 245 246@torch.jit.script 247def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor: 248 return rref_python_annotation(rref_var).to_here() 249 250 251class RRefTypingTest: 252 @dist_init 253 def test_rref_as_arg_and_return(self): 254 n = self.rank + 1 255 dst_rank = n % self.world_size 256 local_ret = one_arg(torch.ones(2, 2)) 257 258 # create rref on current rank 259 rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)) 260 261 # pass rref to another user in rpc call 262 ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,)) 263 self.assertEqual(ret, local_ret) 264 265 # return rref in rpc call 266 rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,)) 267 self.assertEqual(rref1.to_here(), local_ret) 268 269 # pass rref to another user in remote call 270 rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,)) 271 self.assertEqual(rref2.to_here(), local_ret) 272 273 # return rref in remote call 274 rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,)) 275 self.assertEqual(rref3.to_here().to_here(), local_ret) 276 277 @dist_init 278 def test_my_script_module_with_rrefs(self): 279 n = self.rank + 1 280 dst_rank = n % self.world_size 281 282 module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank)) 283 res = module_with_rrefs() 284 self.assertEqual(res, torch.ones(2, 2) * 9) 285 286 @dist_init 287 def test_rref_python_annotation(self): 288 n = self.rank + 1 289 dst_rank = n % self.world_size 290 rref_var = rpc_return_rref(worker_name(dst_rank)) 291 292 res = rref_script_annotation(rref_var) 293 self.assertEqual(res, torch.ones(2, 2) + 1) 294 295 296class FutureTypingTest: 297 @dist_init 298 def test_future_passed_between_python_and_jit(self): 299 dst_rank = (self.rank + 1) % self.world_size 300 inputs = (torch.tensor([1, 1]), torch.tensor([2, 2])) 301 ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs) 302 expected_res = torch.tensor([10, 10]) 303 304 @torch.jit.script 305 def future_wait_in_script(fut: Future[Tensor]) -> Tensor: 306 return fut.wait() 307 308 self.assertEqual(future_wait_in_script(ret_fut), expected_res) 309 310 @torch.jit.script 311 def future_return_to_python( 312 dst_rank: int, inputs: Tuple[Tensor, Tensor] 313 ) -> Future[Tensor]: 314 return rpc.rpc_async( 315 f"worker{dst_rank}", two_args_two_kwargs, inputs 316 ) 317 318 fut_res = future_return_to_python(dst_rank, inputs) 319 self.assertEqual(fut_res.wait(), expected_res) 320 321 @dist_init 322 def test_future_python_annotation(self): 323 if self.rank != 0: 324 return 325 326 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 327 input_0 = torch.ones(2, 2) 328 input_1 = 1 329 expected_res = torch.add(input_0, input_1) 330 331 @torch.jit.ignore 332 def python_return_future() -> Future[Tensor]: 333 fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {}) 334 return fut 335 336 @torch.jit.script 337 def script_use_future() -> Tensor: 338 fut = python_return_future() 339 return fut.wait() 340 341 res = script_use_future() 342 self.assertEqual(res, expected_res) 343 344 345@torch.jit.script 346class MyScriptClass: 347 def __init__(self, a: int): 348 self.a = a 349 350 def get_value(self) -> int: 351 return self.a 352 353 354@torch.jit.interface 355class MyModuleInterface(torch.nn.Module): 356 def forward(self) -> Tensor: 357 # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well 358 pass 359 360 361class MyScriptModule(torch.jit.ScriptModule): 362 def __init__(self, rank): 363 super().__init__() 364 self.a = torch.ones(rank) 365 366 @torch.jit.script_method 367 def forward(self) -> Tensor: 368 return self.a 369 370 @torch.jit.script_method 371 def custom_func(self) -> Tensor: 372 return self.a 373 374 375def owner_create_rref_my_script_class(a): 376 return rpc.RRef(MyScriptClass(a)) 377 378 379def owner_create_rref_my_script_module(a): 380 return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface) 381 382 383@torch.jit.script 384def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int: 385 return rref.to_here().get_value() 386 387 388@torch.jit.script 389def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor: 390 return rref.to_here().forward() 391 392 393class LocalRRefTest: 394 @dist_init 395 def test_create_local_script_class_rref_in_py(self): 396 if self.rank != 0: 397 return 398 399 # Create a local RRef<MyScriptClass>. 400 rref_script_class = rpc.RRef(MyScriptClass(self.rank)) 401 ret = rref_script_class.to_here().get_value() 402 self.assertEqual(ret, self.rank) 403 404 @dist_init 405 def test_create_local_script_module_rref_in_py(self): 406 if self.rank != 0: 407 return 408 409 # Create a local RRef<MyModuleInterface>. 410 rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface) 411 ret = rref_script_module.to_here().forward() 412 self.assertEqual(ret, torch.ones(self.rank)) 413 414 # Create a local RRef<MyModuleInterface> without type hint. 415 with self.assertRaisesRegex( 416 RuntimeError, 417 ( 418 "The RRef being created contains a ScriptModule, " 419 "must provide its ModuleInterface type hint." 420 ), 421 ): 422 rref_script_module = rpc.RRef(MyScriptModule(self.rank)) 423 424 @dist_init 425 def test_return_local_script_class_rref_in_py_and_use_in_script(self): 426 if self.rank != 0: 427 return 428 429 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 430 431 # Create a local RRef<MyScriptClass> remotely in Python. 432 rref = rpc.rpc_sync( 433 dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,) 434 ) 435 436 def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int: 437 args = (rref,) 438 kwargs: Dict[str, Any] = {} 439 fut = rpc.rpc_async( 440 rref.owner(), script_rref_get_value_my_script_class, args, kwargs 441 ) 442 ret = fut.wait() 443 return ret 444 445 # Use RRef<MyScriptClass> in local Python RPC and remote Script run. 446 ret = use_rref_on_owner(rref) 447 self.assertEqual(ret, self.rank) 448 449 # Use RRef<MyScriptClass> in local Script RPC and remote Script run. 450 use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) 451 ret = use_rref_on_owner_script(rref) 452 self.assertEqual(ret, self.rank) 453 454 @dist_init 455 def test_return_local_script_module_rref_in_py_and_use_in_script(self): 456 if self.rank != 0: 457 return 458 459 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 460 461 # Create a local RRef<MyModuleInterface> remotely in Python. 462 rref = rpc.rpc_sync( 463 dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,) 464 ) 465 466 def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor: 467 args = (rref,) 468 kwargs: Dict[str, Any] = {} 469 fut = rpc.rpc_async( 470 rref.owner_name(), 471 script_rref_run_forward_my_script_module, 472 args, 473 kwargs, 474 ) 475 ret = fut.wait() 476 return ret 477 478 # Use RRef<MyScriptClass> in local Python RPC and remote Script run. 479 ret = use_rref_on_owner(rref) 480 self.assertEqual(ret, torch.ones(self.rank)) 481 482 # Use RRef<MyScriptClass> in local Script RPC and remote Script run. 483 use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) 484 ret = use_rref_on_owner_script(rref) 485 self.assertEqual(ret, torch.ones(self.rank)) 486 487 488def python_function(): 489 return 0 490 491 492@torch.jit.script 493def two_args_two_kwargs( 494 first_arg, 495 second_arg, 496 first_kwarg=torch.tensor([3, 3]), 497 second_kwarg=torch.tensor([4, 4]), 498): 499 return first_arg + second_arg + first_kwarg + second_kwarg 500 501 502@torch.jit.script 503def assorted_types_args_kwargs( 504 tensor_arg: Tensor, # noqa: E999 505 str_arg: str, 506 int_arg: int, 507 tensor_kwarg: Tensor = torch.tensor([2, 2]), 508 str_kwarg: str = "str_kwarg", 509 int_kwarg: int = 2, 510): 511 return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg 512 513 514@torch.jit.script 515def raise_script(): 516 raise RuntimeError("Expected error") 517 518 519@torch.jit.script 520def script_rpc_async_call( 521 dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] 522): 523 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 524 ret = fut.wait() 525 return ret 526 527@torch.jit.script 528def script_rpc_sync_call( 529 dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] 530): 531 res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs) 532 return res 533 534@torch.jit.script 535def script_rpc_remote_call( 536 dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] 537): 538 rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs) 539 return rref_res.to_here() 540 541class JitRpcOpTest: 542 # Call functions remotely from Script. 543 @dist_init 544 def test_all_kwargs_are_populated_by_defaults(self): 545 if self.rank != 0: 546 return 547 548 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 549 550 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 551 kwargs = {} 552 553 for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: 554 ret = script_op( 555 dst_worker_name, args, kwargs 556 ) 557 self.assertEqual(ret, torch.tensor([10, 10])) 558 559 @dist_init 560 def test_some_kwargs_are_populated_by_defaults(self): 561 if self.rank != 0: 562 return 563 564 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 565 566 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 567 kwargs = {"first_kwarg": torch.tensor([2, 2])} 568 569 for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: 570 ret = script_op( 571 dst_worker_name, args, kwargs 572 ) 573 self.assertEqual(ret, torch.tensor([9, 9])) 574 575 @dist_init 576 def test_no_kwargs_are_populated_by_defaults(self): 577 if self.rank != 0: 578 return 579 580 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 581 582 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 583 kwargs = { 584 "first_kwarg": torch.tensor([2, 2]), 585 "second_kwarg": torch.tensor([3, 3]), 586 } 587 for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: 588 ret = script_op( 589 dst_worker_name, args, kwargs 590 ) 591 self.assertEqual(ret, torch.tensor([8, 8])) 592 593 @dist_init 594 def test_args_and_kwargs_contain_different_types(self): 595 if self.rank != 0: 596 return 597 598 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 599 600 @torch.jit.script 601 def script_rpc_async_call_with_assorted_types( 602 dst_worker_name: str, 603 ): 604 args = (torch.tensor([1, 1]), "str_arg", 1) 605 # Must annotate the value type as `Any`, because JIT type inference 606 # does not support multiple types when defining a Dict. 607 # The error JIT gives is, 608 # "Dict values must contain only a single type, " 609 # "expected: Tensor but found str instead." 610 kwargs: Dict[str, Any] = { 611 "tensor_kwarg": torch.tensor([3, 3]), 612 "str_kwarg": "_str_kwarg", 613 "int_kwarg": 3, 614 } 615 fut = rpc.rpc_async( 616 dst_worker_name, assorted_types_args_kwargs, args, kwargs 617 ) 618 ret = fut.wait() 619 return ret 620 621 ret = script_rpc_async_call_with_assorted_types( 622 dst_worker_name 623 ) 624 self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4)) 625 626 @dist_init 627 def test_kwargs_not_passed(self): 628 if self.rank != 0: 629 return 630 631 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 632 633 @torch.jit.script 634 def script_rpc_async_call_without_kwargs_passed( 635 dst_worker_name: str, 636 ): 637 args = () 638 fut = rpc.rpc_async(dst_worker_name, no_arg, args) 639 ret = fut.wait() 640 return ret 641 642 ret = script_rpc_async_call_without_kwargs_passed( 643 dst_worker_name 644 ) 645 self.assertEqual(ret, 0) 646 647 @dist_init 648 def test_args_kwargs_are_neither_passed(self): 649 if self.rank != 0: 650 return 651 652 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 653 654 @torch.jit.script 655 def script_rpc_async_call_without_args_kwargs_passed( 656 dst_worker_name: str, 657 ): 658 fut = rpc.rpc_async(dst_worker_name, no_arg) 659 ret = fut.wait() 660 return ret 661 662 ret = script_rpc_async_call_without_args_kwargs_passed( 663 dst_worker_name 664 ) 665 self.assertEqual(ret, 0) 666 667 @dist_init 668 def test_less_than_needed_args_are_specified(self): 669 if self.rank != 0: 670 return 671 672 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 673 674 # Notice, args matching happens during scripting. 675 with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"): 676 677 @torch.jit.script 678 def script_rpc_async_call_with_less_args( 679 dst_worker_name: str, # noqa: E999 680 ): 681 args = (torch.tensor([1, 1]),) 682 kwargs = {} 683 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 684 ret = fut.wait() 685 return ret 686 687 @dist_init 688 def test_more_than_needed_args_are_specified(self): 689 if self.rank != 0: 690 return 691 692 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 693 694 # Notice, args matching happens during scripting. 695 with self.assertRaisesRegex( 696 RuntimeError, 697 "Expected at most 4 arguments but found 5 positional arguments", 698 ): 699 700 @torch.jit.script 701 def script_rpc_async_call_with_more_args( 702 dst_worker_name: str, 703 ): 704 args = ( 705 torch.tensor([1, 1]), 706 torch.tensor([2, 2]), 707 torch.tensor([3, 3]), 708 torch.tensor([4, 4]), 709 torch.tensor([5, 5]), 710 ) 711 kwargs = {} 712 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 713 ret = fut.wait() 714 return ret 715 716 @dist_init 717 def test_unexepected_kwarg_is_specified(self): 718 if self.rank != 0: 719 return 720 721 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 722 723 # Notice, kwargs matching happens during execution. 724 @torch.jit.script 725 def script_rpc_async_call_with_unexpected_kwarg( 726 dst_worker_name: str, # noqa: E999 727 ): 728 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 729 kwargs = {"third_kwarg": torch.tensor([1, 1])} 730 fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) 731 ret = fut.wait() 732 return ret 733 734 with self.assertRaisesRegex( 735 RuntimeError, "Unknown keyword argument 'third_kwarg'" 736 ): 737 ret = script_rpc_async_call_with_unexpected_kwarg( 738 dst_worker_name 739 ) 740 self.assertEqual(ret, 0) 741 742 @dist_init 743 def test_call_python_function_remotely_from_script_not_supported(self): 744 if self.rank != 0: 745 return 746 747 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 748 749 @torch.jit.script 750 def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str): 751 args = () 752 kwargs = {} 753 fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs) 754 ret = fut.wait() 755 return ret 756 757 with self.assertRaisesRegex( 758 RuntimeError, "attempted to get undefined function" 759 ): 760 ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name) 761 self.assertEqual(ret, 0) 762 763 @dist_init 764 def test_call_script_function_that_raises_remotely_from_script(self): 765 if self.rank != 0: 766 return 767 768 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 769 770 # Notice, TorchScript always translates(emits) Python `raise` statement, 771 # as the exception message string, "Exception", 772 # no matter what exception type and exception message are in the statement, 773 @torch.jit.script 774 def rpc_async_call_remote_raising_torchscript_in_torchscript( 775 dst_worker_name: str, 776 ): 777 args = () 778 kwargs = {} 779 fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs) 780 ret = fut.wait() 781 return ret 782 783 with self.assertRaisesRegex(RuntimeError, "Expected error"): 784 ret = rpc_async_call_remote_raising_torchscript_in_torchscript( 785 dst_worker_name 786 ) 787 self.assertEqual(ret, 0) 788 789 @dist_init 790 def test_call_script_function_that_not_exists_remotely_from_script(self): 791 if self.rank != 0: 792 return 793 794 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 795 796 @torch.jit.script 797 def nonexisting_script(): 798 return 0 799 800 @torch.jit.script 801 def rpc_async_call_remote_nonexisting_torchscript_in_torchscript( 802 dst_worker_name: str, 803 ): 804 args = () 805 kwargs = {} 806 fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs) 807 ret = fut.wait() 808 return ret 809 810 with self.assertRaisesRegex( 811 RuntimeError, "attempted to get undefined function nonexisting_script" 812 ): 813 ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript( 814 dst_worker_name 815 ) 816 self.assertEqual(ret, 0) 817 818 819@torch.jit.ignore 820def my_script_module_init(rank: int) -> MyModuleInterface: 821 return MyScriptModule(rank) 822 823 824@torch.jit.script 825def construct_my_script_module(rank: int) -> MyModuleInterface: 826 return my_script_module_init(rank) 827 828 829@torch.jit.script 830def run_ref_script_module( 831 ref_script_module: RRef[MyModuleInterface], t: Tensor 832) -> Tensor: 833 module = ref_script_module.to_here() 834 return module.forward() + t 835 836 837@torch.jit.script 838def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool: 839 return rref.confirmed_by_owner() 840 841 842@torch.jit.script 843def save_rref(rref_var: RRef[Tensor], fname: str) -> None: 844 torch.save(rref_var, fname) 845 846 847@torch.jit.script 848def script_add(x: Tensor, y: Tensor) -> Tensor: 849 return x + y 850 851 852@rpc.functions.async_execution 853@torch.jit.script 854def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: 855 return rpc.rpc_async(to, script_add, (x, y)) 856 857 858@rpc.functions.async_execution 859@torch.jit.script 860def async_wrong_type() -> Tensor: 861 return torch.zeros(2) 862 863 864def load_script_module_with_pickled_rref(pickled_script_module): 865 f = io.BytesIO(pickled_script_module) 866 m = torch.jit.load(f) 867 return m() 868 869 870class JitRpcTest( 871 RRefAPITest, 872 RRefTypingTest, 873 LocalRRefTest, 874 JitRpcOpTest, 875 FutureTypingTest, 876 RpcAgentTestFixture, 877): 878 @dist_init 879 def test_torchscript_function(self): 880 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 881 local_ret = one_arg(torch.ones(2, 2)) 882 ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) 883 self.assertEqual(ret, local_ret) 884 rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) 885 self.assertEqual(rref.to_here(), local_ret) 886 # create rref to itself 887 local_rref = rpc.remote( 888 worker_name(self.rank), one_arg, args=(torch.ones(2, 2),) 889 ) 890 self.assertEqual(local_rref.to_here(), local_ret) 891 892 @dist_init 893 def test_torchscript_function_exception(self): 894 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 895 with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): 896 ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20)) 897 898 with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): 899 rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20)) 900 901 @dist_init 902 def test_torchscript_functions_not_supported(self): 903 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 904 905 my_local_script_module = MyScriptModule(self.rank) 906 907 # It is not thread safe to instantiate MyScriptModule in multiple threads, 908 # wait for local MyScriptModule instantiation to finish, 909 # otherwise it could instantiate MyScriptModule in parallel with 910 # server thread in the below 911 initialize_pg(self.file_init_method, self.rank, self.world_size) 912 dist.barrier() 913 914 # rpc_sync still accepts script class and run it in 915 # the same code path as python call. 916 ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) 917 918 # rpc_sync does not accept script module method. 919 # Python 3.5 and Python 3.6 throw different error message, the only 920 # common word can be greped is "pickle". 921 with self.assertRaisesRegex(TypeError, "pickle"): 922 ret = rpc.rpc_async( 923 dst_worker_name, my_local_script_module.forward, args=() 924 ) 925 926 @dist_init 927 def test_remote_script_module(self): 928 # TODO, need more investigation 929 # there is rref leak when shutting down, suspect it is because 930 # ref as arg is passed to pybind boundary, and the ref is not garbage 931 # collected by python when calling shutdown() 932 import torch.distributed.rpc.api as api 933 934 api._ignore_rref_leak = True 935 936 local_ret = torch.ones(self.rank) + torch.ones(self.rank) 937 938 n = self.rank + 1 939 dst_rank = n % self.world_size 940 remote_ref = rpc.remote( 941 worker_name(dst_rank), construct_my_script_module, args=(self.rank,) 942 ) 943 944 # pass rref arg to owner 945 ret = rpc.rpc_sync( 946 worker_name(dst_rank), 947 run_ref_script_module, 948 args=(remote_ref, torch.ones(self.rank)), 949 ) 950 self.assertEqual(ret, local_ret) 951 952 # pass rref arg to self/user 953 with self.assertRaisesRegex( 954 RuntimeError, 955 "is an RRef to a ScriptModule. It can't be sent through RPC from owner,", 956 ): 957 ret = rpc.rpc_sync( 958 worker_name(self.rank), 959 run_ref_script_module, 960 args=(remote_ref, torch.ones(self.rank)), 961 ) 962 963 @dist_init 964 def test_create_script_module_on_remote(self): 965 dst_name = worker_name((self.rank + 1) % self.world_size) 966 # Construct on remote end with rpc_sync 967 created_script_module = rpc.rpc_sync( 968 dst_name, MyScriptModule, args=(self.rank,) 969 ) 970 # Forward should output a ones tensor of self.rank. 971 self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule)) 972 rank_ones_tensor = created_script_module() 973 self.assertEqual(torch.ones(self.rank), rank_ones_tensor) 974 975 # Construct ScriptModule with rpc.remote. 976 remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,)) 977 # Verify it is an instance of ScriptModule on remote end. 978 remote_end_is_script = rpc.rpc_sync( 979 remote_script_module.owner(), 980 rref_isinstance, 981 args=(remote_script_module, torch.jit.ScriptModule), 982 ) 983 self.assertTrue(remote_end_is_script) 984 # Run forward pass remotely. 985 remote_forward_output = remote_script_module.rpc_sync().forward() 986 self.assertEqual(remote_forward_output, torch.ones(self.rank)) 987 # Run function defined on ScriptModule remotely. 988 remote_func_output = remote_script_module.rpc_sync().custom_func() 989 self.assertEqual(remote_func_output, torch.ones(self.rank)) 990 # Ensure we can transfer ScriptModule RRef to this rank and run 991 # forward pass. 992 local_script_module = remote_script_module.to_here() 993 self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule)) 994 rank_ones_tensor = local_script_module() 995 self.assertEqual(rank_ones_tensor, torch.ones(self.rank)) 996 local_script_func_output = local_script_module.custom_func() 997 self.assertEqual(local_script_func_output, torch.ones(self.rank)) 998 999 @dist_init 1000 def test_load_script_module_with_pickled_rref(self): 1001 dst_name = worker_name((self.rank + 1) % self.world_size) 1002 m1 = MyScriptModuleWithRRefs(dst_name) 1003 m2 = MyScriptModuleWithRRefs(dst_name) 1004 1005 f = io.BytesIO() 1006 1007 rpc._enable_jit_rref_pickle() 1008 torch.jit.save(m1, f) 1009 rpc._disable_jit_rref_pickle() 1010 1011 out1 = rpc.rpc_sync( 1012 dst_name, 1013 load_script_module_with_pickled_rref, 1014 args=(f.getvalue(),) 1015 ) 1016 out2 = m2() 1017 self.assertEqual(out1, out2) 1018 1019 @dist_init 1020 def test_rref_jit_pickle_not_supported(self): 1021 n = self.rank + 1 1022 dst_rank = n % self.world_size 1023 rref_var = rpc_return_rref(worker_name(dst_rank)) 1024 with TemporaryFileName() as fname: 1025 with self.assertRaisesRegex( 1026 RuntimeError, "RRef jit pickling is only allowed inside RPC calls" 1027 ): 1028 save_rref(rref_var, fname) 1029 1030 @dist_init 1031 def test_remote_script_throw(self): 1032 rref = rpc.remote( 1033 worker_name((self.rank + 1) % self.world_size), 1034 script_raise_func, 1035 args=(torch.ones(2),), 1036 ) 1037 with self.assertRaisesRegex(Exception, ".*Expected error.*"): 1038 rref.to_here() 1039 1040 @dist_init 1041 def test_remote_script_udf(self): 1042 rref = rpc.remote( 1043 worker_name((self.rank + 1) % self.world_size), 1044 script_fork_wait_udf, 1045 args=(torch.ones(2),), 1046 ) 1047 self.assertEqual(rref.to_here(), torch.ones(2) * 2) 1048 1049 @dist_init 1050 def test_async_script_udf(self): 1051 future = rpc.rpc_async( 1052 worker_name((self.rank + 1) % self.world_size), 1053 script_fork_wait_udf, 1054 args=(torch.ones(2),), 1055 ) 1056 self.assertEqual(future.wait(), torch.ones(2) * 2) 1057 1058 @dist_init 1059 def test_callback_simple(self): 1060 def callback(fut): 1061 return fut.wait() + 1 1062 1063 future = rpc.rpc_async( 1064 worker_name((self.rank + 1) % self.world_size), 1065 script_fork_wait_udf, 1066 args=(torch.ones(2),), 1067 ).then(callback) 1068 self.assertEqual(future.wait(), torch.ones(2) * 2 + 1) 1069 1070 @dist_init 1071 def test_callback_chain(self): 1072 n = self.rank + 1 1073 dst = worker_name(n % self.world_size) 1074 1075 def callback(fut): 1076 return fut.wait() + 1 1077 1078 fut = rpc.rpc_async( 1079 worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),) 1080 ) 1081 1082 num_cbs = 20 1083 for _ in range(num_cbs): 1084 fut = fut.then(callback) 1085 1086 self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) 1087 1088 @dist_init 1089 def test_add_done_callback(self): 1090 callback_called = None 1091 1092 def callback(fut): 1093 nonlocal callback_called 1094 callback_called = fut.wait() * 2 1095 1096 future = rpc.rpc_async( 1097 worker_name((self.rank + 1) % self.world_size), 1098 script_fork_wait_udf, 1099 args=(torch.ones(2),), 1100 ) 1101 1102 future.add_done_callback(callback) 1103 future_then = future.then(lambda _: True) 1104 1105 self.assertEqual(future.wait(), torch.ones(2) * 2) 1106 1107 # We have no guarantee that the add_done_callback fn will execute before the test finishes. 1108 # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback 1109 future_then.wait() 1110 self.assertEqual(callback_called, torch.ones(2) * 4) 1111 1112 @dist_init 1113 def test_async_script_throw(self): 1114 future = rpc.rpc_async( 1115 worker_name((self.rank + 1) % self.world_size), 1116 script_fork_wait_throw, 1117 args=(torch.ones(2),), 1118 ) 1119 with self.assertRaisesRegex(Exception, ".*Expected error.*"): 1120 future.wait() 1121 1122 @dist_init 1123 def test_callback_with_exception(self): 1124 def callback(fut): 1125 with self.assertRaisesRegex(Exception, ".*Expected error.*"): 1126 fut.wait() 1127 raise RuntimeError("Another expected error") 1128 1129 future = rpc.rpc_async( 1130 worker_name((self.rank + 1) % self.world_size), 1131 script_fork_wait_throw, 1132 args=(torch.ones(2),), 1133 ).then(callback) 1134 1135 with self.assertRaisesRegex(RuntimeError, "Another expected error"): 1136 future.wait() 1137 1138 @dist_init 1139 def test_call_rpc_with_profiling(self): 1140 # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit 1141 # future from within a script function that calls rpc_async 1142 if self.rank == 0: 1143 with _profile() as prof: 1144 prof_key = _build_rpc_profiling_key( 1145 RPCExecMode.ASYNC, 1146 torch._jit_internal._qualified_name(one_arg), 1147 "worker0", 1148 "worker1", 1149 ) 1150 with torch.autograd.profiler.record_function(prof_key) as rf: 1151 ret = call_rpc_with_profiling(rf.record, "worker1") 1152 # TODO: Can't get a reliable time for this profiling event since 1153 # it's hard to estimate the execution time on the remote end for non-UDFs. 1154 # This can be resolved by https://github.com/pytorch/pytorch/issues/36272. 1155 # After that, this test should be modified to validate the function time. 1156 events = prof.function_events 1157 function_event = get_function_event(events, prof_key) 1158 self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name) 1159 1160 @dist_init 1161 def test_rpc_async_jit_profiled(self): 1162 # Tests that rpc_async calls made from within a TorchScript function are 1163 # profiled. 1164 if self.rank == 0: 1165 dst_rank = (self.rank + 1) % self.world_size 1166 dst_worker_name = worker_name(dst_rank) 1167 args = (torch.tensor([1, 1]), torch.tensor([2, 2])) 1168 kwargs = {} 1169 with _profile() as prof: 1170 script_rpc_async_call( 1171 dst_worker_name, args, kwargs 1172 ) 1173 1174 # Ensure rpc_async call is profiled 1175 function_events = prof.function_events 1176 qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs) 1177 rpc_async_jit_event = [ 1178 event 1179 for event in function_events 1180 if qual_name in event.name and event.node_id == self.rank 1181 ] 1182 self.assertEqual(len(rpc_async_jit_event), 1) 1183 rpc_async_jit_event = rpc_async_jit_event[0] 1184 profiled_name = _build_rpc_profiling_key( 1185 RPCExecMode.ASYNC_JIT, 1186 qual_name, 1187 worker_name(self.rank), 1188 dst_worker_name, 1189 ) 1190 self.assertEqual(profiled_name, rpc_async_jit_event.name) 1191 remote_events = [event for event in function_events if event.is_remote] 1192 # All remote events should have taken place on dst_rank 1193 remote_event_node_ids = { 1194 remote_event.node_id for remote_event in remote_events 1195 } 1196 self.assertEqual(remote_event_node_ids, {dst_rank}) 1197 # script_rpc_async_call invokes add operator 1198 # so we should see this as a remote event. 1199 remote_add = next( 1200 remote_event 1201 for remote_event in remote_events 1202 if "aten::add" in remote_event.name 1203 ) 1204 remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add" 1205 self.assertEqual(remote_add.name, remote_add_profiled_name) 1206 1207 @dist_init 1208 def test_record_function_on_caller_rpc_async(self): 1209 if self.rank == 0: 1210 dst_rank = (self.rank + 1) % self.world_size 1211 dst_worker_name = worker_name(dst_rank) 1212 block_scope = "foo" 1213 with _profile() as prof: 1214 # Runs 2 rpc_async calls within JIT under record_function. 1215 record_function_on_caller_rpc_async(dst_worker_name, block_scope) 1216 1217 # Ensure record_function event is profiled. 1218 function_events = prof.function_events 1219 record_function_scope_event = [ 1220 event for event in function_events if event.name == block_scope 1221 ] 1222 self.assertEqual(1, len(record_function_scope_event)) 1223 record_function_scope_event = record_function_scope_event[0] 1224 # Ensure RPC future is profiled. 1225 expected_key = _build_rpc_profiling_key( 1226 RPCExecMode.ASYNC_JIT, 1227 torch._jit_internal._qualified_name(script_add_ones), 1228 worker_name(self.rank), 1229 dst_worker_name, 1230 ) 1231 jit_rpc_events = [ 1232 event for event in function_events if event.name == expected_key 1233 ] 1234 self.assertEqual(2, len(jit_rpc_events)) 1235 # Validate that the record_function scope time is greater than both 1236 # of the individual RPC async call times. The reason it is not necessarily 1237 # greater than the sum is because the two can execute in parallel. 1238 for jit_rpc_event in jit_rpc_events: 1239 self.assertTrue( 1240 record_function_scope_event.cpu_time_total 1241 > jit_rpc_event.cpu_time_total 1242 ) 1243 1244 @dist_init 1245 def test_rpc_torchscript_record_function(self): 1246 # tests that torchscript functions can be profiled using with 1247 # record_function(...) over RPC. 1248 REMOTE_OP_STR = "#remote_op: " 1249 if self.rank == 0: 1250 dst_rank = (self.rank + 1) % self.world_size 1251 dst_worker_name = worker_name(dst_rank) 1252 block_scope = "foo" 1253 with _profile() as prof: 1254 call_rpc_torchscript_with_record_function(dst_worker_name, block_scope) 1255 1256 # Need to call below to populate CPU children. 1257 prof.key_averages() 1258 function_events = prof.function_events 1259 expected_key = ( 1260 _build_rpc_profiling_key( 1261 RPCExecMode.ASYNC_JIT, 1262 torch._jit_internal._qualified_name( 1263 script_add_ones_with_record_function 1264 ), 1265 worker_name(self.rank), 1266 dst_worker_name, 1267 ) 1268 + REMOTE_OP_STR 1269 + block_scope 1270 ) 1271 remote_record_function_event = next( 1272 evt for evt in function_events if evt.name == expected_key 1273 ) 1274 self.assertTrue(block_scope in remote_record_function_event.name) 1275 remote_children = remote_record_function_event.cpu_children 1276 self.assertTrue("aten::add" in child.name for child in remote_children) 1277 1278 def test_record_function_jit_end_callbacks_with_fork(self): 1279 # Ensures that we can call rf._call_end_callbacks_on_future on a jit 1280 # future in python eager mode with torch.jit.fork 1281 sleep_interval = 1 1282 with _profile() as prof: 1283 with torch.autograd.profiler.record_function("foo") as rf: 1284 fut = torch.jit._fork(sleep, sleep_interval) 1285 rf._call_end_callbacks_on_future(fut) 1286 fut.wait() 1287 1288 function_events = prof.function_events 1289 sleep_event = get_function_event(function_events, "foo") 1290 self.assertEqual(sleep_event.name, "foo") 1291 # Validate that callbacks were fired at the right time by checking the 1292 # profiling event cpu time 1293 self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval) 1294 1295 def test_call_fork_in_jit_with_profiling(self): 1296 # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit 1297 # future from within a script function with torch.jit.fork 1298 with _profile() as prof: 1299 with torch.autograd.profiler.record_function("foo") as rf: 1300 ret = call_fork_with_profiling(rf.record) 1301 1302 events = prof.function_events 1303 function_event = get_function_event(events, "foo") 1304 self.assertEqual(function_event.name, "foo") 1305 1306 @dist_init 1307 def test_async_function_simple(self): 1308 dst1 = worker_name((self.rank + 1) % self.world_size) 1309 dst2 = worker_name((self.rank + 2) % self.world_size) 1310 1311 ret = rpc.rpc_sync( 1312 dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) 1313 ) 1314 self.assertEqual(ret, torch.ones(2, 2) + 1) 1315 1316 @dist_init 1317 def test_async_function_wrong_return_type(self): 1318 with self.assertRaisesRegex( 1319 RuntimeError, 1320 "Async functions must return an IValue of Future type, but got Tensor", 1321 ): 1322 rpc.rpc_sync( 1323 worker_name((self.rank + 1) % self.world_size), async_wrong_type 1324 ) 1325 1326 @dist_init 1327 def test_async_function_wrong_decorator_order(self): 1328 # @torch.jit.script complains about undefined value rpc. Error is shown 1329 # below. The reason for not checking error string is to avoid making 1330 # JIT error handling code depend on RPC tests, as we don't have any 1331 # restrictions on the error message here. 1332 # 1333 # RuntimeError: 1334 # undefined value rpc: 1335 # def async_wrong_decorator_order(to, x, y): 1336 # # type: (str, Tensor, Tensor) -> Future[Tensor] 1337 # return rpc.rpc_async(to, script_add, (x, y)) 1338 # ~~~ <--- HERE 1339 with self.assertRaises(RuntimeError): 1340 1341 @torch.jit.script 1342 @rpc.functions.async_execution 1343 def async_wrong_decorator_order( 1344 to: str, x: Tensor, y: Tensor 1345 ) -> Future[Tensor]: 1346 return rpc.rpc_async(to, script_add, (x, y)) 1347 1348 @dist_init 1349 def test_async_function_remote(self): 1350 dst1 = worker_name((self.rank + 1) % self.world_size) 1351 dst2 = worker_name((self.rank + 2) % self.world_size) 1352 1353 rref = rpc.remote( 1354 dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) 1355 ) 1356 self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) 1357 1358 @dist_init 1359 def test_async_function_remote_multi(self): 1360 dst1 = worker_name((self.rank + 1) % self.world_size) 1361 dst2 = worker_name((self.rank + 2) % self.world_size) 1362 1363 num = 20 1364 rrefs = [] 1365 for i in range(num): 1366 rrefs.append( 1367 rpc.remote( 1368 dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i) 1369 ) 1370 ) 1371 1372 for i in range(num): 1373 self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i) 1374 1375 @dist_init 1376 def test_async_function_wrong_return_type_remote(self): 1377 rref = rpc.remote( 1378 worker_name((self.rank + 1) % self.world_size), async_wrong_type 1379 ) 1380 1381 with self.assertRaisesRegex( 1382 RuntimeError, 1383 "Async functions must return an IValue of Future type, but got Tensor", 1384 ): 1385 rref.to_here() 1386