xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/jit/rpc_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import time
4import io
5from typing import Dict, List, Tuple, Any
6
7import torch
8import torch.distributed as dist
9import torch.distributed.rpc as rpc
10from torch import Tensor
11from torch.autograd.profiler import record_function
12from torch.distributed.rpc import RRef
13from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key
14from torch.futures import Future
15from torch.testing._internal.common_utils import TemporaryFileName
16from torch.testing._internal.dist_utils import (
17    dist_init,
18    get_function_event,
19    initialize_pg,
20    worker_name,
21)
22from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
23    RpcAgentTestFixture,
24)
25
26from torch.autograd.profiler_legacy import profile as _profile
27
28def rref_isinstance(rref, cls_to_check):
29    return isinstance(rref.local_value(), cls_to_check)
30
31def sleep(t):
32    time.sleep(t)
33
34
35def rpc_return_rref(dst):
36    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
37
38
39@torch.jit.script
40def rref_local_value(rref: RRef[Tensor]) -> Tensor:
41    return rref.local_value()
42
43
44@torch.jit.script
45def list_create() -> List[int]:
46    global_list = [1, 2, 3]
47    return global_list
48
49
50@torch.jit.script
51def rref_list_mutate(rref: RRef[List[int]]) -> None:
52    rref.local_value().append(4)
53    rref.to_here().append(5)
54    rref.to_here(5.0).append(6)
55
56
57def return_value(value: int) -> int:
58    return value
59
60
61class RRefAPITest:
62    @dist_init
63    def test_rref_is_owner(self):
64        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
65        rref_var = rpc_return_rref(dst_worker_name)
66
67        @torch.jit.script
68        def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
69            return rref_var.is_owner()
70
71        res = rref_tensor_is_owner(rref_var)
72        self.assertEqual(res, False)
73
74    @dist_init
75    def test_rref_local_value(self):
76        if self.rank != 0:
77            return
78
79        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
80        rref = rpc_return_rref(dst_worker_name)
81
82        with self.assertRaisesRegex(
83            RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
84        ):
85            rref_local_value(rref)
86
87        ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
88        self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))
89
90    @dist_init
91    def test_local_rref_local_value(self):
92        if self.rank != 0:
93            return
94
95        dst_worker_name = worker_name(self.rank)
96        rref = rpc.remote(dst_worker_name, return_value, (5,), {})
97
98        ret = rref_local_value(rref)
99        self.assertEqual(ret, 5)
100
101    def _create_rref(self):
102        owner_rank = (self.rank + 2) % self.world_size
103        return rpc.remote(
104            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
105        )
106
107    @dist_init
108    def test_user_rrefs_confirmed(self):
109        dst_rank = (self.rank + 1) % self.world_size
110        rref = self._create_rref()
111        ret = rpc.rpc_sync(
112            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
113        )
114        self.assertEqual(ret, True)
115
116    @dist_init
117    def test_user_rrefs_confirmed_remote(self):
118        dst_rank = (self.rank + 1) % self.world_size
119        rref = self._create_rref()
120        ret_rref = rpc.remote(
121            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
122        )
123        self.assertEqual(ret_rref.to_here(), True)
124
125    @dist_init
126    def test_rref_list_mutate(self):
127        dst = worker_name((self.rank + 1) % self.world_size)
128        list_rref = rpc.remote(dst, list_create)
129
130        rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
131        self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])
132
133
134@torch.jit.script
135def no_arg():
136    return 0
137
138
139@torch.jit.script
140def one_arg(value):
141    return value + 1
142
143@torch.jit.script
144def script_add_ones(x):
145    return torch.add(x, torch.ones(1))
146
147@torch.jit.script
148def script_add_ones_with_record_function(x, block: str):
149    with record_function(block):
150        return torch.add(x, torch.ones(1))
151
152
153@torch.jit.script
154def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
155    t: Tensor = torch.ones(1)
156    with record_function(block) as rf:
157        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
158        # Extra operator call to avoid de-duplication of the next async call
159        # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
160        zero = torch.zeros_like(t)
161        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
162        res = fut1.wait() + fut2.wait() + zero
163    return res
164
165
166
167@torch.jit.script
168def script_fork_wait_udf(tensor):
169    fut = torch.jit._fork(script_add_ones, tensor)
170    x = torch.jit._wait(fut)
171    return x
172
173
174@torch.jit.script
175def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
176    return rref_var.to_here()
177
178
179@torch.jit.script
180def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
181    return rref_var
182
183
184@torch.jit.script
185def script_raise_func(value):
186    if value.numel() == 2:
187        raise ValueError("Expected error")
188    return value + 1
189
190
191@torch.jit.script
192def script_fork_wait_throw(invalue):
193    fut = torch.jit._fork(script_raise_func, invalue)
194    value = torch.jit._wait(fut)
195    return value
196
197
198@torch.jit.script
199def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_worker_name: str) -> Tensor:
200    # Call rpc_async from within ScriptFunction and ensure that we can attach
201    # profiling callbacks. Note that handle here is a Tensor representation of
202    # RecordFunction.
203    fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
204    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
205    ret = fut.wait()
206    return ret
207
208@torch.jit.script
209def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor:
210    fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block))
211    return fut.wait()
212
213
214@torch.jit.script
215def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor:
216    # Call fork from within ScriptFunction and ensure that we can attach profiling
217    # callbacks to the resulting future. Note that handle here is a Tensor
218    # representation of RecordFunction.
219    fut = torch.jit._fork(one_arg, torch.tensor(1))
220    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
221    ret = fut.wait()
222    return ret
223
224
225class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
226    def __init__(self, dst_worker):
227        super().__init__()
228        self.rrefs = []
229        for _ in range(4):
230            self.rrefs.append(rpc_return_rref(dst_worker))
231
232    @torch.jit.script_method
233    def forward(self) -> Tensor:
234        res_tensor = torch.ones(2, 2)
235        for rref in self.rrefs:
236            res_tensor += rref.to_here()
237
238        return res_tensor
239
240
241@torch.jit.ignore
242def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
243    return rref_var
244
245
246@torch.jit.script
247def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
248    return rref_python_annotation(rref_var).to_here()
249
250
251class RRefTypingTest:
252    @dist_init
253    def test_rref_as_arg_and_return(self):
254        n = self.rank + 1
255        dst_rank = n % self.world_size
256        local_ret = one_arg(torch.ones(2, 2))
257
258        # create rref on current rank
259        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))
260
261        # pass rref to another user in rpc call
262        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
263        self.assertEqual(ret, local_ret)
264
265        # return rref in rpc call
266        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
267        self.assertEqual(rref1.to_here(), local_ret)
268
269        # pass rref to another user in remote call
270        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
271        self.assertEqual(rref2.to_here(), local_ret)
272
273        # return rref in remote call
274        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
275        self.assertEqual(rref3.to_here().to_here(), local_ret)
276
277    @dist_init
278    def test_my_script_module_with_rrefs(self):
279        n = self.rank + 1
280        dst_rank = n % self.world_size
281
282        module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
283        res = module_with_rrefs()
284        self.assertEqual(res, torch.ones(2, 2) * 9)
285
286    @dist_init
287    def test_rref_python_annotation(self):
288        n = self.rank + 1
289        dst_rank = n % self.world_size
290        rref_var = rpc_return_rref(worker_name(dst_rank))
291
292        res = rref_script_annotation(rref_var)
293        self.assertEqual(res, torch.ones(2, 2) + 1)
294
295
296class FutureTypingTest:
297    @dist_init
298    def test_future_passed_between_python_and_jit(self):
299        dst_rank = (self.rank + 1) % self.world_size
300        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
301        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
302        expected_res = torch.tensor([10, 10])
303
304        @torch.jit.script
305        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
306            return fut.wait()
307
308        self.assertEqual(future_wait_in_script(ret_fut), expected_res)
309
310        @torch.jit.script
311        def future_return_to_python(
312            dst_rank: int, inputs: Tuple[Tensor, Tensor]
313        ) -> Future[Tensor]:
314            return rpc.rpc_async(
315                f"worker{dst_rank}", two_args_two_kwargs, inputs
316            )
317
318        fut_res = future_return_to_python(dst_rank, inputs)
319        self.assertEqual(fut_res.wait(), expected_res)
320
321    @dist_init
322    def test_future_python_annotation(self):
323        if self.rank != 0:
324            return
325
326        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
327        input_0 = torch.ones(2, 2)
328        input_1 = 1
329        expected_res = torch.add(input_0, input_1)
330
331        @torch.jit.ignore
332        def python_return_future() -> Future[Tensor]:
333            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
334            return fut
335
336        @torch.jit.script
337        def script_use_future() -> Tensor:
338            fut = python_return_future()
339            return fut.wait()
340
341        res = script_use_future()
342        self.assertEqual(res, expected_res)
343
344
345@torch.jit.script
346class MyScriptClass:
347    def __init__(self, a: int):
348        self.a = a
349
350    def get_value(self) -> int:
351        return self.a
352
353
354@torch.jit.interface
355class MyModuleInterface(torch.nn.Module):
356    def forward(self) -> Tensor:
357        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
358        pass
359
360
361class MyScriptModule(torch.jit.ScriptModule):
362    def __init__(self, rank):
363        super().__init__()
364        self.a = torch.ones(rank)
365
366    @torch.jit.script_method
367    def forward(self) -> Tensor:
368        return self.a
369
370    @torch.jit.script_method
371    def custom_func(self) -> Tensor:
372        return self.a
373
374
375def owner_create_rref_my_script_class(a):
376    return rpc.RRef(MyScriptClass(a))
377
378
379def owner_create_rref_my_script_module(a):
380    return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface)
381
382
383@torch.jit.script
384def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int:
385    return rref.to_here().get_value()
386
387
388@torch.jit.script
389def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor:
390    return rref.to_here().forward()
391
392
393class LocalRRefTest:
394    @dist_init
395    def test_create_local_script_class_rref_in_py(self):
396        if self.rank != 0:
397            return
398
399        # Create a local RRef<MyScriptClass>.
400        rref_script_class = rpc.RRef(MyScriptClass(self.rank))
401        ret = rref_script_class.to_here().get_value()
402        self.assertEqual(ret, self.rank)
403
404    @dist_init
405    def test_create_local_script_module_rref_in_py(self):
406        if self.rank != 0:
407            return
408
409        # Create a local RRef<MyModuleInterface>.
410        rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
411        ret = rref_script_module.to_here().forward()
412        self.assertEqual(ret, torch.ones(self.rank))
413
414        # Create a local RRef<MyModuleInterface> without type hint.
415        with self.assertRaisesRegex(
416            RuntimeError,
417            (
418                "The RRef being created contains a ScriptModule, "
419                "must provide its ModuleInterface type hint."
420            ),
421        ):
422            rref_script_module = rpc.RRef(MyScriptModule(self.rank))
423
424    @dist_init
425    def test_return_local_script_class_rref_in_py_and_use_in_script(self):
426        if self.rank != 0:
427            return
428
429        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
430
431        # Create a local RRef<MyScriptClass> remotely in Python.
432        rref = rpc.rpc_sync(
433            dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,)
434        )
435
436        def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
437            args = (rref,)
438            kwargs: Dict[str, Any] = {}
439            fut = rpc.rpc_async(
440                rref.owner(), script_rref_get_value_my_script_class, args, kwargs
441            )
442            ret = fut.wait()
443            return ret
444
445        # Use RRef<MyScriptClass> in local Python RPC and remote Script run.
446        ret = use_rref_on_owner(rref)
447        self.assertEqual(ret, self.rank)
448
449        # Use RRef<MyScriptClass> in local Script RPC and remote Script run.
450        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
451        ret = use_rref_on_owner_script(rref)
452        self.assertEqual(ret, self.rank)
453
454    @dist_init
455    def test_return_local_script_module_rref_in_py_and_use_in_script(self):
456        if self.rank != 0:
457            return
458
459        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
460
461        # Create a local RRef<MyModuleInterface> remotely in Python.
462        rref = rpc.rpc_sync(
463            dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,)
464        )
465
466        def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
467            args = (rref,)
468            kwargs: Dict[str, Any] = {}
469            fut = rpc.rpc_async(
470                rref.owner_name(),
471                script_rref_run_forward_my_script_module,
472                args,
473                kwargs,
474            )
475            ret = fut.wait()
476            return ret
477
478        # Use RRef<MyScriptClass> in local Python RPC and remote Script run.
479        ret = use_rref_on_owner(rref)
480        self.assertEqual(ret, torch.ones(self.rank))
481
482        # Use RRef<MyScriptClass> in local Script RPC and remote Script run.
483        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
484        ret = use_rref_on_owner_script(rref)
485        self.assertEqual(ret, torch.ones(self.rank))
486
487
488def python_function():
489    return 0
490
491
492@torch.jit.script
493def two_args_two_kwargs(
494    first_arg,
495    second_arg,
496    first_kwarg=torch.tensor([3, 3]),
497    second_kwarg=torch.tensor([4, 4]),
498):
499    return first_arg + second_arg + first_kwarg + second_kwarg
500
501
502@torch.jit.script
503def assorted_types_args_kwargs(
504    tensor_arg: Tensor,  # noqa: E999
505    str_arg: str,
506    int_arg: int,
507    tensor_kwarg: Tensor = torch.tensor([2, 2]),
508    str_kwarg: str = "str_kwarg",
509    int_kwarg: int = 2,
510):
511    return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg
512
513
514@torch.jit.script
515def raise_script():
516    raise RuntimeError("Expected error")
517
518
519@torch.jit.script
520def script_rpc_async_call(
521    dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
522):
523    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
524    ret = fut.wait()
525    return ret
526
527@torch.jit.script
528def script_rpc_sync_call(
529    dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
530):
531    res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
532    return res
533
534@torch.jit.script
535def script_rpc_remote_call(
536    dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
537):
538    rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
539    return rref_res.to_here()
540
541class JitRpcOpTest:
542    # Call functions remotely from Script.
543    @dist_init
544    def test_all_kwargs_are_populated_by_defaults(self):
545        if self.rank != 0:
546            return
547
548        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
549
550        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
551        kwargs = {}
552
553        for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
554            ret = script_op(
555                dst_worker_name, args, kwargs
556            )
557            self.assertEqual(ret, torch.tensor([10, 10]))
558
559    @dist_init
560    def test_some_kwargs_are_populated_by_defaults(self):
561        if self.rank != 0:
562            return
563
564        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
565
566        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
567        kwargs = {"first_kwarg": torch.tensor([2, 2])}
568
569        for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
570            ret = script_op(
571                dst_worker_name, args, kwargs
572            )
573            self.assertEqual(ret, torch.tensor([9, 9]))
574
575    @dist_init
576    def test_no_kwargs_are_populated_by_defaults(self):
577        if self.rank != 0:
578            return
579
580        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
581
582        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
583        kwargs = {
584            "first_kwarg": torch.tensor([2, 2]),
585            "second_kwarg": torch.tensor([3, 3]),
586        }
587        for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
588            ret = script_op(
589                dst_worker_name, args, kwargs
590            )
591            self.assertEqual(ret, torch.tensor([8, 8]))
592
593    @dist_init
594    def test_args_and_kwargs_contain_different_types(self):
595        if self.rank != 0:
596            return
597
598        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
599
600        @torch.jit.script
601        def script_rpc_async_call_with_assorted_types(
602            dst_worker_name: str,
603        ):
604            args = (torch.tensor([1, 1]), "str_arg", 1)
605            # Must annotate the value type as `Any`, because JIT type inference
606            # does not support multiple types when defining a Dict.
607            # The error JIT gives is,
608            # "Dict values must contain only a single type, "
609            # "expected: Tensor but found str instead."
610            kwargs: Dict[str, Any] = {
611                "tensor_kwarg": torch.tensor([3, 3]),
612                "str_kwarg": "_str_kwarg",
613                "int_kwarg": 3,
614            }
615            fut = rpc.rpc_async(
616                dst_worker_name, assorted_types_args_kwargs, args, kwargs
617            )
618            ret = fut.wait()
619            return ret
620
621        ret = script_rpc_async_call_with_assorted_types(
622            dst_worker_name
623        )
624        self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
625
626    @dist_init
627    def test_kwargs_not_passed(self):
628        if self.rank != 0:
629            return
630
631        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
632
633        @torch.jit.script
634        def script_rpc_async_call_without_kwargs_passed(
635            dst_worker_name: str,
636        ):
637            args = ()
638            fut = rpc.rpc_async(dst_worker_name, no_arg, args)
639            ret = fut.wait()
640            return ret
641
642        ret = script_rpc_async_call_without_kwargs_passed(
643            dst_worker_name
644        )
645        self.assertEqual(ret, 0)
646
647    @dist_init
648    def test_args_kwargs_are_neither_passed(self):
649        if self.rank != 0:
650            return
651
652        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
653
654        @torch.jit.script
655        def script_rpc_async_call_without_args_kwargs_passed(
656            dst_worker_name: str,
657        ):
658            fut = rpc.rpc_async(dst_worker_name, no_arg)
659            ret = fut.wait()
660            return ret
661
662        ret = script_rpc_async_call_without_args_kwargs_passed(
663            dst_worker_name
664        )
665        self.assertEqual(ret, 0)
666
667    @dist_init
668    def test_less_than_needed_args_are_specified(self):
669        if self.rank != 0:
670            return
671
672        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
673
674        # Notice, args matching happens during scripting.
675        with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"):
676
677            @torch.jit.script
678            def script_rpc_async_call_with_less_args(
679                dst_worker_name: str,  # noqa: E999
680            ):
681                args = (torch.tensor([1, 1]),)
682                kwargs = {}
683                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
684                ret = fut.wait()
685                return ret
686
687    @dist_init
688    def test_more_than_needed_args_are_specified(self):
689        if self.rank != 0:
690            return
691
692        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
693
694        # Notice, args matching happens during scripting.
695        with self.assertRaisesRegex(
696            RuntimeError,
697            "Expected at most 4 arguments but found 5 positional arguments",
698        ):
699
700            @torch.jit.script
701            def script_rpc_async_call_with_more_args(
702                dst_worker_name: str,
703            ):
704                args = (
705                    torch.tensor([1, 1]),
706                    torch.tensor([2, 2]),
707                    torch.tensor([3, 3]),
708                    torch.tensor([4, 4]),
709                    torch.tensor([5, 5]),
710                )
711                kwargs = {}
712                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
713                ret = fut.wait()
714                return ret
715
716    @dist_init
717    def test_unexepected_kwarg_is_specified(self):
718        if self.rank != 0:
719            return
720
721        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
722
723        # Notice, kwargs matching happens during execution.
724        @torch.jit.script
725        def script_rpc_async_call_with_unexpected_kwarg(
726            dst_worker_name: str,  # noqa: E999
727        ):
728            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
729            kwargs = {"third_kwarg": torch.tensor([1, 1])}
730            fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
731            ret = fut.wait()
732            return ret
733
734        with self.assertRaisesRegex(
735            RuntimeError, "Unknown keyword argument 'third_kwarg'"
736        ):
737            ret = script_rpc_async_call_with_unexpected_kwarg(
738                dst_worker_name
739            )
740            self.assertEqual(ret, 0)
741
742    @dist_init
743    def test_call_python_function_remotely_from_script_not_supported(self):
744        if self.rank != 0:
745            return
746
747        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
748
749        @torch.jit.script
750        def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str):
751            args = ()
752            kwargs = {}
753            fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs)
754            ret = fut.wait()
755            return ret
756
757        with self.assertRaisesRegex(
758            RuntimeError, "attempted to get undefined function"
759        ):
760            ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name)
761            self.assertEqual(ret, 0)
762
763    @dist_init
764    def test_call_script_function_that_raises_remotely_from_script(self):
765        if self.rank != 0:
766            return
767
768        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
769
770        # Notice, TorchScript always translates(emits) Python `raise` statement,
771        # as the exception message string, "Exception",
772        # no matter what exception type and exception message are in the statement,
773        @torch.jit.script
774        def rpc_async_call_remote_raising_torchscript_in_torchscript(
775            dst_worker_name: str,
776        ):
777            args = ()
778            kwargs = {}
779            fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs)
780            ret = fut.wait()
781            return ret
782
783        with self.assertRaisesRegex(RuntimeError, "Expected error"):
784            ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
785                dst_worker_name
786            )
787            self.assertEqual(ret, 0)
788
789    @dist_init
790    def test_call_script_function_that_not_exists_remotely_from_script(self):
791        if self.rank != 0:
792            return
793
794        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
795
796        @torch.jit.script
797        def nonexisting_script():
798            return 0
799
800        @torch.jit.script
801        def rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
802            dst_worker_name: str,
803        ):
804            args = ()
805            kwargs = {}
806            fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs)
807            ret = fut.wait()
808            return ret
809
810        with self.assertRaisesRegex(
811            RuntimeError, "attempted to get undefined function nonexisting_script"
812        ):
813            ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
814                dst_worker_name
815            )
816            self.assertEqual(ret, 0)
817
818
819@torch.jit.ignore
820def my_script_module_init(rank: int) -> MyModuleInterface:
821    return MyScriptModule(rank)
822
823
824@torch.jit.script
825def construct_my_script_module(rank: int) -> MyModuleInterface:
826    return my_script_module_init(rank)
827
828
829@torch.jit.script
830def run_ref_script_module(
831    ref_script_module: RRef[MyModuleInterface], t: Tensor
832) -> Tensor:
833    module = ref_script_module.to_here()
834    return module.forward() + t
835
836
837@torch.jit.script
838def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool:
839    return rref.confirmed_by_owner()
840
841
842@torch.jit.script
843def save_rref(rref_var: RRef[Tensor], fname: str) -> None:
844    torch.save(rref_var, fname)
845
846
847@torch.jit.script
848def script_add(x: Tensor, y: Tensor) -> Tensor:
849    return x + y
850
851
852@rpc.functions.async_execution
853@torch.jit.script
854def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
855    return rpc.rpc_async(to, script_add, (x, y))
856
857
858@rpc.functions.async_execution
859@torch.jit.script
860def async_wrong_type() -> Tensor:
861    return torch.zeros(2)
862
863
864def load_script_module_with_pickled_rref(pickled_script_module):
865    f = io.BytesIO(pickled_script_module)
866    m = torch.jit.load(f)
867    return m()
868
869
870class JitRpcTest(
871    RRefAPITest,
872    RRefTypingTest,
873    LocalRRefTest,
874    JitRpcOpTest,
875    FutureTypingTest,
876    RpcAgentTestFixture,
877):
878    @dist_init
879    def test_torchscript_function(self):
880        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
881        local_ret = one_arg(torch.ones(2, 2))
882        ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
883        self.assertEqual(ret, local_ret)
884        rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
885        self.assertEqual(rref.to_here(), local_ret)
886        # create rref to itself
887        local_rref = rpc.remote(
888            worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)
889        )
890        self.assertEqual(local_rref.to_here(), local_ret)
891
892    @dist_init
893    def test_torchscript_function_exception(self):
894        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
895        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
896            ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
897
898        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
899            rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20))
900
901    @dist_init
902    def test_torchscript_functions_not_supported(self):
903        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
904
905        my_local_script_module = MyScriptModule(self.rank)
906
907        # It is not thread safe to instantiate MyScriptModule in multiple threads,
908        # wait for local MyScriptModule instantiation to finish,
909        # otherwise it could instantiate MyScriptModule in parallel with
910        # server thread in the below
911        initialize_pg(self.file_init_method, self.rank, self.world_size)
912        dist.barrier()
913
914        # rpc_sync still accepts script class and run it in
915        # the same code path as python call.
916        ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,))
917
918        # rpc_sync does not accept script module method.
919        # Python 3.5 and Python 3.6 throw different error message, the only
920        # common word can be greped is "pickle".
921        with self.assertRaisesRegex(TypeError, "pickle"):
922            ret = rpc.rpc_async(
923                dst_worker_name, my_local_script_module.forward, args=()
924            )
925
926    @dist_init
927    def test_remote_script_module(self):
928        # TODO, need more investigation
929        # there is rref leak when shutting down, suspect it is because
930        # ref as arg is passed to pybind boundary, and the ref is not garbage
931        # collected by python when calling shutdown()
932        import torch.distributed.rpc.api as api
933
934        api._ignore_rref_leak = True
935
936        local_ret = torch.ones(self.rank) + torch.ones(self.rank)
937
938        n = self.rank + 1
939        dst_rank = n % self.world_size
940        remote_ref = rpc.remote(
941            worker_name(dst_rank), construct_my_script_module, args=(self.rank,)
942        )
943
944        # pass rref arg to owner
945        ret = rpc.rpc_sync(
946            worker_name(dst_rank),
947            run_ref_script_module,
948            args=(remote_ref, torch.ones(self.rank)),
949        )
950        self.assertEqual(ret, local_ret)
951
952        # pass rref arg to self/user
953        with self.assertRaisesRegex(
954            RuntimeError,
955            "is an RRef to a ScriptModule. It can't be sent through RPC from owner,",
956        ):
957            ret = rpc.rpc_sync(
958                worker_name(self.rank),
959                run_ref_script_module,
960                args=(remote_ref, torch.ones(self.rank)),
961            )
962
963    @dist_init
964    def test_create_script_module_on_remote(self):
965        dst_name = worker_name((self.rank + 1) % self.world_size)
966        # Construct on remote end with rpc_sync
967        created_script_module = rpc.rpc_sync(
968            dst_name, MyScriptModule, args=(self.rank,)
969        )
970        # Forward should output a ones tensor of self.rank.
971        self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule))
972        rank_ones_tensor = created_script_module()
973        self.assertEqual(torch.ones(self.rank), rank_ones_tensor)
974
975        # Construct ScriptModule with rpc.remote.
976        remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,))
977        # Verify it is an instance of ScriptModule on remote end.
978        remote_end_is_script = rpc.rpc_sync(
979            remote_script_module.owner(),
980            rref_isinstance,
981            args=(remote_script_module, torch.jit.ScriptModule),
982        )
983        self.assertTrue(remote_end_is_script)
984        # Run forward pass remotely.
985        remote_forward_output = remote_script_module.rpc_sync().forward()
986        self.assertEqual(remote_forward_output, torch.ones(self.rank))
987        # Run function defined on ScriptModule remotely.
988        remote_func_output = remote_script_module.rpc_sync().custom_func()
989        self.assertEqual(remote_func_output, torch.ones(self.rank))
990        # Ensure we can transfer ScriptModule RRef to this rank and run
991        # forward pass.
992        local_script_module = remote_script_module.to_here()
993        self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule))
994        rank_ones_tensor = local_script_module()
995        self.assertEqual(rank_ones_tensor, torch.ones(self.rank))
996        local_script_func_output = local_script_module.custom_func()
997        self.assertEqual(local_script_func_output, torch.ones(self.rank))
998
999    @dist_init
1000    def test_load_script_module_with_pickled_rref(self):
1001        dst_name = worker_name((self.rank + 1) % self.world_size)
1002        m1 = MyScriptModuleWithRRefs(dst_name)
1003        m2 = MyScriptModuleWithRRefs(dst_name)
1004
1005        f = io.BytesIO()
1006
1007        rpc._enable_jit_rref_pickle()
1008        torch.jit.save(m1, f)
1009        rpc._disable_jit_rref_pickle()
1010
1011        out1 = rpc.rpc_sync(
1012            dst_name,
1013            load_script_module_with_pickled_rref,
1014            args=(f.getvalue(),)
1015        )
1016        out2 = m2()
1017        self.assertEqual(out1, out2)
1018
1019    @dist_init
1020    def test_rref_jit_pickle_not_supported(self):
1021        n = self.rank + 1
1022        dst_rank = n % self.world_size
1023        rref_var = rpc_return_rref(worker_name(dst_rank))
1024        with TemporaryFileName() as fname:
1025            with self.assertRaisesRegex(
1026                RuntimeError, "RRef jit pickling is only allowed inside RPC calls"
1027            ):
1028                save_rref(rref_var, fname)
1029
1030    @dist_init
1031    def test_remote_script_throw(self):
1032        rref = rpc.remote(
1033            worker_name((self.rank + 1) % self.world_size),
1034            script_raise_func,
1035            args=(torch.ones(2),),
1036        )
1037        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
1038            rref.to_here()
1039
1040    @dist_init
1041    def test_remote_script_udf(self):
1042        rref = rpc.remote(
1043            worker_name((self.rank + 1) % self.world_size),
1044            script_fork_wait_udf,
1045            args=(torch.ones(2),),
1046        )
1047        self.assertEqual(rref.to_here(), torch.ones(2) * 2)
1048
1049    @dist_init
1050    def test_async_script_udf(self):
1051        future = rpc.rpc_async(
1052            worker_name((self.rank + 1) % self.world_size),
1053            script_fork_wait_udf,
1054            args=(torch.ones(2),),
1055        )
1056        self.assertEqual(future.wait(), torch.ones(2) * 2)
1057
1058    @dist_init
1059    def test_callback_simple(self):
1060        def callback(fut):
1061            return fut.wait() + 1
1062
1063        future = rpc.rpc_async(
1064            worker_name((self.rank + 1) % self.world_size),
1065            script_fork_wait_udf,
1066            args=(torch.ones(2),),
1067        ).then(callback)
1068        self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
1069
1070    @dist_init
1071    def test_callback_chain(self):
1072        n = self.rank + 1
1073        dst = worker_name(n % self.world_size)
1074
1075        def callback(fut):
1076            return fut.wait() + 1
1077
1078        fut = rpc.rpc_async(
1079            worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),)
1080        )
1081
1082        num_cbs = 20
1083        for _ in range(num_cbs):
1084            fut = fut.then(callback)
1085
1086        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
1087
1088    @dist_init
1089    def test_add_done_callback(self):
1090        callback_called = None
1091
1092        def callback(fut):
1093            nonlocal callback_called
1094            callback_called = fut.wait() * 2
1095
1096        future = rpc.rpc_async(
1097            worker_name((self.rank + 1) % self.world_size),
1098            script_fork_wait_udf,
1099            args=(torch.ones(2),),
1100        )
1101
1102        future.add_done_callback(callback)
1103        future_then = future.then(lambda _: True)
1104
1105        self.assertEqual(future.wait(), torch.ones(2) * 2)
1106
1107        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
1108        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
1109        future_then.wait()
1110        self.assertEqual(callback_called, torch.ones(2) * 4)
1111
1112    @dist_init
1113    def test_async_script_throw(self):
1114        future = rpc.rpc_async(
1115            worker_name((self.rank + 1) % self.world_size),
1116            script_fork_wait_throw,
1117            args=(torch.ones(2),),
1118        )
1119        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
1120            future.wait()
1121
1122    @dist_init
1123    def test_callback_with_exception(self):
1124        def callback(fut):
1125            with self.assertRaisesRegex(Exception, ".*Expected error.*"):
1126                fut.wait()
1127            raise RuntimeError("Another expected error")
1128
1129        future = rpc.rpc_async(
1130            worker_name((self.rank + 1) % self.world_size),
1131            script_fork_wait_throw,
1132            args=(torch.ones(2),),
1133        ).then(callback)
1134
1135        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
1136            future.wait()
1137
1138    @dist_init
1139    def test_call_rpc_with_profiling(self):
1140        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
1141        # future from within a script function that calls rpc_async
1142        if self.rank == 0:
1143            with _profile() as prof:
1144                prof_key = _build_rpc_profiling_key(
1145                    RPCExecMode.ASYNC,
1146                    torch._jit_internal._qualified_name(one_arg),
1147                    "worker0",
1148                    "worker1",
1149                )
1150                with torch.autograd.profiler.record_function(prof_key) as rf:
1151                    ret = call_rpc_with_profiling(rf.record, "worker1")
1152            # TODO: Can't get a reliable time for this profiling event since
1153            # it's hard to estimate the execution time on the remote end for non-UDFs.
1154            # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
1155            # After that, this test should be modified to validate the function time.
1156            events = prof.function_events
1157            function_event = get_function_event(events, prof_key)
1158            self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name)
1159
1160    @dist_init
1161    def test_rpc_async_jit_profiled(self):
1162        # Tests that rpc_async calls made from within a TorchScript function are
1163        # profiled.
1164        if self.rank == 0:
1165            dst_rank = (self.rank + 1) % self.world_size
1166            dst_worker_name = worker_name(dst_rank)
1167            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
1168            kwargs = {}
1169            with _profile() as prof:
1170                script_rpc_async_call(
1171                    dst_worker_name, args, kwargs
1172                )
1173
1174            # Ensure rpc_async call is profiled
1175            function_events = prof.function_events
1176            qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs)
1177            rpc_async_jit_event = [
1178                event
1179                for event in function_events
1180                if qual_name in event.name and event.node_id == self.rank
1181            ]
1182            self.assertEqual(len(rpc_async_jit_event), 1)
1183            rpc_async_jit_event = rpc_async_jit_event[0]
1184            profiled_name = _build_rpc_profiling_key(
1185                RPCExecMode.ASYNC_JIT,
1186                qual_name,
1187                worker_name(self.rank),
1188                dst_worker_name,
1189            )
1190            self.assertEqual(profiled_name, rpc_async_jit_event.name)
1191            remote_events = [event for event in function_events if event.is_remote]
1192            # All remote events should have taken place on dst_rank
1193            remote_event_node_ids = {
1194                remote_event.node_id for remote_event in remote_events
1195            }
1196            self.assertEqual(remote_event_node_ids, {dst_rank})
1197            # script_rpc_async_call invokes add operator
1198            # so we should see this as a remote event.
1199            remote_add = next(
1200                remote_event
1201                for remote_event in remote_events
1202                if "aten::add" in remote_event.name
1203            )
1204            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
1205            self.assertEqual(remote_add.name, remote_add_profiled_name)
1206
1207    @dist_init
1208    def test_record_function_on_caller_rpc_async(self):
1209        if self.rank == 0:
1210            dst_rank = (self.rank + 1) % self.world_size
1211            dst_worker_name = worker_name(dst_rank)
1212            block_scope = "foo"
1213            with _profile() as prof:
1214                # Runs 2 rpc_async calls within JIT under record_function.
1215                record_function_on_caller_rpc_async(dst_worker_name, block_scope)
1216
1217            # Ensure record_function event is profiled.
1218            function_events = prof.function_events
1219            record_function_scope_event = [
1220                event for event in function_events if event.name == block_scope
1221            ]
1222            self.assertEqual(1, len(record_function_scope_event))
1223            record_function_scope_event = record_function_scope_event[0]
1224            # Ensure RPC future is profiled.
1225            expected_key = _build_rpc_profiling_key(
1226                RPCExecMode.ASYNC_JIT,
1227                torch._jit_internal._qualified_name(script_add_ones),
1228                worker_name(self.rank),
1229                dst_worker_name,
1230            )
1231            jit_rpc_events = [
1232                event for event in function_events if event.name == expected_key
1233            ]
1234            self.assertEqual(2, len(jit_rpc_events))
1235            # Validate that the record_function scope time is greater than both
1236            # of the individual RPC async call times. The reason it is not necessarily
1237            # greater than the sum is because the two can execute in parallel.
1238            for jit_rpc_event in jit_rpc_events:
1239                self.assertTrue(
1240                    record_function_scope_event.cpu_time_total
1241                    > jit_rpc_event.cpu_time_total
1242                )
1243
1244    @dist_init
1245    def test_rpc_torchscript_record_function(self):
1246        # tests that torchscript functions can be profiled using with
1247        # record_function(...) over RPC.
1248        REMOTE_OP_STR = "#remote_op: "
1249        if self.rank == 0:
1250            dst_rank = (self.rank + 1) % self.world_size
1251            dst_worker_name = worker_name(dst_rank)
1252            block_scope = "foo"
1253            with _profile() as prof:
1254                call_rpc_torchscript_with_record_function(dst_worker_name, block_scope)
1255
1256            # Need to call below to populate CPU children.
1257            prof.key_averages()
1258            function_events = prof.function_events
1259            expected_key = (
1260                _build_rpc_profiling_key(
1261                    RPCExecMode.ASYNC_JIT,
1262                    torch._jit_internal._qualified_name(
1263                        script_add_ones_with_record_function
1264                    ),
1265                    worker_name(self.rank),
1266                    dst_worker_name,
1267                )
1268                + REMOTE_OP_STR
1269                + block_scope
1270            )
1271            remote_record_function_event = next(
1272                evt for evt in function_events if evt.name == expected_key
1273            )
1274            self.assertTrue(block_scope in remote_record_function_event.name)
1275            remote_children = remote_record_function_event.cpu_children
1276            self.assertTrue("aten::add" in child.name for child in remote_children)
1277
1278    def test_record_function_jit_end_callbacks_with_fork(self):
1279        # Ensures that we can call rf._call_end_callbacks_on_future on a jit
1280        # future in python eager mode with torch.jit.fork
1281        sleep_interval = 1
1282        with _profile() as prof:
1283            with torch.autograd.profiler.record_function("foo") as rf:
1284                fut = torch.jit._fork(sleep, sleep_interval)
1285                rf._call_end_callbacks_on_future(fut)
1286            fut.wait()
1287
1288        function_events = prof.function_events
1289        sleep_event = get_function_event(function_events, "foo")
1290        self.assertEqual(sleep_event.name, "foo")
1291        # Validate that callbacks were fired at the right time by checking the
1292        # profiling event cpu time
1293        self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
1294
1295    def test_call_fork_in_jit_with_profiling(self):
1296        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
1297        # future from within a script function with torch.jit.fork
1298        with _profile() as prof:
1299            with torch.autograd.profiler.record_function("foo") as rf:
1300                ret = call_fork_with_profiling(rf.record)
1301
1302        events = prof.function_events
1303        function_event = get_function_event(events, "foo")
1304        self.assertEqual(function_event.name, "foo")
1305
1306    @dist_init
1307    def test_async_function_simple(self):
1308        dst1 = worker_name((self.rank + 1) % self.world_size)
1309        dst2 = worker_name((self.rank + 2) % self.world_size)
1310
1311        ret = rpc.rpc_sync(
1312            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
1313        )
1314        self.assertEqual(ret, torch.ones(2, 2) + 1)
1315
1316    @dist_init
1317    def test_async_function_wrong_return_type(self):
1318        with self.assertRaisesRegex(
1319            RuntimeError,
1320            "Async functions must return an IValue of Future type, but got Tensor",
1321        ):
1322            rpc.rpc_sync(
1323                worker_name((self.rank + 1) % self.world_size), async_wrong_type
1324            )
1325
1326    @dist_init
1327    def test_async_function_wrong_decorator_order(self):
1328        # @torch.jit.script complains about undefined value rpc. Error is shown
1329        # below. The reason for not checking error string is to avoid making
1330        # JIT error handling code depend on RPC tests, as we don't have any
1331        # restrictions on the error message here.
1332        #
1333        # RuntimeError:
1334        # undefined value rpc:
1335        # def async_wrong_decorator_order(to, x, y):
1336        #    # type: (str, Tensor, Tensor) -> Future[Tensor]
1337        #    return rpc.rpc_async(to, script_add, (x, y))
1338        #           ~~~ <--- HERE
1339        with self.assertRaises(RuntimeError):
1340
1341            @torch.jit.script
1342            @rpc.functions.async_execution
1343            def async_wrong_decorator_order(
1344                to: str, x: Tensor, y: Tensor
1345            ) -> Future[Tensor]:
1346                return rpc.rpc_async(to, script_add, (x, y))
1347
1348    @dist_init
1349    def test_async_function_remote(self):
1350        dst1 = worker_name((self.rank + 1) % self.world_size)
1351        dst2 = worker_name((self.rank + 2) % self.world_size)
1352
1353        rref = rpc.remote(
1354            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
1355        )
1356        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
1357
1358    @dist_init
1359    def test_async_function_remote_multi(self):
1360        dst1 = worker_name((self.rank + 1) % self.world_size)
1361        dst2 = worker_name((self.rank + 2) % self.world_size)
1362
1363        num = 20
1364        rrefs = []
1365        for i in range(num):
1366            rrefs.append(
1367                rpc.remote(
1368                    dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
1369                )
1370            )
1371
1372        for i in range(num):
1373            self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
1374
1375    @dist_init
1376    def test_async_function_wrong_return_type_remote(self):
1377        rref = rpc.remote(
1378            worker_name((self.rank + 1) % self.world_size), async_wrong_type
1379        )
1380
1381        with self.assertRaisesRegex(
1382            RuntimeError,
1383            "Async functions must return an IValue of Future type, but got Tensor",
1384        ):
1385            rref.to_here()
1386