xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/rpc_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import concurrent.futures
4import contextlib
5import json
6import os
7import sys
8import threading
9import time
10
11from collections import namedtuple
12from functools import partial
13from threading import Event
14from threading import Lock
15from unittest import mock
16
17import torch
18import torch.nn as nn
19import torch.distributed as dist
20import torch.distributed.rpc as rpc
21import torch.distributed.autograd as dist_autograd
22from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info, WorkerInfo
23from torch.distributed.rpc.api import _use_rpc_pickler, _thread_local_var, _wait_all
24from torch.distributed.rpc.internal import (
25    PythonUDF,
26    RPCExecMode,
27    _internal_rpc_pickler,
28    _build_rpc_profiling_key,
29)
30from torch.futures import Future
31from torch.testing._internal.common_distributed import (
32    skip_if_lt_x_gpu,
33    captured_output,
34    tp_transports,
35)
36from torch.testing._internal.common_utils import (
37    IS_MACOS,
38    load_tests,
39    skip_but_pass_in_sandcastle_if,
40    get_cycles_per_ms,
41)
42
43from torch.testing._internal.dist_utils import (
44    dist_init,
45    get_function_event,
46    initialize_pg,
47    wait_until_node_failure,
48    wait_until_pending_futures_and_users_flushed,
49    wait_until_owners_and_forks_on_rank,
50    worker_name,
51)
52from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
53    RpcAgentTestFixture,
54)
55from torch.testing._internal.common_utils import TemporaryFileName
56
57from torch.autograd.profiler_legacy import profile as _profile
58import operator
59
60
61def foo_add():
62    return torch.add(torch.ones(1), torch.ones(1))
63
64def udf_with_torch_ops(device=-1, use_record_function=False):
65    device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
66    record_function_ctx = (
67        torch.autograd.profiler.record_function("##forward##")
68        if use_record_function
69        else contextlib.nullcontext()
70    )
71    with device_ctx, record_function_ctx:
72        t1, t2 = torch.ones(1), torch.ones(1)
73        t = torch.add(t1, t2)
74        t = torch.mul(t, t)
75        t = t.relu()
76        t = t.sigmoid()
77
78# Events (operator invocations) that are expected to be ran as part of the above
79# function.
80EXPECTED_REMOTE_EVENTS = [
81    "aten::ones",
82    "aten::ones",
83    "aten::add",
84    "aten::mul",
85    "aten::relu",
86    "aten::clamp_min",
87    "aten::sigmoid",
88]
89
90# Remote operations are prefixed with the following string for RPC profiling.
91REMOTE_OP_STR = "#remote_op: "
92
93
94VALUE_FUTURE = concurrent.futures.Future()
95DONE_FUTURE = concurrent.futures.Future()
96
97FIFTY_MIL_CYCLES = 50000000
98
99_rpc_barrier_count = 0
100
101def _increment_count():
102    global _rpc_barrier_count
103    _rpc_barrier_count += 1
104
105def _reset_count():
106    global _rpc_barrier_count
107    _rpc_barrier_count = 0
108
109class StubRpcAgent:
110    def __init__(self, world_size):
111        self.world_size = world_size
112
113    def get_worker_infos(self):
114        return {
115            WorkerInfo(name=worker_name(rank), id=rank)
116            for rank in range(self.world_size)
117        }
118
119
120def _stub_construct_rpc_backend_options_handler(**kwargs):
121    return mock.Mock()  # RpcBackendOptions.
122
123
124def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options):
125    return StubRpcAgent(world_size=world_size)
126
127
128def set_value(value):
129    VALUE_FUTURE.set_result(value)
130
131
132def wait_for_value_future():
133    return VALUE_FUTURE.result()
134
135
136def set_and_check_done(value):
137    VALUE_FUTURE.set_result(value)
138    return DONE_FUTURE.result()
139
140
141# it is used to test python user defined function over rpc
142# classes and functions are used to test python user defined class and
143# methods over rpc
144TensorClass = namedtuple("TensorClass", ["tensors"])
145
146class MyPickleClass:
147    def __init__(self) -> None:
148        self.t = None
149
150    def __getstate__(self):
151        (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
152            PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
153        )
154        return (pickled_python_udf, tensors)
155
156    def __setstate__(self, obj):
157        python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
158        result = python_udf.func(python_udf.args[0], python_udf.args[1])
159        self.t = result
160
161    def set(self, val):
162        self.t = val
163
164
165class SlowPickleClass:
166    def __init__(self, t):
167        self.t = t
168
169    def __getstate__(self):
170        time.sleep(self.t)
171        return (self.t, )
172
173    def __setstate__(self, obj):
174        self.t = obj[0]
175        time.sleep(self.t)
176
177
178class MyClass:
179    def __init__(self, a, delay=False):
180        self.a = a
181        # delay initialization to simulate errors if specified
182        if delay:
183            time.sleep(2)
184
185    def my_instance_method(self, b):
186        return self.a + b
187
188    @classmethod
189    def my_class_method(cls, d, e):
190        return d + e
191
192    @staticmethod
193    def my_static_method(f):
194        return f > 10
195
196    def increment_value(self, increment):
197        self.a += increment
198
199    def get_value(self):
200        return self.a
201
202    def my_slow_method(self, my_tensor_arg):
203        time.sleep(5)
204        return torch.add(self.a, my_tensor_arg)
205
206
207def _call_method_on_rref(method, rref, *args, **kwargs):
208    return method(rref.local_value(), *args, **kwargs)
209
210
211def get_rref_list(values):
212    return [RRef(MyClass(a)) for a in values]
213
214
215def add_rref_to_value(rref, value):
216    return rref.to_here() + value
217
218
219def run_nested_pickle(pickle_cls_instance, tensor):
220    return pickle_cls_instance.t + tensor
221
222def build_sparse_tensor(coalesce=False):
223    i = [[0, 1, 1], [2, 0, 2]]
224    v = [3, 4, 5]
225    tensor = torch.sparse_coo_tensor(i, v, (2, 3))
226    if coalesce:
227        tensor = tensor.coalesce()
228    return tensor
229
230def build_complex_tensors():
231    a = torch.ones(3, 3)
232    b = [a, a]
233    c = [b, b]
234    d = [a, b]
235    e = {a: d}
236    return [a, b, c, d, e]
237
238def non_cont_test(t_view, t_cont):
239    if t_view.is_contiguous():
240        raise Exception('t_view is contiguous!')  # noqa: TRY002
241    if not t_cont.is_contiguous():
242        raise Exception('t_cont is not contiguous!')  # noqa: TRY002
243    if not torch.equal(t_view, t_cont):
244        raise Exception('t_view is not equal to t_cont!')  # noqa: TRY002
245    return t_view
246
247def my_function(a, b, c):
248    return a + b + c
249
250
251def my_tensor_function(a, b):
252    return a + b
253
254def my_container_sum(a):
255    result = a[0]
256    for tensor in a[1:]:
257        result += tensor
258    return result
259
260
261def my_sleep_func(seconds=1):
262    time.sleep(seconds)
263    return torch.mul(torch.tensor(1), torch.tensor(1))
264
265
266def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
267    res = list_input[0]
268    for t in list_input:
269        res += t
270    for v in dict_input.values():
271        res += v
272    complex_tensors = tensor_class_input.tensors
273    return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
274
275
276def my_rref_function(rref_a, rref_b):
277    return rref_a.to_here() + rref_b.to_here()
278
279
280def delayed_add(a, b, seconds=0.05):
281    time.sleep(seconds)
282    return a + b
283
284
285def identity(a):
286    return a
287
288def no_result():
289    print("do nothing")
290
291def raise_or_inc(value):
292    if value.numel() == 2:
293        raise ValueError("Expected error")
294    return value + 1
295
296def nested_rpc(dst):
297    return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
298
299
300def nested_rpc_sparse(dst):
301    return rpc.rpc_sync(
302        dst,
303        torch.add,
304        args=(build_sparse_tensor(), build_sparse_tensor())
305    )
306
307
308def multi_layer_nested_async_rpc(dst, world_size, ttl):
309    # this method returns immediately without blocking the callee, but will
310    # generate additional requests.
311    if ttl > 0:
312        current_dst = worker_name(dst)
313        next_dst = (dst + 1) % world_size
314        rpc.rpc_async(
315            current_dst,
316            multi_layer_nested_async_rpc,
317            args=(next_dst, world_size, ttl - 1),
318        )
319        return 0
320
321
322def nested_rref(dst):
323    return (
324        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
325        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
326    )
327
328
329def nested_rref_sparse(dst):
330    return (
331        rpc.remote(
332            dst,
333            torch.add,
334            args=(build_sparse_tensor(), build_sparse_tensor())
335        ),
336        rpc.remote(
337            dst,
338            torch.add,
339            args=(build_sparse_tensor(), build_sparse_tensor())
340        ),
341    )
342
343
344def nested_remote(dst):
345    rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
346    return rref.to_here()
347
348def nested_remote_sparse(dst):
349    rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor()))
350    return rref.to_here()
351
352
353def rref_forward_chain(dst, world_size, rref, ttl):
354    if ttl > 0:
355        current_dst = worker_name(dst)
356        next_dst = (dst + 1) % world_size
357        ret_rref = rpc.remote(
358            current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
359        )
360        return [ret_rref]
361    else:
362        return rref.to_here()
363
364
365def rpc_return_rref(dst):
366    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
367
368
369def light_rpc():
370    return 0
371
372
373def heavy_rpc(tensor):
374    for i in range(1, 100):
375        tensor *= i
376        tensor /= i + 1
377    return 0
378
379
380def heavy_rpc_sparse(tensor):
381    for i in range(1, 100):
382        tensor *= i
383        tensor = tensor / (i + 1)
384    return 0
385
386@torch.jit.script
387def heavy_rpc_torchscript(tensor):
388    for i in range(1, 100):
389        tensor *= i
390        tensor /= i + 1
391    return 0
392
393
394@torch.jit.script
395def my_script_func(tensor):
396    return torch.add(tensor, tensor)
397
398
399expected_err = "Expected error"
400
401# Note that it needs to inherit from Exception, not BaseException. See comment
402# in rpc/internal.py
403class CustomException(Exception):
404    def __init__(self, bool, msg):
405        self.bool = bool
406        super().__init__(msg)
407
408def raise_func():
409    raise ValueError(expected_err)
410
411def custom_raise_func():
412    raise CustomException(True, "foo")
413
414@torch.jit.script
415def raise_func_script(expected_err: str) -> torch.Tensor:
416    raise ValueError(expected_err)
417
418expected_err_escape = "\nFirst line of error \n next line of error \n last line of error"
419def raise_func_escape():
420    raise ValueError(expected_err_escape)
421
422
423global_rref = None
424
425
426def set_global_rref(rref):
427    global global_rref
428    global_rref = rref
429
430
431def clear_global_rref():
432    global global_rref
433    global_rref = None
434
435
436def check_rref_confirmed(rref):
437    return rref.confirmed_by_owner()
438
439
440def get_rref_debug_info():
441    return _rref_context_get_debug_info()
442
443
444def add_use_future_cb(to, x, y, z):
445    out = concurrent.futures.Future()
446
447    def callback(fut):
448        out.set_result(fut.wait() + z)
449
450    fut = rpc.rpc_async(to, torch.add, args=(x, y))
451    fut.then(callback)
452    return out.result()
453
454
455def get_events_from_profile(profile_rref):
456    return profile_rref.local_value().process_global_function_events
457
458
459def add_use_future_set_result(to, x, y, z):
460    out = torch.futures.Future()
461    fut = rpc.rpc_async(to, torch.add, args=(x, y))
462    fut.then(lambda fut : out.set_result(fut.wait() + z))
463    return out.wait()
464
465
466def add_use_future_nested_cb(to, x, y, z):
467    out = torch.futures.Future()
468
469    def callback(fut1):
470        fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
471        fut2.then(lambda fut2 : out.set_result(fut2.wait()))
472
473    fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
474    fut1.then(callback)
475    return out.wait()
476
477
478def fail_on_fut(fut):
479    pass
480
481
482@rpc.functions.async_execution
483def async_raise_func():
484    raise RuntimeError("Expected error")
485
486
487@rpc.functions.async_execution
488def async_wrong_type():
489    return torch.zeros(2, 2)
490
491
492@rpc.functions.async_execution
493def async_add(to, x, y):
494    return rpc.rpc_async(to, torch.add, args=(x, y))
495
496
497def slow_add(x, y, device="cpu"):
498    time.sleep(1)
499    x = x.to(device)
500    y = y.to(device)
501    return torch.add(x, y).cpu()
502
503
504@rpc.functions.async_execution
505def slow_async_add(to, x, y, device="cpu"):
506    return rpc.rpc_async(to, slow_add, args=(x, y, device))
507
508
509@rpc.functions.async_execution
510def async_add_with_future_ctor(to, x, y, z):
511    fut = torch.futures.Future()
512    rpc.rpc_async(to, torch.add, args=(x, y)).then(
513        lambda fut1: fut.set_result(fut1.wait() + z)
514    )
515    return fut
516
517
518@rpc.functions.async_execution
519def async_add_chained(to, x, y, z):
520    return rpc.rpc_async(to, torch.add, args=(x, y)).then(
521        lambda fut: fut.wait() + z
522    )
523
524
525@rpc.functions.async_execution
526def async_add_chained_multi(to, x, num, step):
527    fut = rpc.rpc_async(to, torch.add, args=(x, 0))
528    for _ in range(num):
529        fut = fut.then(lambda fut: fut.wait() + step)
530    return fut
531
532
533@rpc.functions.async_execution
534def async_add_nested(to, x, y, z):
535    return rpc.rpc_async(to, async_add, args=(to, x, y)).then(
536        lambda fut: fut.wait() + z
537    )
538
539
540@rpc.functions.async_execution
541def async_add_multi_fanout(to, x, num, step):
542    futs = []
543    for i in range(num):
544        if i == 0:
545            futs.append(rpc.rpc_async(to, torch.add, args=(x, step)))
546        else:
547            futs.append(rpc.rpc_async(to, torch.add, args=(0, step)))
548
549    # TODO: use torch.futures.collect_all
550    lock = Lock()
551    state = {"cnt": 0, "ret": torch.zeros_like(x)}
552    ret_future = torch.futures.Future()
553
554    def inc_and_set(fut):
555        with lock:
556            state["cnt"] += 1
557            state["ret"] += fut.wait()
558            if state["cnt"] >= len(futs):
559                ret_future.set_result(state["ret"])
560
561    for fut in futs:
562        fut.then(inc_and_set)
563
564    return ret_future
565
566
567@rpc.functions.async_execution
568def async_cuda_sleep_and_set_to_one(t):
569    device = t.device
570    original_stream = torch.cuda.current_stream(device)
571    new_stream = torch.cuda.Stream(device)
572    new_stream.wait_stream(original_stream)
573    with torch.cuda.stream(new_stream):
574        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
575        t.fill_(1)
576        fut = Future(devices=[device])
577        fut.set_result(t)
578        return fut
579
580
581@rpc.functions.async_execution
582def async_cuda_nested_add(to, x, y, z):
583    def cb(fut):
584        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
585        return fut.value() + z
586
587    return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb)
588
589
590# A custom Python class that contains a tensor, needed to see if we correctly
591# use the Python pickler to extract tensors from non-IValue-convertible types.
592class TensorWrapper:
593    __slots__ = ("tensor", "lock", "event", "thread")
594
595    def __init__(self, t):
596        self.tensor = t
597        # Add one non-picklable field, to ensure it's ignored/skipped.
598        self.lock = Lock()
599        self.event = torch.cuda.Event(enable_timing=True)
600        self.thread = threading.Thread()
601        self.thread.start()
602
603    def increase(self, v):
604        with self.lock:
605            self.tensor += v
606
607    def sum(self):
608        with self.lock:
609            self.event.record()
610            return self.tensor.sum()
611
612
613class AsyncExecutionClass:
614
615    @staticmethod
616    @rpc.functions.async_execution
617    def static_async_add(to, x, y, z):
618        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
619            lambda fut: fut.wait() + z
620        )
621
622    @classmethod
623    @rpc.functions.async_execution
624    def class_async_add(cls, to, x, y, z):
625        ret_fut = torch.futures.Future()
626        rpc.rpc_async(to, torch.add, args=(x, y)).then(
627            lambda fut: ret_fut.set_result(fut.wait() + z)
628        )
629        return ret_fut
630
631    @rpc.functions.async_execution
632    def bound_async_add(self, to, x, y, z):
633        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
634            lambda fut: fut.wait() + z
635        )
636
637
638def return_future():
639    return torch.futures.Future()
640
641
642class FooBackendOptions(rpc.RpcBackendOptions):
643    def __init__(self, init_method):
644        # Must call the __init__ of the superclass (and do so directly,
645        # without using super()) because... pybind.
646        rpc.RpcBackendOptions.__init__(self)
647        self.init_method = init_method
648
649
650# load_tests from common_utils is used to automatically filter tests for
651# sharding on sandcastle. This line silences flake warnings
652load_tests = load_tests
653
654
655class MyEmbeddingBagModel(torch.nn.Module):
656    def __init__(self, sparse):
657        super().__init__()
658        self.eb = torch.nn.EmbeddingBag(
659            10,
660            10,
661            sparse=sparse
662        )
663
664    def forward(self, x):
665        return self.eb(x)
666
667
668class MyParameterServer:
669    def __init__(self, trainers):
670        self.lock = Lock()
671        self.trainers = trainers
672        self.iteration = 0
673        self.updates = 0
674        self.futures = []
675        self.total = None
676        self.gradient = None
677
678    @staticmethod
679    def get_gradient(rref):
680        return rref.local_value().gradient
681
682    @staticmethod
683    @rpc.functions.async_execution
684    def average(rref, riteration, tensor):
685        self = rref.local_value()
686        fut = torch.futures.Future()
687        with self.lock:
688            if riteration > self.iteration:
689                self.iteration = riteration
690                self.updates = 0
691                self.futures.clear()
692            self.futures.append(fut)
693            if self.total is None:
694                self.total = tensor
695            else:
696                self.total += tensor
697            self.updates += 1
698            if self.trainers == self.updates:
699                self.gradient = self.total / float(self.trainers)
700                for fut in self.futures:
701                    result = self.total / float(self.trainers)
702                    fut.set_result(result)
703        return fut
704
705
706class MyConvNetForMNIST(nn.Module):
707    def __init__(self, device):
708        super().__init__()
709        self.net = nn.Sequential(
710            nn.Conv2d(1, 16, 3, 1),
711            nn.ReLU(),
712            nn.Conv2d(16, 32, 3, 1),
713            nn.ReLU(),
714            nn.MaxPool2d(2),
715            nn.Flatten(1),
716            nn.Linear(4608, 128),
717            nn.ReLU(),
718            nn.Linear(128, 10),
719        ).to(device)
720        self.device = device
721
722    def forward(self, x, is_rref=False):
723        x = x.to_here() if is_rref else x
724        with torch.cuda.stream(torch.cuda.current_stream(self.device)):
725            # intentionally adding delay to current CUDA stream
726            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
727            return self.net(x)
728
729    def __getstate__(self):
730        # return an empty dict to avoid inspecting the model contents on the
731        # owner
732        return {}
733
734
735class RpcTestCommon:
736    def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None):
737        if mode == RPCExecMode.SYNC:
738            return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs)
739        elif mode == RPCExecMode.ASYNC:
740            return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait()
741        elif mode == RPCExecMode.REMOTE:
742            return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here()
743
744    def _self_py_udf_remote(self, worker_info, x, y, z):
745        rref = rpc.remote(worker_info, my_function, args=(x, y, z))
746        self.assertEqual(rref.to_here(), x + y + z)
747
748    def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
749        self_worker_info = rpc.get_worker_info()
750        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
751        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
752        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
753        self.assertEqual(ret, x + y + z + x + y)
754        self.assertEqual(fut.wait(), x + y + z + x)
755
756    def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
757        self_worker_info = rpc.get_worker_info()
758        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
759        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
760        self.assertEqual(
761            ret_rref.to_here(), x + y + z + x
762        )
763
764    def _world_size_one(self, a, b):
765        if self.rank == 0:
766            rpc.init_rpc(
767                name="me",
768                backend=self.rpc_backend,
769                rank=0,
770                world_size=1,
771                rpc_backend_options=self.rpc_backend_options,
772            )
773
774            def _rpc_sync(x, y):
775                expect = x * 2
776                result = rpc.rpc_sync(
777                    "me",
778                    my_tensor_function,
779                    args=(x, y)
780                )
781                self.assertEqual(expect, result)
782
783            def _rpc_async(x, y):
784                expect = x * 2
785                result = rpc.rpc_async(
786                    "me",
787                    my_tensor_function,
788                    args=(x, y)
789                ).wait()
790                self.assertEqual(expect, result)
791
792            def _remote(x, y):
793                expect = x * 2
794                result = rpc.remote(
795                    "me",
796                    my_tensor_function,
797                    args=(x, y)
798                ).to_here()
799                self.assertEqual(expect, result)
800
801            _rpc_sync(a, b)
802            _rpc_async(a, b)
803            _remote(a, b)
804
805            rpc.shutdown()
806
807    def _multi_rpc(self, sparse):
808        dst_rank = (self.rank + 1) % self.world_size
809        for i in range(20):
810            n = i + self.rank + 1
811            if sparse:
812                x = build_sparse_tensor() * n
813                y = build_sparse_tensor() * n
814            else:
815                x = torch.ones(2, 2)
816                y = torch.ones(2, 2)
817            ret = rpc.rpc_sync(
818                worker_name(dst_rank),
819                torch.add,
820                args=(x, y),
821            )
822            self.assertEqual(ret, x * 2)
823
824    def _run_uneven_workload(self, f, x, num_repeat=30):
825        # worker0 drives and waits for worker1 and worker2
826        # throughout the test.
827        if self.rank == 0:
828            self.assertTrue(self.world_size >= 3)
829
830            # Phase 1: Only worker1 has workload.
831            dst = "worker1"
832            futs = []
833            for _ in range(num_repeat):
834                fut = rpc.rpc_async(dst, f, args=(x,))
835                futs.append(fut)
836
837            for fut in torch.futures.collect_all(futs).wait():
838                self.assertEqual(fut.wait(), 0)
839
840            # Phase 2: Only worker2 has workload.
841            # If join is not correctly implemented,
842            # worker2 should be closed by now.
843            dst = "worker2"
844            futs = []
845            for _ in range(num_repeat):
846                fut = rpc.rpc_async(dst, f, args=(x,))
847                futs.append(fut)
848
849            for val in torch.futures.wait_all(futs):
850                self.assertEqual(val, 0)
851
852    def _wait_all_workers(self, f, x):
853        initialize_pg(self.file_init_method, self.rank, self.world_size)
854        rpc.init_rpc(
855            name="worker%d" % self.rank,
856            backend=self.rpc_backend,
857            rank=self.rank,
858            world_size=self.world_size,
859            rpc_backend_options=self.rpc_backend_options,
860        )
861
862        self._run_uneven_workload(f, x)
863
864        # worker0 calls this at the end after waiting for RPC responses.
865        # worker1/2 calls this immediately and has some works after it.
866        # worker3 calls this immediately and has no more work.
867        rpc.api._wait_all_workers()
868
869        # Wait before proceeding to shutdown to ensure worker0 RPCs make
870        # it through to other workers.
871        dist.barrier()
872        rpc.shutdown(graceful=False)
873
874    def _wait_all_workers_twice(self, f, x):
875        initialize_pg(self.file_init_method, self.rank, self.world_size)
876        rpc.init_rpc(
877            name="worker%d" % self.rank,
878            backend=self.rpc_backend,
879            rank=self.rank,
880            world_size=self.world_size,
881            rpc_backend_options=self.rpc_backend_options,
882        )
883
884        self._run_uneven_workload(f, x)
885
886        # worker0 calls this at the end after waiting for RPC responses.
887        # worker1/2 calls this immediately and has some works after it.
888        # worker3 calls this immediately and has no more work.
889        rpc.api._wait_all_workers()
890        rpc.api._wait_all_workers()
891
892        # Wait before proceeding to shutdown to ensure worker0 RPCs make
893        # it through to other workers.
894        dist.barrier()
895        rpc.shutdown(graceful=False)
896
897    def _nested_rpc(self, f, expected):
898        n = self.rank + 1
899        dst_rank = n % self.world_size
900        ret = rpc.rpc_sync(
901            worker_name(dst_rank),
902            f,
903            args=(worker_name(self.rank),),
904        )
905        self.assertEqual(ret, expected)
906
907    def _stress_test_rpc(self, f, repeat=1000, args=()):
908        n = self.rank + 1
909        dst_rank = n % self.world_size
910        futs = []
911        tik = time.time()
912        for _ in range(repeat):
913            fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
914            futs.append(fut)
915
916        for val in torch.futures.wait_all(futs):
917            self.assertEqual(val, 0)
918        tok = time.time()
919        print(
920            f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds."
921        )
922
923    def _builtin_remote_ret(self, x, y, expected):
924        n = self.rank + 1
925        dst_rank = n % self.world_size
926        rref = rpc.remote(
927            worker_name(dst_rank),
928            torch.add,
929            args=(x, y),
930        )
931        self.assertEqual(rref.to_here(), expected)
932
933    def _builtin_remote_self(self, x, y, expected):
934        rref = rpc.remote(
935            worker_name(self.rank),
936            torch.add,
937            args=(x, y),
938        )
939        self.assertEqual(rref.local_value(), expected)
940
941    def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}):
942        m = 10
943        n = self.rank + 1
944        dst_rank = n % self.world_size
945        rrefs = []
946        expected = []
947        for i in range(m):
948            n = n + i
949            rrefs.append(
950                rpc.remote(
951                    worker_name(dst_rank),
952                    fn,
953                    args=args_fn(n, sparse),
954                    kwargs=kwargs_fn(n, sparse),
955                )
956            )
957            expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
958
959        for i in range(m):
960            self.assertEqual(rrefs[i].to_here(), expected[i])
961
962    def _py_rref_args(self, a, b, x, y, expected):
963        n = self.rank + 1
964        dst_rank = n % self.world_size
965        rref_a = rpc.remote(
966            worker_name(dst_rank), torch.add, args=(a, b)
967        )
968        rref_b = rpc.remote(
969            worker_name(dst_rank), torch.add, args=(x, y)
970        )
971        rref_c = rpc.remote(
972            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
973        )
974        self.assertEqual(rref_c.to_here(), expected)
975
976    def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
977        n = self.rank + 1
978        owner_rank = n % self.world_size
979        user_rank = (n + 1) % self.world_size
980        rref_a = rpc.remote(
981            worker_name(owner_rank), my_function, args=(a, b, c)
982        )
983        rref_b = rpc.remote(
984            worker_name(owner_rank), my_function, args=(x, y, z)
985        )
986        rref_c = rpc.remote(
987            worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
988        )
989        self.assertEqual(rref_c.to_here(), expected)
990
991    def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
992        n = self.rank + 1
993        dst_rank = n % self.world_size
994        rref_a = rpc.remote(
995            worker_name(dst_rank), my_function, args=(a, b, c)
996        )
997        rref_b = rpc.remote(
998            worker_name(dst_rank), my_function, args=(x, y, z)
999        )
1000
1001        c = rpc.rpc_sync(
1002            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
1003        )
1004        self.assertEqual(c, expected)
1005
1006    def _nested_remote(self, f, expected):
1007        n = self.rank + 1
1008        dst_rank1 = n % self.world_size
1009        dst_rank2 = (n + 1) % self.world_size
1010
1011        rref = rpc.remote(
1012            worker_name(dst_rank1),
1013            f,
1014            args=(worker_name(dst_rank2),),
1015        )
1016        self.assertEqual(rref.to_here(), expected)
1017
1018    def _nested_rref(self, f, expected1, expected2):
1019        n = self.rank + 1
1020        dst_rank1 = n % self.world_size
1021        dst_rank2 = (n + 1) % self.world_size
1022        rref_of_rrefs = rpc.remote(
1023            worker_name(dst_rank1),
1024            f,
1025            args=(worker_name(dst_rank2),),
1026        )
1027
1028        # Say C has 2 OwnerRRefs.
1029        # B has 2 UserRRefs to those 2 OwnerRRefs, respectively.
1030        # This call is effectively A asking B to share its 2 UserRRefs.
1031        rrefs = rref_of_rrefs.to_here()
1032
1033        self.assertEqual(len(rrefs), 2)
1034        self.assertEqual(rrefs[0].to_here(), expected1)
1035        self.assertEqual(rrefs[1].to_here(), expected2)
1036
1037    def _nested_rref_stress(self, f, expected1, expected2):
1038        n = self.rank + 1
1039        dst_rank1 = n % self.world_size
1040        dst_rank2 = (n + 1) % self.world_size
1041        all_rrefs = []
1042        for _ in range(20):
1043            all_rrefs.append(
1044                rpc.remote(
1045                    worker_name(dst_rank1),
1046                    f,
1047                    args=(worker_name(dst_rank2),),
1048                )
1049            )
1050
1051        for i in range(20):
1052            rref_of_rrefs = all_rrefs[i]
1053            rrefs = rref_of_rrefs.to_here()
1054            self.assertEqual(len(rrefs), 2)
1055            self.assertEqual(rrefs[0].to_here(), expected1)
1056            self.assertEqual(rrefs[1].to_here(), expected2)
1057
1058    def _trainer_func(self, rref, sparse):
1059        m = MyEmbeddingBagModel(sparse=sparse)
1060        loss_fn = nn.MSELoss()
1061        for i in range(10):
1062            outputs = m(torch.rand(10, 10).long())
1063            loss_fn(outputs, torch.rand(10, 10)).backward()
1064            gradient = next(iter(m.parameters())).grad
1065            fut = rref.rpc_async().average(rref, i, gradient)
1066            gradient = fut.wait()
1067            if gradient.is_sparse:
1068                gradient = gradient.to_dense().double()
1069            ps_gradient = rref.rpc_sync().get_gradient(rref)
1070            if ps_gradient.is_sparse:
1071                ps_gradient = ps_gradient.to_dense().double()
1072            self.assertTrue(torch.equal(gradient, ps_gradient))
1073
1074    def _my_parameter_server(self, sparse):
1075        ps_rref = RRef(MyParameterServer(self.world_size - 1))
1076        futures = []
1077        for index in range(1, self.world_size):
1078            futures.append(
1079                rpc.rpc_async(
1080                    worker_name((self.rank + index) % self.world_size),
1081                    self._trainer_func,
1082                    args=(
1083                        ps_rref,
1084                        sparse
1085                    ),
1086                )
1087            )
1088        torch.futures.wait_all(futures)
1089
1090    def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor):
1091        # We check proper CUDA stream synchronization by adding to the tensor
1092        # in one stream to get the expected value, and reading it from another stream.
1093        future = Future(devices=["cuda:0"])
1094        with torch.cuda.device("cuda:0"):
1095            stream = torch.cuda.Stream()
1096            another_stream = torch.cuda.Stream()
1097            with torch.cuda.stream(stream):
1098                if sparse_tensor:
1099                    tensor = build_sparse_tensor().to("cuda:0")
1100                    add_tensor = build_sparse_tensor().to("cuda:0")
1101                    expected_tensor = (tensor + add_tensor).coalesce()
1102                else:
1103                    tensor = torch.zeros((100,), device="cuda:0")
1104                    add_tensor = torch.ones((100,), device="cuda:0")
1105                    expected_tensor = tensor + add_tensor
1106                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
1107                tensor += add_tensor
1108                if sparse_tensor:
1109                    tensor = tensor.coalesce()
1110                future.set_result(wrapper(tensor))
1111            with torch.cuda.stream(another_stream):
1112                tensor = unwrapper(future.wait())
1113                if sparse_tensor:
1114                    self.assertTrue(torch.eq(tensor.indices(), expected_tensor.indices()).all().item())
1115                    self.assertTrue(torch.eq(tensor.values(), expected_tensor.values()).all().item())
1116                    self.assertEqual(tensor.size(), expected_tensor.size())
1117                else:
1118                    self.assertTrue(torch.eq(tensor, expected_tensor).all().item())
1119
1120
1121class RpcTest(RpcAgentTestFixture, RpcTestCommon):
1122    @dist_init
1123    def test_worker_id(self):
1124        n = self.rank + 1
1125        peer_rank = n % self.world_size
1126        self_worker_info = rpc.get_worker_info()
1127        peer_worker_info = rpc.get_worker_info(worker_name(peer_rank))
1128
1129        self.assertEqual(self_worker_info.name, worker_name(self.rank))
1130        self.assertEqual(peer_worker_info.name, worker_name(peer_rank))
1131
1132        with self.assertRaisesRegex(RuntimeError, "could not find destination"):
1133            unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
1134
1135    @dist_init
1136    def test_get_worker_infos(self):
1137        worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos()
1138
1139        worker_names = {worker_info.name for worker_info in worker_infos}
1140        expected_worker_names = {
1141            worker_name(rank) for rank in range(self.world_size)
1142        }
1143        self.assertEqual(worker_names, expected_worker_names)
1144
1145        worker_ids = {worker_info.id for worker_info in worker_infos}
1146        expected_worker_ids = set(range(self.world_size))
1147        self.assertEqual(worker_ids, expected_worker_ids)
1148
1149    @dist_init
1150    def test_self_add(self):
1151        self_worker_info = rpc.get_worker_info()
1152        self_worker_name = worker_name(self.rank)
1153        fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
1154        ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
1155        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
1156        self.assertEqual(ret, torch.ones(2, 2) + 1)
1157
1158    @dist_init
1159    def test_send_to_rank(self):
1160        dst_rank = (self.rank + 1) % self.world_size
1161
1162        # Test dense tensor
1163        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1164            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
1165            self.assertEqual(ret, torch.ones(2, 2) + 1)
1166
1167        # Test invalid ranks
1168        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1169            with self.assertRaises(RuntimeError):
1170                self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
1171
1172        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1173            with self.assertRaises(RuntimeError):
1174                self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
1175
1176        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1177            with self.assertRaises(ValueError):
1178                self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
1179
1180        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1181            with self.assertRaises(ValueError):
1182                self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1))
1183
1184    @dist_init
1185    def test_self_py_udf_remote(self):
1186        self._self_py_udf_remote(
1187            rpc.get_worker_info(),
1188            torch.ones(2, 2),
1189            1,
1190            3
1191        )
1192
1193    @dist_init
1194    def test_self_remote_rref_as_rpc_arg(self):
1195        dst = worker_name((self.rank + 1) % self.world_size)
1196        self._self_remote_rref_as_rpc_arg(
1197            dst,
1198            torch.ones(2, 2),
1199            1,
1200            3
1201        )
1202
1203    @dist_init
1204    def test_self_remote_rref_as_self_rpc_arg(self):
1205        self._self_remote_rref_as_rpc_arg(
1206            rpc.get_worker_info(),
1207            torch.ones(2, 2),
1208            1,
1209            3
1210        )
1211
1212    @dist_init
1213    def test_self_remote_rref_as_remote_arg(self):
1214        dst = worker_name((self.rank + 1) % self.world_size)
1215        self._self_remote_rref_as_remote_arg(
1216            dst,
1217            torch.ones(2, 2),
1218            1,
1219            3
1220        )
1221
1222    @dist_init
1223    def test_self_remote_rref_as_self_remote_arg(self):
1224        self._self_remote_rref_as_remote_arg(
1225            rpc.get_worker_info(),
1226            torch.ones(2, 2),
1227            1,
1228            3
1229        )
1230
1231    @dist_init
1232    def test_rref_proxy_non_exist(self):
1233        dst = worker_name((self.rank + 1) % self.world_size)
1234        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
1235        msg = "has no attribute 'non_exist'"
1236        with self.assertRaisesRegex(AttributeError, msg):
1237            rref.rpc_sync().non_exist()
1238
1239        with self.assertRaisesRegex(AttributeError, msg):
1240            rref.rpc_async().non_exist().wait()
1241
1242        with self.assertRaisesRegex(AttributeError, msg):
1243            rref.remote().non_exist()
1244
1245    def _test_rref_proxy_tensor(self, dst):
1246        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
1247
1248        expected = torch.ones(2, 2) + 1 + 3
1249        self.assertEqual(expected.size(), rref.rpc_sync().size())
1250        self.assertEqual(expected + 1, rref.rpc_async().add(1).wait())
1251        self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here())
1252
1253    @dist_init
1254    def test_rref_proxy_tensor(self):
1255        self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size))
1256
1257    @dist_init
1258    def test_rref_proxy_tensor_self(self):
1259        self._test_rref_proxy_tensor(rpc.get_worker_info())
1260
1261    @dist_init
1262    def test_rref_proxy_reuse(self):
1263        rref = rpc.remote(
1264            worker_name((self.rank + 1) % self.world_size),
1265            my_function,
1266            args=(torch.ones(2, 2), 1, 3)
1267        )
1268        expected = torch.ones(2, 2) + 1 + 3
1269
1270        proxy_rpc_sync = rref.rpc_sync()
1271        proxy_rpc_async = rref.rpc_async()
1272        proxy_remote = rref.remote()
1273
1274        self.assertEqual(expected.size(), proxy_rpc_sync.size())
1275        self.assertEqual(expected + 1, proxy_rpc_sync.add(1))
1276        self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4))
1277
1278        self.assertEqual(expected.size(), proxy_rpc_async.size().wait())
1279        self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait())
1280        self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait())
1281
1282        self.assertEqual(expected.size(), proxy_remote.size().to_here())
1283        self.assertEqual(expected + 5, proxy_remote.add(5).to_here())
1284        self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here())
1285
1286    def _test_rref_proxy_class(self, dst):
1287        rref = rpc.remote(dst, MyClass, args=(7,))
1288        expected = MyClass(7)
1289        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
1290        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
1291        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
1292
1293        expected.increment_value(3)
1294        self.assertEqual(None, rref.rpc_sync().increment_value(1))
1295        self.assertEqual(None, rref.rpc_async().increment_value(1).wait())
1296        self.assertEqual(None, rref.remote().increment_value(1).to_here())
1297
1298        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
1299        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
1300        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
1301
1302        self.assertEqual(
1303            expected.my_instance_method(2),
1304            rref.rpc_sync().my_instance_method(2)
1305        )
1306        self.assertEqual(
1307            expected.my_instance_method(3),
1308            rref.rpc_async().my_instance_method(3).wait()
1309        )
1310        self.assertEqual(
1311            expected.my_instance_method(4),
1312            rref.remote().my_instance_method(4).to_here()
1313        )
1314
1315        self.assertEqual(
1316            expected.my_static_method(9),
1317            rref.rpc_sync().my_static_method(9)
1318        )
1319        self.assertEqual(
1320            expected.my_static_method(10),
1321            rref.rpc_async().my_static_method(10).wait()
1322        )
1323        self.assertEqual(
1324            expected.my_static_method(11),
1325            rref.remote().my_static_method(11).to_here()
1326        )
1327
1328        self.assertEqual(
1329            expected.my_class_method(2, torch.zeros(2, 2)),
1330            rref.rpc_sync().my_class_method(2, torch.zeros(2, 2))
1331        )
1332        self.assertEqual(
1333            expected.my_class_method(2, torch.ones(3, 3)),
1334            rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait()
1335        )
1336        self.assertEqual(
1337            expected.my_class_method(2, torch.ones(4, 4)),
1338            rref.remote().my_class_method(2, torch.ones(4, 4)).to_here()
1339        )
1340
1341    @dist_init
1342    def test_rref_proxy_class(self):
1343        self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size))
1344
1345    @dist_init
1346    def test_rref_proxy_class_self(self):
1347        self._test_rref_proxy_class(rpc.get_worker_info())
1348
1349    @mock.patch.object(torch.distributed.autograd, "_init")
1350    @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
1351    @dist_init(setup_rpc=False)
1352    def test_register_rpc_backend_and_set_and_start_rpc_backend(
1353        self, mock_rpc_agent, mock_dist_autograd_init
1354    ):
1355        backend_name = "stub_backend"
1356
1357        backend = rpc.backend_registry.register_backend(
1358            backend_name,
1359            _stub_construct_rpc_backend_options_handler,
1360            _stub_init_rpc_backend_handler,
1361        )
1362
1363        with self.assertRaisesRegex(
1364            RuntimeError, "^RPC backend .+: already registered$"
1365        ):
1366            backend = rpc.backend_registry.register_backend(
1367                backend_name,
1368                _stub_construct_rpc_backend_options_handler,
1369                _stub_init_rpc_backend_handler,
1370            )
1371
1372        rpc.init_rpc(
1373            name="worker1",
1374            backend=backend,
1375            rank=self.rank,
1376            world_size=self.world_size,
1377            rpc_backend_options=self.rpc_backend_options,
1378        )
1379
1380    @dist_init(setup_rpc=False)
1381    def test_duplicate_name(self):
1382        with self.assertRaisesRegex(RuntimeError, "is not unique"):
1383            store, _, _ = next(
1384                torch.distributed.rendezvous(
1385                    self.init_method, rank=self.rank, world_size=self.world_size
1386                )
1387            )
1388            rpc._init_rpc_backend(
1389                backend=self.rpc_backend,
1390                store=store,
1391                name="duplicate_name",
1392                rank=self.rank,
1393                world_size=self.world_size,
1394                rpc_backend_options=self.rpc_backend_options,
1395            )
1396
1397    @dist_init(setup_rpc=False)
1398    def test_duplicate_name_2(self):
1399        with self.assertRaisesRegex(RuntimeError, "is not unique"):
1400            rpc.init_rpc(
1401                name=worker_name(self.rank % (self.world_size - 1)),
1402                backend=self.rpc_backend,
1403                rank=self.rank,
1404                world_size=self.world_size,
1405                rpc_backend_options=self.rpc_backend_options,
1406            )
1407
1408    @dist_init(setup_rpc=False)
1409    def test_reinit(self):
1410        rpc.init_rpc(
1411            name=worker_name(self.rank),
1412            backend=self.rpc_backend,
1413            rank=self.rank,
1414            world_size=self.world_size,
1415            rpc_backend_options=self.rpc_backend_options,
1416        )
1417
1418        initialize_pg(self.file_init_method, self.rank, self.world_size)
1419        # Wait for all init to complete.
1420        dist.barrier()
1421
1422        # TODO: with TCP init, rank 0 raises Address already in use because
1423        # rank 0 is the start daemon and the store is created before checking if
1424        # RPC is already initialized in init_rpc.
1425        if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0:
1426            expected_reinit_err = "Address already in use"
1427        else:
1428            expected_reinit_err = "is already initialized"
1429
1430        with self.assertRaisesRegex(RuntimeError, expected_reinit_err):
1431            rpc.init_rpc(
1432                name=worker_name(self.rank),
1433                backend=self.rpc_backend,
1434                rank=self.rank,
1435                world_size=self.world_size,
1436                rpc_backend_options=self.rpc_backend_options,
1437            )
1438        rpc.shutdown()
1439
1440    @dist_init(setup_rpc=False)
1441    def test_pg_init_no_rpc_init(self):
1442        dist.init_process_group(
1443            backend='gloo',
1444            init_method=self.file_init_method,
1445            rank=self.rank,
1446            world_size=self.world_size)
1447
1448        class MyModel(torch.nn.Module):
1449            def __init__(self) -> None:
1450                super().__init__()
1451                self.lin = torch.nn.Linear(3, 4)
1452
1453            def forward(self, x):
1454                return self.lin(x)
1455
1456        model = MyModel()
1457        model.train()
1458        model = torch.nn.parallel.DistributedDataParallel(model)
1459
1460        with self.assertRaisesRegex(RuntimeError, 'Current RPC agent is not set! Did you initialize the RPC framework'):
1461            params = []
1462            for param in model.parameters():
1463                params.append(RRef(param))
1464
1465    def test_world_size_one(self):
1466        self._world_size_one(
1467            torch.ones(2, 2),
1468            torch.ones(2, 2)
1469        )
1470
1471    @dist_init(setup_rpc=False)
1472    def test_invalid_names(self):
1473
1474        worker_id = 0
1475        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
1476            info = WorkerInfo("abc*", worker_id)
1477
1478        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
1479            info = WorkerInfo(" ", worker_id)
1480
1481        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
1482            info = WorkerInfo("", worker_id)
1483
1484        # If the number in the message does not match, it is likely that the
1485        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
1486        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
1487            info = WorkerInfo("".join(["a" for i in range(500)]), worker_id)
1488
1489    # Test that WorkerInfo can be pickled and sent in RPC call
1490    @dist_init
1491    def test_worker_info_pickle(self):
1492        dst_rank = (self.rank + 1) % self.world_size
1493        worker_info = rpc.api.get_worker_info()
1494        ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,))
1495        self.assertEqual(ret, worker_info)
1496
1497    @dist_init
1498    def test_add(self):
1499        n = self.rank + 1
1500        dst_rank = n % self.world_size
1501        ret = rpc.rpc_sync(
1502            worker_name(dst_rank),
1503            torch.add,
1504            args=(torch.ones(n, n), torch.ones(n, n)),
1505        )
1506        self.assertEqual(ret, torch.ones(n, n) * 2)
1507
1508    @staticmethod
1509    def return_callee_id():
1510        return rpc.get_worker_info().id
1511
1512    @dist_init
1513    def test_int_callee(self):
1514        dst_rank = (self.rank + 1) % self.world_size
1515        ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
1516        self.assertEqual(ret, dst_rank)
1517
1518    @dist_init
1519    def test_add_with_id(self):
1520        n = self.rank + 1
1521        dst_rank = n % self.world_size
1522        workder_info = rpc.get_worker_info(worker_name(dst_rank))
1523
1524        ret = rpc.rpc_sync(
1525            workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
1526        )
1527        self.assertEqual(ret, torch.ones(n, n) * 2)
1528
1529    @dist_init
1530    def test_scalar_add(self):
1531        n = self.rank + 1
1532        dst_rank = n % self.world_size
1533        ret = rpc.rpc_sync(
1534            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n)
1535        )
1536        self.assertEqual(ret, (torch.ones(n, n) + n))
1537
1538    @dist_init
1539    def test_async_add(self):
1540        n = self.rank + 1
1541        dst_rank = n % self.world_size
1542        fut = rpc.rpc_async(
1543            worker_name(dst_rank),
1544            torch.add,
1545            args=(torch.ones(n, n), torch.ones(n, n)),
1546        )
1547        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
1548
1549    @dist_init
1550    def test_nonzero(self):
1551        n = self.rank + 1
1552        dst_rank = n % self.world_size
1553        x = torch.ones(self.world_size, self.world_size)
1554        x[self.rank][self.rank] = 0
1555        ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
1556        self.assertEqual(ret, x.nonzero())
1557
1558    @dist_init
1559    def test_multi_rpc(self):
1560        self._multi_rpc(False)
1561
1562    @dist_init
1563    def test_future_wait_twice(self):
1564        dst = worker_name((self.rank + 1) % self.world_size)
1565        futs = []
1566        for i in range(20):
1567            futs.append(rpc.rpc_async(dst, raise_func))
1568
1569        with self.assertRaisesRegex(ValueError, "Expected error"):
1570            torch.futures.wait_all(futs)
1571
1572        for fut in futs:
1573            with self.assertRaisesRegex(ValueError, "Expected error"):
1574                fut.wait()
1575
1576    @dist_init(setup_rpc=False)
1577    def test_wait_all_workers_timeout(self):
1578        initialize_pg(self.file_init_method, self.rank, self.world_size)
1579
1580        rpc.init_rpc(
1581            name=worker_name(self.rank),
1582            backend=self.rpc_backend,
1583            rank=self.rank,
1584            world_size=self.world_size,
1585            rpc_backend_options=self.rpc_backend_options,
1586        )
1587
1588        og_func = rpc.api._wait_all_workers
1589
1590        def wait_all_workers_sleep(timeout):
1591            rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
1592
1593        rpc.api._wait_all_workers = wait_all_workers_sleep
1594
1595        try:
1596            with self.assertRaisesRegex(RuntimeError, ''):
1597                rpc.shutdown(graceful=True, timeout=0.01)
1598        finally:
1599            rpc.api._wait_all_workers = og_func
1600        dist.barrier()
1601
1602    def test_wait_all_workers_dense(self):
1603        self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
1604
1605    def test_wait_all_workers_twice_dense(self):
1606        self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
1607
1608    @dist_init
1609    def test_all_gather(self):
1610        info = rpc.get_worker_info()
1611        results = rpc.api._all_gather(info.id)
1612        expected = {}
1613        for info in rpc._get_current_rpc_agent().get_worker_infos():
1614            expected[info.name] = info.id
1615
1616        self.assertEqual(expected, results)
1617
1618    @dist_init
1619    def test_all_gather_timeout(self):
1620        rpc._set_rpc_timeout(0.1)
1621
1622        if self.rank == 0:
1623            with self.assertRaisesRegex(
1624                RuntimeError,
1625                "timed out in _all_gather after 0\\.10 seconds"
1626            ):
1627                rpc.api._all_gather(SlowPickleClass(0.5))
1628        else:
1629            expected_error = self.get_timeout_error_regex()
1630            with self.assertRaisesRegex(RuntimeError, expected_error):
1631                rpc.api._all_gather(SlowPickleClass(0.5))
1632
1633    def _test_barrier_helper(self, info, names, multi_threaded=False):
1634        names = sorted(names)
1635        leader = names[0]
1636        rpc.rpc_sync(leader, _reset_count)
1637        if not multi_threaded and info.name == leader:
1638            self.assertEqual(_rpc_barrier_count, 0)
1639        rpc.api._barrier(names)
1640        rpc.rpc_sync(leader, _increment_count)
1641        rpc.api._barrier(names)
1642        if not multi_threaded and info.name == leader:
1643            self.assertEqual(_rpc_barrier_count, len(names))
1644
1645    @dist_init
1646    def test_rpc_barrier_all(self):
1647        # Test rpc barrier when called with full list of workers
1648        info = rpc.get_worker_info()
1649        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
1650        names = [worker.name for worker in all_worker_info]
1651        self._test_barrier_helper(info, names)
1652
1653    @dist_init
1654    def test_rpc_barrier_subset(self):
1655        # Test rpc barrier when processes are called with different subsets of the full list
1656        info = rpc.get_worker_info()
1657        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
1658        if info.id % 2:
1659            names = [worker.name for worker in all_worker_info if worker.id % 2]
1660        else:
1661            names = [worker.name for worker in all_worker_info if not worker.id % 2]
1662        self._test_barrier_helper(info, names)
1663
1664    @dist_init
1665    def test_rpc_barrier_partial_subset(self):
1666        # Test rpc barrier when some processes are not involved in the barrier
1667        info = rpc.get_worker_info()
1668        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
1669        if info.id % 2:
1670            names = [worker.name for worker in all_worker_info if worker.id % 2]
1671        else:
1672            names = [f"worker{info.id}"]
1673        self._test_barrier_helper(info, names)
1674
1675    @dist_init
1676    def test_rpc_barrier_multithreaded(self):
1677        # This tests validates the implementation of barrier when multiple threads call into it
1678        # We only need to check that it does not hang in this case
1679        info = rpc.get_worker_info()
1680        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
1681        names = [worker.name for worker in all_worker_info]
1682        threads = []
1683        for _ in range(3):
1684            th = threading.Thread(target=self._test_barrier_helper, args=(info, names, True))
1685            threads.append(th)
1686            th.start()
1687        for th in threads:
1688            th.join()
1689
1690    @dist_init
1691    def test_graceful_shutdown_with_uneven_workload(self):
1692        """Test graceful termination."""
1693        self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
1694
1695    @dist_init(setup_rpc=False)
1696    def test_shutdown_followed_by_rpc(self):
1697        # Initialize RPC.
1698        rpc.init_rpc(
1699            name="worker%d" % self.rank,
1700            backend=self.rpc_backend,
1701            rank=self.rank,
1702            world_size=self.world_size,
1703            rpc_backend_options=self.rpc_backend_options,
1704        )
1705
1706        n = self.rank + 1
1707        dst_rank = n % self.world_size
1708        ret = rpc.rpc_sync(
1709            worker_name(dst_rank),
1710            torch.add,
1711            args=(torch.ones(n, n), torch.ones(n, n)),
1712        )
1713        self.assertEqual(ret, torch.ones(n, n) * 2)
1714        rpc.shutdown()
1715
1716        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
1717            rpc.rpc_sync(
1718                worker_name(dst_rank),
1719                torch.add,
1720                args=(torch.ones(n, n), torch.ones(n, n)),
1721            )
1722
1723    @dist_init
1724    def test_expected_src(self):
1725        dst_rank = (self.rank + 1) % self.world_size
1726        expected_src_rank = (self.rank - 1) % self.world_size
1727        ret = rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,))
1728        value = VALUE_FUTURE.result()
1729        self.assertEqual(value, expected_src_rank)
1730
1731    @dist_init
1732    def test_py_built_in(self):
1733        n = self.rank + 1
1734        dst_rank = n % self.world_size
1735        ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2))
1736        self.assertEqual(ret, min(n, n + 1, n + 2))
1737
1738    @dist_init
1739    def test_py_user_defined(self):
1740        n = self.rank + 1
1741        dst_rank = n % self.world_size
1742        ret = rpc.rpc_sync(
1743            worker_name(dst_rank),
1744            my_function,
1745            kwargs={"a": n, "b": n + 1, "c": n + 2},
1746        )
1747        self.assertEqual(ret, my_function(n, n + 1, n + 2))
1748
1749    def test_build_rpc_profiling_key(self):
1750        # Tests that the name that shows up as an Event in profiling RPCs has all
1751        # the necessary information.
1752        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
1753            rpc_profiling_key = _build_rpc_profiling_key(
1754                exec_mode, "foo", "worker0", "worker1"
1755            )
1756            self.assertIn(exec_mode.value, rpc_profiling_key)
1757            self.assertIn("foo", rpc_profiling_key)
1758            self.assertIn("worker0", rpc_profiling_key)
1759            self.assertIn("worker1", rpc_profiling_key)
1760
1761    def check_profiling_info(self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode):
1762        self.assertTrue(self_worker_name in rpc_event.name)
1763        self.assertTrue(dst_worker_name in rpc_event.name)
1764        if isinstance(func, torch.jit.ScriptFunction):
1765            self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name)
1766        else:
1767            self.assertTrue(func.__name__ in rpc_event.name)
1768        self.assertTrue(rpc_exec_mode.value in rpc_event.name)
1769        self.assertEqual(rpc_event.count, 1)
1770
1771    @dist_init
1772    def test_profiler_rpc_record_shapes(self):
1773        if self.rank != 1:
1774            return
1775        dst = (self.rank + 1) % self.world_size
1776        dst_worker = worker_name(dst)
1777        t1, t2 = torch.ones(100), torch.ones(100)
1778        with _profile(record_shapes=True) as prof:
1779            rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2))
1780
1781        function_events = prof.function_events
1782        remote_events = [event for event in function_events if event.is_remote]
1783        remote_add_event = next(
1784            event for event in remote_events if "aten::add" in event.name
1785        )
1786        remote_add_input_shapes = remote_add_event.input_shapes
1787        # Run profiler on equivalent local op and validate shapes are the same.
1788        with _profile(record_shapes=True) as prof:
1789            torch.add(t1, t2)
1790
1791        local_function_events = prof.function_events
1792        local_add_event = next(
1793            event for event in local_function_events if "aten::add" in event.name
1794        )
1795        local_add_input_shapes = local_add_event.input_shapes
1796        self.assertEqual(remote_add_input_shapes, local_add_input_shapes)
1797
1798    @dist_init
1799    def test_profiler_rpc_memory(self):
1800        if self.rank != 1:
1801            return
1802        dst = (self.rank + 1) % self.world_size
1803        dst_worker = worker_name(dst)
1804        with _profile(profile_memory=True) as p:
1805            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
1806            res = fut.wait()
1807
1808        function_events = p.function_events
1809        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
1810        # if cpu_memory_usage was not propagated over the wire, this set would
1811        # only contain 0 (indicates no memory being profiled)
1812        self.assertNotEqual({0}, event_cpu_mem_usages)
1813        # No memory profiled if profile_memory=False
1814        with _profile(profile_memory=False) as p:
1815            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
1816            res = fut.wait()
1817
1818        function_events = p.function_events
1819        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
1820        self.assertEqual({0}, event_cpu_mem_usages)
1821
1822    @dist_init
1823    def test_profiler_export_trace(self):
1824        if self.rank != 1:
1825            return
1826        dst = (self.rank + 1) % self.world_size
1827        dst_worker = worker_name(dst)
1828        with _profile() as p:
1829            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
1830            res = fut.wait()
1831
1832        events = p.function_events
1833        with TemporaryFileName() as fname:
1834            path = fname
1835            p.export_chrome_trace(path)
1836            with open(path) as f:
1837                trace = json.load(f)
1838                event_names = [event['name'] for event in trace]
1839                for expected_event_name in EXPECTED_REMOTE_EVENTS + [RPCExecMode.ASYNC.value]:
1840                    event_exists = any(expected_event_name in event_name for event_name in event_names)
1841                    self.assertTrue(event_exists)
1842
1843    @dist_init
1844    def test_profiler_rpc_key_names(self):
1845        # tests that remote events are properly prefixed with the RPC profiling key.
1846        if self.rank != 1:
1847            return
1848
1849        # Spawn multiple threads that send RPCs to ensure keys are correctly
1850        # prefixed when there are multiple RPCs being created/in flight at the
1851        # same time.
1852        dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank]
1853
1854        def rpc_with_profiling(dst_worker):
1855            with _profile() as prof:
1856                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
1857                fut.wait()
1858
1859            events = prof.function_events
1860            remote_event_names = {
1861                event.name: event for event in events if event.is_remote
1862            }
1863            rpc_profiling_key = _build_rpc_profiling_key(
1864                RPCExecMode.ASYNC,
1865                udf_with_torch_ops.__qualname__,
1866                worker_name(self.rank),
1867                dst_worker,
1868            )
1869
1870            remote_event_name_set = set(EXPECTED_REMOTE_EVENTS)
1871            for name, event in remote_event_names.items():
1872                # Ensure that we have the expected key as part of the remote
1873                # event.
1874                self.assertTrue(name.startswith(rpc_profiling_key))
1875                self.assertTrue(event.is_remote)
1876                self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id)
1877                # Ensure that the remote event name also contains the operator.
1878                operator_name_substr = name[len(rpc_profiling_key) :]
1879                # Note: we don't assert that every remote event needs to be
1880                # in the above set, the set is just a representative set of
1881                # what we expect to see. The profiler can change and add more
1882                # events, but we should always expect to see this representative
1883                # set.
1884                matching_event = {
1885                    remote_event_name
1886                    for remote_event_name in remote_event_name_set
1887                    if remote_event_name in operator_name_substr
1888                }
1889                remote_event_name_set -= matching_event
1890
1891            # The set should be empty, otherwise its contained elements did
1892            # not show up in the remote profiler output.
1893            self.assertTrue(
1894                remote_event_name_set == set(),
1895                f"Expected {remote_event_name_set} to be included in remote profiler output.",
1896            )
1897
1898        for dst in dst_ranks:
1899            dst_worker = worker_name(dst)
1900            num_parallel_rpcs = 2
1901            with concurrent.futures.ThreadPoolExecutor(
1902                max_workers=num_parallel_rpcs
1903            ) as executor:
1904                futs = [
1905                    executor.submit(rpc_with_profiling, dst_worker)
1906                    for _ in range(num_parallel_rpcs)
1907                ]
1908                # Wait for workers to finish test
1909                for fut in futs:
1910                    fut.result()
1911
1912    def _run_test_profiler_remote_events_profiled(self):
1913        # Tests that we can successfully invoke the profiler on a remote node,
1914        # and collect the remote events back in the local profiler.
1915        if self.rank != 1:
1916            return
1917
1918        dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank]
1919        for dst in dst_ranks:
1920            dst_worker = worker_name(dst)
1921            with _profile() as prof:
1922                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
1923                ret = fut.wait()
1924
1925            events = prof.function_events
1926
1927            rpc_event = get_function_event(events, RPCExecMode.ASYNC.value)
1928            self.check_profiling_info(
1929                worker_name(self.rank),
1930                dst_worker,
1931                udf_with_torch_ops,
1932                rpc_event,
1933                RPCExecMode.ASYNC,
1934            )
1935
1936            remote_events = {event.name: event for event in events if event.is_remote}
1937            rpc_profiling_key = _build_rpc_profiling_key(
1938                RPCExecMode.ASYNC,
1939                udf_with_torch_ops.__qualname__,
1940                worker_name(self.rank),
1941                worker_name(dst),
1942            )
1943
1944            for expected_remote_event_name in EXPECTED_REMOTE_EVENTS:
1945                expected_key = rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name
1946                self.assertTrue(expected_key in remote_events)
1947                remote_event = remote_events[expected_key]
1948                # Remote event should have a node ID corresponding to the worker
1949                # it ran on.
1950                self.assertEqual(remote_event.node_id, dst)
1951
1952            # Validate order remote events show up in profiling output.
1953            def convert_remote_to_local(event_name):
1954                remote_op_key = rpc_profiling_key + REMOTE_OP_STR
1955                return event_name[
1956                    event_name.find(remote_op_key)
1957                    + len(remote_op_key) :
1958                ]
1959
1960            remote_events_list = [
1961                convert_remote_to_local(event.name)
1962                for event in events
1963                if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS
1964            ]
1965            self.assertEqual(
1966                set(remote_events_list),
1967                set(EXPECTED_REMOTE_EVENTS),
1968                f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}",
1969            )
1970
1971    @dist_init
1972    def test_profiler_remote_events_profiled(self):
1973        self._run_test_profiler_remote_events_profiled()
1974
1975    @dist_init
1976    def test_profiler_remote_events_profiled_single_threaded(self):
1977        self._run_test_profiler_remote_events_profiled()
1978
1979    def run_profiling_workload(self, dst):
1980        fut = rpc.rpc_async(
1981            worker_name(dst),
1982            torch.mul,
1983            args=(
1984                torch.tensor(1.0, requires_grad=True),
1985                torch.tensor(1.0, requires_grad=True),
1986            ),
1987        )
1988        fut.wait()
1989
1990    def _run_rpc_profiling_async_function(self, device="cpu"):
1991        if self.rank != 1:
1992            return
1993
1994        dst1 = worker_name((self.rank + 1) % self.world_size)
1995        dst2 = worker_name((self.rank + 2) % self.world_size)
1996        x = torch.ones(2)
1997        y = torch.ones(2)
1998        with _profile() as prof:
1999            ret = rpc.rpc_async(
2000                dst1, slow_async_add, args=(dst2, x, y, device), timeout=20
2001            )
2002            out = ret.wait()
2003
2004        function_events = prof.function_events
2005        # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be
2006        # recorded.
2007        key_prefix = _build_rpc_profiling_key(
2008            RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1
2009        )
2010
2011        nested_rpc_key_prefix = _build_rpc_profiling_key(
2012            RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2
2013        )
2014        expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix
2015        remote_events = [event for event in function_events if event.is_remote]
2016        rpc_remote_event = [
2017            event for event in remote_events if event.name == expected_key
2018        ]
2019        self.assertEqual(1, len(rpc_remote_event))
2020        rpc_remote_event = rpc_remote_event[0]
2021        self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size)
2022        # slow_async_add's RPC does an add on dst2, which should be reflected as well.
2023        remote_add_key = (
2024            expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add)
2025        )
2026        remote_add_event = [
2027            event for event in remote_events if event.name == remote_add_key
2028        ]
2029        self.assertEqual(1, len(remote_add_event))
2030        remote_add_event = remote_add_event[0]
2031        # Validate that node_id is dst2.
2032        self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size)
2033
2034    @dist_init
2035    def test_rpc_profiling_async_function(self):
2036        initialize_pg(self.file_init_method, self.rank, self.world_size)
2037        self._run_rpc_profiling_async_function()
2038        if torch.cuda.is_available():
2039            dist.barrier()
2040            self._run_rpc_profiling_async_function(device="cuda:0")
2041
2042    @dist_init
2043    def test_rpc_profiling_async_function_single_threaded(self):
2044        initialize_pg(self.file_init_method, self.rank, self.world_size)
2045        self._run_rpc_profiling_async_function()
2046        if torch.cuda.is_available():
2047            dist.barrier()
2048            self._run_rpc_profiling_async_function(device="cuda:0")
2049
2050    @dist_init
2051    def test_rpc_profiling_remote_record_function(self):
2052        # test that functions run over RPC with record_function show the expected
2053        # profiled block.
2054        if self.rank != 1:
2055            return
2056        dst_ranks = [i for i in range(self.world_size) if i != self.rank]
2057        for dst_rank in dst_ranks:
2058            dst_worker = worker_name(dst_rank)
2059            with _profile() as prof:
2060                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True))
2061                fut.wait()
2062
2063            function_events = prof.function_events
2064            record_function_remote_event = [
2065                evt for evt in function_events if "##forward##" in evt.name
2066            ]
2067            self.assertEqual(1, len(record_function_remote_event))
2068            record_function_remote_event = record_function_remote_event[0]
2069            self.assertEqual(record_function_remote_event.node_id, dst_rank)
2070            # cpu_children only returns direct children, so here we get all
2071            # children recursively.
2072
2073            def get_cpu_children(event):
2074                if not event.cpu_children:
2075                    return []
2076                cpu_children = event.cpu_children
2077                for e in event.cpu_children:
2078                    cpu_children.extend(get_cpu_children(e))
2079                return cpu_children
2080
2081            remote_children = get_cpu_children(record_function_remote_event)
2082            # Get local children and verify parity.
2083            with _profile() as prof:
2084                udf_with_torch_ops(-1, True)
2085
2086            local_function_events = prof.function_events
2087            local_record_function_event = next(
2088                evt for evt in local_function_events if "##forward##" in evt.name
2089            )
2090            local_children = get_cpu_children(local_record_function_event)
2091            local_children_names = [
2092                evt.name for evt in local_children
2093            ]
2094
2095            REMOTE_OP_STR = "#remote_op: "
2096
2097            def convert_remote_to_local(event_name):
2098                remote_op_key = REMOTE_OP_STR
2099                return event_name[
2100                    event_name.find(remote_op_key) + len(remote_op_key) :
2101                ]
2102
2103            for evt in remote_children:
2104                local_name = convert_remote_to_local(evt.name)
2105                self.assertTrue(local_name in local_children_names)
2106
2107    def validate_profiling_workload(self, dst, prof):
2108
2109        def convert_remote_to_local(event_name):
2110            return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
2111
2112        events = prof.function_events
2113        remote_events = {
2114            convert_remote_to_local(event.name): event
2115            for event in events
2116            if event.is_remote
2117        }
2118        self.assertTrue("aten::mul" in remote_events)
2119        remote_mul_event = remote_events["aten::mul"]
2120        self.assertEqual(remote_mul_event.node_id, dst)
2121        self.check_profiling_info(
2122            worker_name(self.rank),
2123            worker_name(dst),
2124            torch.mul,
2125            remote_mul_event,
2126            RPCExecMode.ASYNC,
2127        )
2128
2129    def _run_test_profiler_with_autograd_context(self):
2130        dst = (self.rank + 1) % self.world_size
2131        if self.rank == 1:
2132            # Cases where we can double wrap messages with profiling information and autograd info.
2133            with dist_autograd.context() as context_id:
2134                with _profile() as prof:
2135                    self.run_profiling_workload(dst)
2136
2137            self.validate_profiling_workload(dst, prof)
2138
2139            # Ensure that flipped order of ctx managers results in events being
2140            # recorded as expected.
2141            with _profile() as prof:
2142                with dist_autograd.context() as context_id:
2143                    self.run_profiling_workload(dst)
2144
2145            self.validate_profiling_workload(dst, prof)
2146
2147    @dist_init
2148    def test_profiler_with_autograd_context_single_threaded(self):
2149        self._run_test_profiler_with_autograd_context()
2150
2151    @dist_init
2152    def test_profiler_with_autograd_context(self):
2153        self._run_test_profiler_with_autograd_context()
2154
2155    def _profiler_test_with_rpc(
2156        self, rpc_exec_mode, func, args, use_record_function=False, dst=None, kineto_profile=False
2157    ):
2158        dst = dst if dst is not None else (self.rank + 1) % self.world_size
2159
2160        # only run profiler on rank 1.
2161        p = _profile if not kineto_profile else torch.profiler.profile  # kineto
2162        if self.rank == 1:
2163            with p() as prof:
2164                record_function_ctx_mgr = (
2165                    contextlib.nullcontext()
2166                    if not use_record_function
2167                    else torch.autograd.profiler.record_function(
2168                        "foo"
2169                    )
2170                )
2171                with record_function_ctx_mgr as rf:
2172                    if rpc_exec_mode == RPCExecMode.SYNC:
2173                        rpc.rpc_sync(worker_name(dst), func, args=args)
2174                    elif rpc_exec_mode == RPCExecMode.ASYNC:
2175                        fut = rpc.rpc_async(worker_name(dst), func, args=args)
2176                        if kineto_profile:
2177                            # Ensure multiple async RPCs don't cause issues.
2178                            # Would have raised
2179                            # "RuntimeError: Cannot call
2180                            # RemoteProfilerManager::setCurrentKey when current
2181                            # key is already set." error if RPC profiling was
2182                            # not disabled properly for kineto.
2183                            fut2 = rpc.rpc_async(worker_name(dst), func, args=args)
2184                            fut2.wait()
2185                        fut.wait()
2186                    else:
2187                        self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE)
2188                        rref = rpc.remote(worker_name(dst), func, args=args)
2189                        rref.to_here()
2190                        # To avoid flakiness, wait for the RRef to be profiled. This
2191                        # means that we received the acknowledgement of successful
2192                        # creation on the owner and ran the callbacks responsible
2193                        # for recording the profiling event.
2194                        rref._get_profiling_future().wait()
2195
2196            events = prof.function_events if not kineto_profile else prof.events()
2197            if kineto_profile:
2198                # RPC profiling is disabled so there should be no rpc related
2199                # events.
2200                with self.assertRaises(IndexError):
2201                    get_function_event(events, rpc_exec_mode.value)
2202
2203                return
2204
2205            rpc_event = get_function_event(events, rpc_exec_mode.value)
2206            # verify Node ID for this rpc event.
2207            self.assertEqual(rpc_event.node_id, self.rank)
2208            # Ensure recording of remote events.
2209            remote_events = {event for event in events if event.node_id == dst} - {rpc_event}
2210            self.assertGreaterEqual(len(remote_events), 1)
2211            for remote_event in remote_events:
2212                self.assertEqual(remote_event.node_id, dst)
2213
2214            if use_record_function:
2215                scope_event = get_function_event(events, "foo")
2216                # Since RPC call is within the scope, its CPU interval should be
2217                # contained within foo's interval.
2218                self.assertLessEqual(scope_event.time_range.start, rpc_event.time_range.start)
2219                self.assertGreaterEqual(scope_event.time_range.end, rpc_event.time_range.end)
2220            # the sender, dest worker, function run, and type of RPC should all
2221            # be recorded.
2222            self_worker_name = worker_name(self.rank)
2223            dst_worker_name = worker_name(dst)
2224            self.check_profiling_info(self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode)
2225            if use_record_function:
2226                # verify order by ensuring that the outer context comes
2227                # before the rpc event.
2228                foo_event_ix = next(i for i, event in enumerate(events) if "foo" in event.name)
2229                rpc_event_idx = next(i for i, event in enumerate(events) if rpc_exec_mode.value in event.name)
2230                self.assertLess(foo_event_ix, rpc_event_idx)
2231
2232    def _run_test_profiler_with_sync_rpc_udf(self):
2233        self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,))
2234        self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,),
2235                                     use_record_function=True)
2236
2237    @dist_init
2238    def test_profiler_with_sync_rpc_udf(self):
2239        self._run_test_profiler_with_sync_rpc_udf()
2240
2241    @dist_init
2242    def test_profiler_with_sync_rpc_udf_single_threaded(self):
2243        self._run_test_profiler_with_sync_rpc_udf()
2244
2245    def _run_test_profiler_with_sync_rpc_builtin(self):
2246        self._profiler_test_with_rpc(
2247            RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
2248        )
2249        self._profiler_test_with_rpc(
2250            RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)),
2251            use_record_function=True
2252        )
2253
2254    @dist_init
2255    def test_profiler_with_sync_rpc_builtin(self):
2256        self._run_test_profiler_with_sync_rpc_builtin()
2257
2258    @dist_init
2259    def test_profiler_with_sync_rpc_builtin_single_threaded(self):
2260        self._run_test_profiler_with_sync_rpc_builtin()
2261
2262    def _run_test_profiler_with_async_rpc_udf(self):
2263        self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,))
2264        self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,),
2265                                     use_record_function=True)
2266        # Test to ensure that kineto profiler enabled in RPC does not enable
2267        # RPC profiling (it is unsupported) and does not result in issues.
2268        self._profiler_test_with_rpc(
2269            RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True
2270        )
2271
2272    @dist_init
2273    def test_profiler_with_async_rpc_udf(self):
2274        self._run_test_profiler_with_async_rpc_udf()
2275
2276    @dist_init
2277    def test_profiler_with_async_rpc_udf_single_threaded(self):
2278        self._run_test_profiler_with_async_rpc_udf()
2279
2280    def _run_test_profiler_with_async_rpc_builtin(self):
2281        self._profiler_test_with_rpc(
2282            RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
2283        )
2284        self._profiler_test_with_rpc(
2285            RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)),
2286            use_record_function=True
2287        )
2288
2289    @dist_init
2290    def test_profiler_with_async_rpc_builtin(self):
2291        self._run_test_profiler_with_async_rpc_builtin()
2292
2293    @dist_init
2294    def test_profiler_with_async_rpc_builtin_single_threaded(self):
2295        self._run_test_profiler_with_async_rpc_builtin()
2296
2297    def _run_test_profiler_with_remote_udf(self):
2298        self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,))
2299        self._profiler_test_with_rpc(
2300            RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True
2301        )
2302        # test remote to self
2303        self._profiler_test_with_rpc(
2304            RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank
2305        )
2306
2307    @dist_init
2308    def test_profiler_with_remote_udf(self):
2309        self._run_test_profiler_with_remote_udf()
2310
2311    @dist_init
2312    def test_profiler_with_remote_udf_single_threaded(self):
2313        self._run_test_profiler_with_remote_udf()
2314
2315    def _run_test_profiler_with_remote_builtin(self):
2316        self._profiler_test_with_rpc(
2317            RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1))
2318        )
2319        self._profiler_test_with_rpc(
2320            RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)),
2321            use_record_function=True
2322        )
2323        # test remote to self
2324        self._profiler_test_with_rpc(
2325            RPCExecMode.REMOTE,
2326            torch.mul,
2327            args=(torch.ones(1), torch.ones(1)),
2328            dst=self.rank,
2329        )
2330
2331    @dist_init
2332    def test_profiler_with_remote_builtin(self):
2333        self._run_test_profiler_with_remote_builtin()
2334
2335    @dist_init
2336    def test_profiler_with_remote_builtin_single_threaded(self):
2337        self._run_test_profiler_with_remote_builtin()
2338
2339    def _run_test_profiler_with_script_async_rpc(self):
2340        self._profiler_test_with_rpc(
2341            RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),)
2342        )
2343        self._profiler_test_with_rpc(
2344            RPCExecMode.ASYNC,
2345            my_script_func,
2346            args=(torch.tensor(1),),
2347            use_record_function=True,
2348        )
2349
2350    @dist_init
2351    def test_profiler_with_script_async_rpc(self):
2352        self._run_test_profiler_with_script_async_rpc()
2353
2354    @dist_init
2355    def test_profiler_with_script_async_rpc_single_threaded(self):
2356        self._run_test_profiler_with_script_async_rpc()
2357
2358    def _run_test_profiler_with_script_sync_rpc(self):
2359        self._profiler_test_with_rpc(
2360            RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),)
2361        )
2362        self._profiler_test_with_rpc(
2363            RPCExecMode.SYNC,
2364            my_script_func,
2365            args=(torch.tensor(1),),
2366            use_record_function=True,
2367        )
2368
2369    @dist_init
2370    def test_profiler_with_script_sync_rpc(self):
2371        self._run_test_profiler_with_script_sync_rpc()
2372
2373    @dist_init
2374    def test_profiler_with_script_sync_rpc_single_threaded(self):
2375        self._run_test_profiler_with_script_sync_rpc()
2376
2377    def _run_test_profiler_with_script_remote_rpc(self):
2378        self._profiler_test_with_rpc(
2379            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),)
2380        )
2381        self._profiler_test_with_rpc(
2382            RPCExecMode.REMOTE,
2383            my_script_func,
2384            args=(torch.tensor(1),),
2385            use_record_function=True,
2386        )
2387        # test remote to self
2388        self._profiler_test_with_rpc(
2389            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank
2390        )
2391
2392    @dist_init
2393    def test_profiler_with_script_remote_rpc(self):
2394        self._run_test_profiler_with_script_remote_rpc()
2395
2396    @dist_init
2397    def test_profiler_with_script_remote_rpc_single_threaded(self):
2398        self._run_test_profiler_with_script_remote_rpc()
2399
2400    def _assert_top_level_events(self, process_global_events, expected_top_level_event_names):
2401        top_level_event_names = []
2402        for thread_local_events in process_global_events:
2403            # Get top-level events from all events happened on a thread.
2404            last_end_time = 0
2405            for event in thread_local_events:
2406                event_name = event.name
2407                time_range = event.time_range
2408                if time_range.start > last_end_time:
2409                    top_level_event_names.append(event_name)
2410                    last_end_time = time_range.end
2411        top_level_event_names = sorted(top_level_event_names)
2412        expected_top_level_event_names = sorted(expected_top_level_event_names)
2413        self.assertEqual(
2414            top_level_event_names,
2415            expected_top_level_event_names,
2416            f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}",
2417        )
2418
2419    @dist_init
2420    def test_server_process_global_profiler(self):
2421        if self.rank != 0:
2422            return
2423
2424        dst_rank = (self.rank + 1) % self.world_size
2425        dst_worker_name = worker_name(dst_rank)
2426
2427        x = torch.tensor(1)
2428        y = torch.tensor(2)
2429
2430        outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
2431        outer_profile_rref.rpc_sync().__enter__()
2432        rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
2433        inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
2434        inner_profile_rref.rpc_sync().__enter__()
2435        rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
2436        inner_profile_rref.rpc_sync().__exit__(None, None, None)
2437        outer_profile_rref.rpc_sync().__exit__(None, None, None)
2438
2439        inner_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (inner_profile_rref,))
2440        expected_inner_events = ['aten::sub']
2441        expected_outer_events = expected_inner_events + ['aten::add']
2442
2443        self._assert_top_level_events(inner_events, expected_inner_events)
2444        outer_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (outer_profile_rref,))
2445        self._assert_top_level_events(outer_events, expected_outer_events)
2446
2447        inner_profile_rref.rpc_sync().key_averages()
2448        outer_profile_rref.rpc_sync().key_averages()
2449
2450    @dist_init
2451    def test_async_record_function_double_end_callbacks(self):
2452        num_sleep_seconds = 1
2453        if self.rank == 1:
2454            # Validate that calling the function twice results in an error.
2455            with _profile() as pf:
2456                with torch.autograd.profiler.record_function("foo") as rf:
2457                    fut = rpc.rpc_async(
2458                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
2459                    )
2460                    rf._call_end_callbacks_on_future(fut)
2461                    with self.assertRaisesRegex(
2462                        RuntimeError, "can only be called once."
2463                    ):
2464                        rf._call_end_callbacks_on_future(fut)
2465                fut.wait()
2466
2467    @dist_init
2468    def test_async_record_function_legacy(self):
2469        # Test the legacy _record_function ops work
2470        # Note: These exist for backward compatibility with TorchScript
2471        num_sleep_seconds = 1
2472        if self.rank == 1:
2473            with _profile() as pf:
2474                try:
2475                    handle = torch.ops.profiler._record_function_enter("foo", None)
2476                    fut = rpc.rpc_async(
2477                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
2478                    )
2479                    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
2480                finally:
2481                    torch.ops.profiler._record_function_exit(handle)
2482
2483                fut.wait()
2484
2485    @dist_init
2486    def test_async_record_function_cbs_jit_call(self):
2487        if self.rank == 1:
2488            with _profile() as pf:
2489                key = _build_rpc_profiling_key(
2490                    RPCExecMode.ASYNC,
2491                    torch._jit_internal._qualified_name(my_script_func),
2492                    "worker1",
2493                    "worker0",
2494                )
2495                with torch.autograd.profiler.record_function(key) as rf:
2496                    fut = rpc.rpc_async(
2497                        worker_name(0), my_script_func, args=(torch.tensor(1),)
2498                    )
2499                    # Intentionally calling record_function internals
2500                    fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.record, fut)
2501                result = fut.wait()
2502                # Validate that the profiling future returns the same value as the RPC
2503                # future.
2504                expected = torch.add(torch.tensor(1), torch.tensor(1))
2505                self.assertEqual(result, expected)
2506            events = pf.function_events
2507            rpc_event = get_function_event(
2508                events, torch._jit_internal._qualified_name(my_script_func)
2509            )
2510            self.assertTrue(torch._jit_internal._qualified_name(my_script_func) in rpc_event.name)
2511
2512    @dist_init
2513    def test_py_class_constructor(self):
2514        n = self.rank + 1
2515        dst_rank = n % self.world_size
2516        ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,))
2517        self.assertEqual(ret.a, n)
2518
2519    @dist_init
2520    def test_py_class_instance_method(self):
2521        n = self.rank + 1
2522        dst_rank = n % self.world_size
2523        ret = rpc.rpc_sync(
2524            worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,)
2525        )
2526        self.assertEqual(ret, MyClass(2).my_instance_method(n))
2527
2528    @dist_init
2529    def test_py_class_method(self):
2530        n = self.rank + 1
2531        dst_rank = n % self.world_size
2532        ret = rpc.rpc_sync(
2533            worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1)
2534        )
2535        self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
2536
2537    @dist_init
2538    def test_py_class_static_method(self):
2539        n = self.rank + 1
2540        dst_rank = n % self.world_size
2541        ret = rpc.rpc_sync(
2542            worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,)
2543        )
2544        self.assertEqual(ret, MyClass.my_static_method(n + 10))
2545
2546    @dist_init
2547    def test_py_multi_async_call(self):
2548        n = self.rank + 1
2549        dst_rank = n % self.world_size
2550        dst_worker_info = rpc.get_worker_info(worker_name(dst_rank))
2551        fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
2552        fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
2553        self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
2554        self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
2555
2556    @dist_init
2557    def test_py_no_return_result(self):
2558        n = self.rank + 1
2559        dst_rank = n % self.world_size
2560        ret = rpc.rpc_sync(worker_name(dst_rank), no_result)
2561        self.assertEqual(ret, no_result())
2562
2563    @dist_init
2564    def test_py_tensors(self):
2565        n = self.rank + 1
2566        dst_rank = n % self.world_size
2567        ret = rpc.rpc_sync(
2568            worker_name(dst_rank),
2569            my_tensor_function,
2570            args=(torch.ones(n, n), torch.ones(n, n)),
2571        )
2572        self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
2573
2574    @dist_init
2575    def test_py_tensors_multi_async_call(self):
2576        futs = []
2577        n = self.rank + 1
2578        dst_rank = n % self.world_size
2579        for i in range(100):
2580            fut = rpc.rpc_async(
2581                worker_name(dst_rank),
2582                my_tensor_function,
2583                args=(torch.ones(i, i), torch.ones(i, i)),
2584            )
2585            futs.append(fut)
2586
2587        j = 0
2588        for val in torch.futures.wait_all(futs):
2589            self.assertEqual(
2590                val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
2591            )
2592            j += 1
2593
2594    @dist_init
2595    def test_py_tensors_in_container(self):
2596        n = self.rank + 1
2597        dst_rank = n % self.world_size
2598        a = [torch.ones(n, n), torch.ones(n, n)]
2599        b = TensorClass(build_complex_tensors())
2600        c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
2601        ret = rpc.rpc_sync(
2602            worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c)
2603        )
2604        self.assertEqual(ret, my_complex_tensor_function(a, b, c))
2605
2606    @dist_init
2607    def test_py_nested_pickle(self):
2608        n = self.rank + 1
2609        dst_rank = n % self.world_size
2610
2611        ret = rpc.rpc_sync(
2612            worker_name(dst_rank),
2613            run_nested_pickle,
2614            args=(MyPickleClass(), torch.ones(2, 2)),
2615        )
2616
2617        m = MyPickleClass()
2618        m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
2619        self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
2620
2621    @dist_init
2622    def test_py_function_exception(self):
2623        n = self.rank + 1
2624        dst_rank = n % self.world_size
2625        with self.assertRaises(TypeError):
2626            ret = rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,))
2627
2628    @dist_init
2629    def test_py_raise_in_user_func(self):
2630        with captured_output() as (_, err):
2631            # This barrier prevents a race condition where the main thread has
2632            # not entered the context manager when the remote function runs.
2633            initialize_pg(self.file_init_method, self.rank, self.world_size)
2634            dist.barrier()
2635            n = self.rank + 1
2636            dst_rank = n % self.world_size
2637            fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
2638            with self.assertRaisesRegex(ValueError, expected_err):
2639                fut.wait()
2640            # This barrier prevents a race condition where the main thread exits
2641            # context manager before the remote function has ran.
2642            dist.barrier()
2643
2644        # Validate that trainers log errors when running functions.
2645        stderr_lines = err.getvalue()
2646        self.assertTrue(expected_err in stderr_lines)
2647
2648    @dist_init
2649    def test_py_raise_in_user_func_escaped_str(self):
2650        n = self.rank + 1
2651        dst_rank = n % self.world_size
2652        fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape)
2653        try:
2654            fut.wait()
2655        except ValueError as e:
2656            msg = str(e)
2657            # Ensure newlines are unescaped to provide a better repr of error.
2658            self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape"))
2659        else:
2660            self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
2661
2662    @dist_init
2663    def test_nested_rpc(self):
2664        self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
2665
2666    @dist_init
2667    def test_stress_light_rpc(self):
2668        self._stress_test_rpc(light_rpc)
2669
2670    @dist_init
2671    def test_stress_heavy_rpc(self):
2672        self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
2673
2674    @dist_init
2675    def test_stress_heavy_rpc_torchscript(self):
2676        self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),))
2677
2678    @dist_init
2679    def test_builtin_remote_ret(self):
2680        self._builtin_remote_ret(
2681            torch.ones(2, 2),
2682            torch.ones(2, 2),
2683            torch.ones(2, 2) * 2
2684        )
2685
2686    @dist_init
2687    def test_builtin_remote_self(self):
2688        self._builtin_remote_self(
2689            torch.ones(2, 2),
2690            torch.ones(2, 2),
2691            torch.ones(2, 2) * 2
2692        )
2693
2694    @staticmethod
2695    def _multi_args_fn(n, sparse=False):
2696        if sparse:
2697            return (build_sparse_tensor(), build_sparse_tensor())
2698        else:
2699            return (torch.ones(n, n), torch.ones(n, n))
2700
2701    @dist_init
2702    def test_multi_builtin_remote_ret(self):
2703        self._test_multi_remote_call(
2704            torch.add, False,
2705            args_fn=RpcTest._multi_args_fn
2706        )
2707
2708    @dist_init
2709    def test_py_udf_remote(self):
2710        n = self.rank + 1
2711        dst_rank = n % self.world_size
2712        rref = rpc.remote(
2713            worker_name(dst_rank),
2714            my_function,
2715            kwargs={"a": n, "b": n + 1, "c": n + 2},
2716        )
2717        self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
2718
2719    @staticmethod
2720    def _multi_kwargs_fn(n, sparse=False):
2721        if sparse:
2722            return {
2723                "a": build_sparse_tensor(),
2724                "b": build_sparse_tensor(),
2725                "c": build_sparse_tensor()
2726            }
2727        else:
2728            return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
2729
2730    @dist_init
2731    def test_multi_py_udf_remote(self):
2732        self._test_multi_remote_call(
2733            my_function,
2734            False,
2735            kwargs_fn=RpcTest._multi_kwargs_fn
2736        )
2737
2738    @dist_init
2739    def test_py_rref_args(self):
2740        self._py_rref_args(
2741            torch.ones(2, 2),
2742            1,
2743            torch.ones(2, 2),
2744            2,
2745            torch.ones(2, 2) * 2 + 3)
2746
2747    @dist_init
2748    def test_py_rref_args_user_share(self):
2749        self._py_rref_args_user_share(
2750            torch.ones(2, 2),
2751            1,
2752            2,
2753            torch.ones(2, 2),
2754            3,
2755            4,
2756            torch.ones(2, 2) * 2 + 10
2757        )
2758
2759    @dist_init
2760    def test_py_rpc_rref_args(self):
2761        self._py_rpc_rref_args(
2762            torch.ones(2, 2),
2763            1,
2764            2,
2765            torch.ones(2, 2),
2766            3,
2767            4,
2768            torch.ones(2, 2) * 2 + 10
2769        )
2770
2771    @dist_init
2772    def test_nested_remote(self):
2773        self._nested_remote(
2774            nested_remote,
2775            torch.ones(2, 2) + 3
2776        )
2777
2778    @dist_init
2779    def test_nested_rref(self):
2780        self._nested_rref(
2781            nested_rref,
2782            torch.ones(2, 2) + 1,
2783            torch.ones(2, 2) + 2
2784        )
2785
2786    @dist_init
2787    def test_nested_rref_stress(self):
2788        self._nested_rref_stress(
2789            nested_rref,
2790            torch.ones(2, 2) + 1,
2791            torch.ones(2, 2) + 2
2792        )
2793
2794    @dist_init
2795    def test_multi_layer_nested_async_rpc(self):
2796        # This test will exit right away, but there will be a chain of async
2797        # RPCs. The termination algorithm should detect those messages properly.
2798        # Otherwise, some peer could exit early, leaving others to timeout
2799        # errors or connection closed errors.
2800        ttl = 20
2801        n = self.rank + 1
2802        dst_rank = n % self.world_size
2803
2804        multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
2805
2806    @dist_init
2807    def test_remote_with_exception(self):
2808        n = self.rank + 1
2809        dst_rank = n % self.world_size
2810        # check ref to other workers
2811        rref = rpc.remote(worker_name(dst_rank), raise_func)
2812        with self.assertRaises(ValueError):
2813            rref.to_here()
2814        # check ref to itself
2815        rref = rpc.remote(worker_name(self.rank), no_result, args=(10,))
2816        with self.assertRaises(TypeError):
2817            rref.to_here()
2818
2819    @dist_init
2820    def test_rpc_return_rref(self):
2821        n = self.rank + 1
2822        dst_rank1 = n % self.world_size
2823        dst_rank2 = (n + 1) % self.world_size
2824        rref = rpc.rpc_sync(
2825            worker_name(dst_rank1),
2826            rpc_return_rref,
2827            args=(worker_name(dst_rank2),),
2828        )
2829        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
2830
2831    @dist_init
2832    def test_rref_forward_chain(self):
2833        ttl = 8
2834        n = self.rank + 1
2835        dst_rank = n % self.world_size
2836
2837        rref = rpc.remote(
2838            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
2839        )
2840
2841        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
2842
2843        for i in range(ttl):
2844            self.assertEqual(len(ret_rref), 1)
2845            ret_rref = ret_rref[0].to_here()
2846
2847        ret = ret_rref
2848        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
2849
2850    @dist_init
2851    def test_local_rref_no_fork(self):
2852        local_rref = RRef(35)
2853        self.assertEqual(local_rref.local_value(), 35)
2854
2855    @dist_init
2856    def test_local_value_not_on_owner(self):
2857        # ensure that an error message is thrown if a user tries to call
2858        # local_value() on a non-owning node.
2859        next_rank = (self.rank + 1) % self.world_size
2860        rref = rpc.remote(
2861            worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1))
2862        )
2863        with self.assertRaisesRegex(
2864            RuntimeError, (
2865                fr"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), "
2866                fr"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), "
2867                r"can't call localValue\(\) on user "
2868                fr"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). "
2869                fr"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)"
2870            )
2871        ):
2872            rref.local_value()
2873
2874    @dist_init
2875    def test_return_local_rrefs(self):
2876        n = self.rank + 1
2877        dst_rank = n % self.world_size
2878
2879        rref_list = rpc.rpc_sync(
2880            worker_name(dst_rank), get_rref_list, args=([1, 2, 3],)
2881        )
2882
2883        for rref in rref_list:
2884            rpc.rpc_sync(
2885                rref.owner(),
2886                _call_method_on_rref,
2887                args=(MyClass.increment_value, rref, 10),
2888            )
2889
2890        rets = [
2891            rpc.rpc_sync(
2892                rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref)
2893            )
2894            for rref in rref_list
2895        ]
2896
2897        self.assertEqual(rets, [11, 12, 13])
2898
2899    @dist_init
2900    def _test_rref_type(self, blocking):
2901
2902        def launched_rpc(events):
2903            expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
2904            return any(e.name.startswith(expected_name) for e in events)
2905
2906        dst = worker_name((self.rank + 1) % self.world_size)
2907        rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))
2908
2909        with _profile() as p:
2910            t = rref._get_type(blocking=blocking)
2911            if not blocking:
2912                t = t.wait()
2913
2914        self.assertTrue(launched_rpc(p.function_events))
2915        expected_type = type(torch.ones(2))
2916        self.assertEqual(t, expected_type)
2917
2918        futs = []
2919
2920        def verify(fut):
2921            self.assertEqual(fut.value(), expected_type)
2922
2923        with _profile() as p:
2924            for _ in range(10):
2925                t = rref._get_type(blocking=blocking)
2926                if not blocking:
2927                    futs.append(t)
2928                    t.add_done_callback(verify)
2929                    t = t.wait()
2930                self.assertEqual(t, expected_type)
2931
2932        if not blocking:
2933            # Note that cached calls with blocking=False all return the same
2934            # cached original future.
2935            first_fut = futs[0]
2936            for f in futs[1:]:
2937                self.assertTrue(f is first_fut)
2938        # Ensure we never launch another RPC, other than for the very
2939        # first call.
2940        self.assertFalse(launched_rpc(p.function_events))
2941        self.assertEqual(t, type(torch.ones(2)))
2942
2943        rref = rpc.remote(dst, MyClass, args=(0,))
2944        rref_type = rref._get_type(blocking=blocking)
2945        if not blocking:
2946            rref_type = rref_type.wait()
2947        self.assertEqual(rref_type, MyClass)
2948
2949    def test_rref_type_blocking(self):
2950        self._test_rref_type(blocking=True)
2951
2952    def test_rref_type_non_blocking(self):
2953        self._test_rref_type(blocking=False)
2954
2955    @dist_init
2956    def _test_rref_type_with_error(self, blocking):
2957        dst = worker_name((self.rank + 1) % self.world_size)
2958        # 10 ms timeout
2959        rref = rpc.remote(dst, raise_func)
2960        # Blocking: error raised inline
2961        if blocking:
2962            with self.assertRaisesRegex(ValueError, "Expected error"):
2963                rref._get_type(blocking=blocking)
2964        else:
2965            # Non-blocking: Immediately return future, block on wait
2966            fut = rref._get_type(blocking=blocking)
2967            with self.assertRaisesRegex(ValueError, "Expected error"):
2968                fut.wait()
2969
2970
2971    def test_rref_type_with_error_blocking(self):
2972        self._test_rref_type_with_error(blocking=True)
2973
2974    def test_rref_type_with_error_non_blocking(self):
2975        self._test_rref_type_with_error(blocking=False)
2976
2977    @dist_init
2978    def _test_rref_type_owner(self, blocking):
2979        rref = RRef(torch.ones(2) + 1)
2980        rref_type = rref._get_type(blocking=blocking)
2981        if not blocking:
2982            rref_type = rref_type.wait()
2983        self.assertEqual(rref_type, type(torch.ones(2)))
2984
2985        rref = RRef(MyClass(0))
2986        rref_type = rref._get_type(blocking=blocking)
2987        if not blocking:
2988            rref_type = rref_type.wait()
2989        self.assertEqual(rref_type, MyClass)
2990
2991    def test_rref_type_owner_blocking(self):
2992        self._test_rref_type_owner(blocking=True)
2993
2994    def test_rref_type_owner_non_blocking(self):
2995        self._test_rref_type_owner(blocking=False)
2996
2997    @staticmethod
2998    def _slow_add(x, y):
2999        time.sleep(1)
3000        return x + y
3001
3002    @dist_init
3003    def test_rref_type_slow_init(self):
3004        dst = worker_name((self.rank + 1) % self.world_size)
3005        rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1))
3006        self.assertEqual(rref._get_type(), type(torch.ones(2)))
3007
3008    @dist_init
3009    def test_owner_equality(self):
3010        a = RRef(40)
3011        b = RRef(50)
3012
3013        other_rank = (self.rank + 1) % self.world_size
3014        other_a = rpc.remote(
3015            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
3016        )
3017        other_b = rpc.remote(
3018            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
3019        )
3020        other_a.to_here()  # to ensure clean termination
3021        other_b.to_here()
3022
3023        self.assertNotEqual(a.owner(), 23)
3024        self.assertEqual(other_a.owner(), other_b.owner())
3025        self.assertNotEqual(a.owner(), other_a.owner())
3026        self.assertEqual(other_a.owner(), other_a.owner())
3027        self.assertEqual(other_a.owner(), other_b.owner())
3028        self.assertEqual(a.owner(), a.owner())
3029        self.assertEqual(a.owner(), b.owner())
3030        self.assertEqual(a.owner(), rpc.get_worker_info())
3031        x = {}
3032        x[a.owner()] = a
3033        x[other_a.owner()] = other_a
3034        self.assertEqual(x[a.owner()], a)
3035        self.assertEqual(x[b.owner()], a)
3036        self.assertEqual(x[other_a.owner()], other_a)
3037        self.assertEqual(x[other_b.owner()], other_a)
3038        self.assertEqual(len(x), 2)
3039
3040    @dist_init
3041    def test_pass_local_rrefs(self):
3042        n = self.rank + 1
3043        dst_rank = n % self.world_size
3044        dst_worker = worker_name(dst_rank)
3045
3046        rref = RRef(40)
3047        self.assertEqual(
3048            rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90
3049        )
3050        self.assertEqual(
3051            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90
3052        )
3053        self.assertEqual(
3054            rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90
3055        )
3056
3057    @dist_init
3058    def test_remote_same_worker(self):
3059        n = self.rank + 1
3060        dst_rank = n % self.world_size
3061        rref_a = rpc.remote(
3062            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
3063        )
3064        rref_b = rpc.remote(
3065            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
3066        )
3067        rref_c = rpc.remote(
3068            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
3069        )
3070        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
3071
3072    @dist_init(setup_rpc=True)
3073    def test_call_method_on_rref(self):
3074        """
3075        Tests that it is possible to call an instance method on a remote object
3076        by using rref.owner() as destination of the call.
3077        """
3078        vals = [10, 2, 5, 7]
3079        dst_rank = (self.rank + 1) % self.world_size
3080        dst_worker = worker_name(dst_rank)
3081
3082        # creates a remote object
3083        rref = rpc.remote(dst_worker, MyClass, args=(vals[0],))
3084
3085        # modifies state of the remote object
3086        rpc.rpc_sync(
3087            rref.owner(),
3088            _call_method_on_rref,
3089            args=(MyClass.increment_value, rref, vals[1]),
3090        )
3091        rpc.rpc_async(
3092            rref.owner(),
3093            _call_method_on_rref,
3094            args=(MyClass.increment_value, rref, vals[2]),
3095        ).wait()
3096        rpc.remote(
3097            rref.owner(),
3098            _call_method_on_rref,
3099            args=(MyClass.increment_value, rref, vals[3]),
3100        ).to_here()
3101
3102        # queries state of the remote object
3103        result = rpc.rpc_sync(
3104            dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref)
3105        )
3106
3107        self.assertEqual(result, sum(vals))
3108
3109    # Notice `rpc.api.shutdown()` accesses
3110    # `_delete_all_user_and_unforked_owner_rrefs` through
3111    # `torch.distributed.rpc.api`, so patching
3112    # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will
3113    # not help.
3114    @mock.patch.object(torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs")
3115    def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak):
3116        rpc.init_rpc(
3117            name=worker_name(self.rank),
3118            backend=self.rpc_backend,
3119            rank=self.rank,
3120            world_size=self.world_size,
3121            rpc_backend_options=self.rpc_backend_options,
3122        )
3123
3124        initialize_pg(self.file_init_method, self.rank, self.world_size)
3125        # Wait for all init to complete.
3126        dist.barrier()
3127
3128        rref = rpc.remote(
3129            worker_name((self.rank + 1) % self.world_size),
3130            torch.add,
3131            args=(torch.ones(2, 2), 1),
3132        )
3133
3134        import torch.distributed.rpc.api as api
3135
3136        if ignore_leak:
3137            api._ignore_rref_leak = True
3138            rpc.shutdown(graceful=True)
3139        else:
3140            api._ignore_rref_leak = False
3141            with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
3142                rpc.shutdown(graceful=True)
3143
3144    @dist_init(setup_rpc=False)
3145    def test_rref_leak(self):
3146        self._test_rref_leak(ignore_leak=False)
3147
3148    @dist_init(setup_rpc=False)
3149    def test_ignore_rref_leak(self):
3150        self._test_rref_leak(ignore_leak=True)
3151
3152    @dist_init
3153    def test_rref_str(self):
3154        rref1 = RRef(self.rank)
3155        id_class = "GloballyUniqueId"
3156        self.assertEqual(
3157            f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))", rref1.__str__()
3158        )
3159
3160        dst_rank = (self.rank + 1) % self.world_size
3161        rref2 = rpc.remote(
3162            worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
3163        )
3164        self.assertEqual(
3165            rref2.__str__(),
3166            f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), "
3167            f"ForkId = {id_class}(created_on={self.rank}, local_id=2))",
3168        )
3169
3170    @dist_init
3171    def test_rref_get_future(self):
3172        # Tests that we can obtain the future corresponding to the creation of
3173        # the RRef on remote end
3174        if self.rank == 0:
3175            # Builtin
3176            rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
3177            rref.to_here()
3178            fut = rref._get_future()
3179            self.assertIsInstance(fut, torch._C.Future)
3180
3181            # UDF
3182            rref = rpc.remote(worker_name(1), foo_add, args=())
3183            rref.to_here()
3184            fut = rref._get_future()
3185            self.assertIsInstance(fut, torch._C.Future)
3186
3187            # Script
3188            rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), ))
3189            rref.to_here()
3190            fut = rref._get_future()
3191            self.assertIsInstance(fut, torch._C.Future)
3192
3193
3194    @dist_init
3195    def test_rref_context_debug_info(self):
3196        # This test checks local states that are modified by remote workers.
3197        # This means that we would need barrier before and after every check.
3198        # The barrier before the check makes sure that all previous states are
3199        # cleared globally, the barrier after ensures that no following states
3200        # change gets into the current check.
3201        initialize_pg(self.file_init_method, self.rank, self.world_size)
3202
3203        # Check 1: local RRef does not update owners_ map or add a pending user.
3204        #################################################
3205
3206        rref1 = RRef(self.rank)
3207
3208        # don't need a barrier here as local RRef is handled by this thread
3209        info = _rref_context_get_debug_info()
3210        self.assertIn("num_owner_rrefs", info)
3211        self.assertIn("num_pending_users", info)
3212        # RRef on local value is not added to context until shared across RPC
3213        self.assertEqual(0, int(info["num_owner_rrefs"]))
3214        self.assertEqual(0, int(info["num_pending_users"]))
3215        # barrier after the check 1
3216        dist.barrier()
3217
3218        # Check 2: Sharing RRef as an arg should update owners_ map
3219        ###########################################################
3220
3221        dst_rank = (self.rank + 1) % self.world_size
3222        rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,))
3223
3224        # barrier before check 2
3225        wait_until_pending_futures_and_users_flushed()
3226        dist.barrier()
3227
3228        info = _rref_context_get_debug_info()
3229        self.assertIn("num_owner_rrefs", info)
3230        self.assertEqual(1, int(info["num_owner_rrefs"]))
3231        # no pending users since the fork is finished
3232        self.assertEqual(0, int(info["num_pending_users"]))
3233        # barrier after check 2
3234        dist.barrier()
3235
3236        # clear states for check 2
3237        rpc.rpc_sync(worker_name(dst_rank), clear_global_rref)
3238
3239        # Wait for owner rref to be cleared.
3240        while int(info["num_owner_rrefs"]) != 0:
3241            info = _rref_context_get_debug_info()
3242            time.sleep(0.1)
3243        dist.barrier()
3244
3245        # Check 3: rpc.remote call should update owners_ map
3246        ####################################################
3247        rref2 = rpc.remote(
3248            worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
3249        )
3250        rref3 = rpc.remote(
3251            worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
3252        )
3253        rref2.to_here()
3254        rref3.to_here()
3255
3256        # barrier before check 3
3257        wait_until_pending_futures_and_users_flushed()
3258        dist.barrier()
3259
3260        info = _rref_context_get_debug_info()
3261        self.assertIn("num_owner_rrefs", info)
3262        self.assertEqual(2, int(info["num_owner_rrefs"]))
3263        # no pending users since the fork is finished
3264        self.assertEqual(0, int(info["num_pending_users"]))
3265
3266        # barrier after check 3
3267        dist.barrier()
3268
3269    @dist_init
3270    def test_disable_gil_profiling(self):
3271        # test that rpc.enable_gil_profiling(false) will result in
3272        # GIL wait time not being recorded.
3273
3274        # GIL profiling should be disabled by default.
3275        dst_rank = (self.rank + 1) % self.world_size
3276        rpc.rpc_sync(
3277            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
3278        )
3279        info = rpc.api._get_current_rpc_agent().get_debug_info()
3280        self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"])
3281        rpc.enable_gil_profiling(True)
3282        rpc.rpc_sync(
3283            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
3284        )
3285        info = rpc.api._get_current_rpc_agent().get_debug_info()
3286        self.assertIn("agent.gil_average_wait_time_us", info)
3287
3288    @dist_init(setup_rpc=False)
3289    def test_local_shutdown(self):
3290        # test that we can start RPC and then immediately locally shutdown
3291        # without sending any messages.
3292        rpc.init_rpc(
3293            name="worker%d" % self.rank,
3294            backend=self.rpc_backend,
3295            rank=self.rank,
3296            world_size=self.world_size,
3297            rpc_backend_options=self.rpc_backend_options,
3298        )
3299        # pass in graceful=False to ensure that we don't wait for other workers.
3300        rpc.shutdown(graceful=False)
3301
3302    @dist_init
3303    def test_debug_info(self):
3304        # only test keys in this test case. Values should be covered by
3305        # individual module debug info tests
3306        import torch.distributed.autograd as dist_autograd
3307
3308        info = _get_debug_info()
3309        rref_info = _rref_context_get_debug_info()
3310        agent_info = rpc.api._get_current_rpc_agent().get_debug_info()
3311        autograd_info = dist_autograd._get_debug_info()
3312        common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys()
3313        self.assertEqual(0, len(common_keys))
3314        expected = {}
3315        expected.update(rref_info)
3316        expected.update(agent_info)
3317        expected.update(autograd_info)
3318        # NB: Key ordering is only preserved in python 3.6+. So here, we
3319        # manually check keys are equal.
3320        for key in expected.keys():
3321            self.assertIn(key, info.keys())
3322
3323        for key in info.keys():
3324            self.assertIn(key, expected.keys())
3325
3326    @dist_init(setup_rpc=False)
3327    @skip_but_pass_in_sandcastle_if(
3328        IS_MACOS,
3329        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
3330    )
3331    def test_handle_send_exceptions(self):
3332        # test that if a callee node has gone down, we raise an appropriate
3333        # exception instead of just crashing.
3334        rpc.init_rpc(
3335            name="worker%d" % self.rank,
3336            backend=self.rpc_backend,
3337            rank=self.rank,
3338            world_size=self.world_size,
3339            rpc_backend_options=self.rpc_backend_options,
3340        )
3341        rpc._set_rpc_timeout(10)
3342        # This barrier is needed to ensure that some workers do not exit before
3343        # others have been brought up.
3344        initialize_pg(self.file_init_method, self.rank, self.world_size)
3345        dist.barrier()
3346        if self.rank == 1:
3347            dst_rank = (self.rank + 1) % self.world_size
3348            dst_worker = worker_name(dst_rank)
3349            # allow destination worker to exit without joining
3350            error_str = self.get_shutdown_error_regex()
3351            wait_until_node_failure(dst_rank, error_str)
3352            fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3))
3353            # Shutdown sequence is not very well defined and as a result
3354            # we can see any of the error messages defined in get_shutdown_error_regex.
3355            with self.assertRaisesRegex(RuntimeError, error_str):
3356                fut.wait()
3357        # exit all workers non-gracefully.
3358        rpc.shutdown(graceful=False)
3359
3360    @dist_init
3361    def test_deadlock(self):
3362        # this test is copied from https://github.com/pytorch/pytorch/issues/45089
3363        if self.rank == 1:
3364            dst1 = worker_name((self.rank + 1) % self.world_size)
3365            x = torch.ones(2)
3366            y = torch.ones(2)
3367            rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
3368
3369        dist_initialized = dist.is_initialized()
3370        if not dist_initialized:
3371            dist.init_process_group(
3372                backend="gloo",
3373                init_method=self.file_init_method,
3374                rank=self.rank,
3375                world_size=self.world_size,
3376            )
3377
3378    @dist_init(setup_rpc=False)
3379    def test_local_shutdown_with_rpc(self):
3380        # test that we can start RPC, send RPCs, and then run local shutdown.
3381        rpc.init_rpc(
3382            name="worker%d" % self.rank,
3383            backend=self.rpc_backend,
3384            rank=self.rank,
3385            world_size=self.world_size,
3386            rpc_backend_options=self.rpc_backend_options,
3387        )
3388        n = self.rank + 1
3389        dst_rank = n % self.world_size
3390        rpc.rpc_sync(
3391            worker_name(dst_rank),
3392            torch.add,
3393            args=(torch.ones(n, n), torch.ones(n, n)),
3394        )
3395        # A barrier is needed to ensure that all RPCs are processed.
3396        # Otherwise, some RPCs can timeout since the receiving end
3397        # has terminated.
3398        initialize_pg(self.file_init_method, self.rank, self.world_size)
3399        dist.barrier()
3400        # pass in graceful=False to ensure that we don't wait for other workers.
3401        rpc.shutdown(graceful=False)
3402
3403    @dist_init(setup_rpc=False)
3404    def test_set_and_get_default_rpc_timeout(self):
3405        timeout = 0.5
3406
3407        # A new `RpcBackendOptions` is constructed
3408        # when accessing `self.rpc_backend_options`.
3409        rpc_backend_options = self.rpc_backend_options
3410        rpc_backend_options.rpc_timeout = timeout
3411
3412        rpc.init_rpc(
3413            name=worker_name(self.rank),
3414            backend=self.rpc_backend,
3415            rank=self.rank,
3416            world_size=self.world_size,
3417            rpc_backend_options=rpc_backend_options,
3418        )
3419        set_timeout = rpc.get_rpc_timeout()
3420        self.assertEqual(timeout, set_timeout)
3421        rpc.shutdown()
3422
3423    @dist_init
3424    def test_default_timeout_used(self):
3425        """
3426        Tests that if no timeout is passed into rpc_async and rpc_sync, then the
3427        default timeout is used.
3428        """
3429        dst_rank = (self.rank + 1) % self.world_size
3430        rpc._set_rpc_timeout(0.001)  # 1 ms
3431        # futures should time out and be marked with an exception indicating it as such.
3432        futs = [
3433            rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=())
3434            for _ in range(10)
3435        ]
3436        expected_error = self.get_timeout_error_regex()
3437        for fut in futs:
3438            with self.assertRaisesRegex(RuntimeError, expected_error):
3439                fut.wait()
3440
3441        # ensure that if a new timeout is set old futures don't time out but new ones do.
3442        rpc._set_rpc_timeout(200)  # 200 seconds
3443        # create a longstanding RPC.
3444        fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
3445        # now, set a short timeout.
3446        rpc._set_rpc_timeout(0.001)
3447        # fut2 should time out, fut1 should not.
3448        fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
3449        with self.assertRaisesRegex(RuntimeError, expected_error):
3450            fut2.wait()
3451        fut1.wait()
3452
3453        # Zero timeout means infinity, so future should run to completion.
3454        rpc._set_rpc_timeout(0)
3455        rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait()
3456
3457        # reset to default timeout so shutdown messages can process cleanly.
3458        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
3459
3460    @dist_init
3461    def test_rpc_timeouts(self):
3462        # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803)
3463        dst_rank = (self.rank + 1) % self.world_size
3464        dst_worker = worker_name(dst_rank)
3465        timeout = 0.1  # 100 ms
3466        expected_error = self.get_timeout_error_regex()
3467        # Test async UDF
3468        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
3469        with self.assertRaisesRegex(RuntimeError, expected_error):
3470            fut.wait()
3471
3472        # Ensure run to completion if there is no timeout and we use the default
3473        # RPC timeout.
3474        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait()
3475
3476        # Test sync UDF
3477        with self.assertRaisesRegex(RuntimeError, expected_error):
3478            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
3479
3480        # Ensure run to completion if there is no timeout and we use the default
3481        # RPC timeout.
3482        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
3483
3484        # If we set a default timeout for RPCs, it should be respected, though
3485        # still overridden if we pass in a different timeout to the APIs.
3486        rpc._set_rpc_timeout(0.001)
3487        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,))
3488        with self.assertRaisesRegex(RuntimeError, expected_error):
3489            fut.wait()
3490        with self.assertRaisesRegex(RuntimeError, expected_error):
3491            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
3492
3493        # The RPCs should run to completion since we override the timeout.
3494        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait()
3495        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5)
3496        # Passing in a zero timeout should ensure that the RPC won't time out.
3497        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait()
3498        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0)
3499        # Reset for clean shutdown
3500        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
3501
3502    def test_dist_init_decorator(self):
3503        @dist_init(setup_rpc=False)
3504        def test_func(self):
3505            return "expected result"
3506
3507        self.assertEqual(test_func(self), "expected result")
3508
3509        @dist_init
3510        def test_func(self):
3511            return "expected result"
3512
3513        self.assertEqual(test_func(self), "expected result")
3514
3515    def test_use_rpc_pickler(self):
3516        class TestPickler:
3517            pass
3518
3519        test_pickler = TestPickler()
3520        with _use_rpc_pickler(test_pickler):
3521            self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
3522        self.assertTrue(
3523            torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
3524        )
3525
3526    @dist_init
3527    def test_wait_all(self):
3528        with _wait_all():
3529            self.assertTrue(_thread_local_var.future_list == [])
3530            dst = worker_name((self.rank + 1) % self.world_size)
3531            fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
3532            self.assertTrue(len(_thread_local_var.future_list) == 1)
3533            self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future))
3534        self.assertTrue(fut.done())
3535        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
3536        self.assertFalse(hasattr(_thread_local_var, "future_list"))
3537
3538    @dist_init
3539    def test_wait_all_multiple_call(self):
3540        with _wait_all():
3541            self.assertTrue(_thread_local_var.future_list == [])
3542            dst = worker_name((self.rank + 1) % self.world_size)
3543            for i in range(20):
3544                fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1))
3545                res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1))
3546                self.assertEqual(res, torch.ones(i, i) + 1)
3547                self.assertEqual(fut.wait(), torch.ones(i, i) + 1)
3548            self.assertTrue(len(_thread_local_var.future_list) == 20)
3549        self.assertFalse(hasattr(_thread_local_var, "future_list"))
3550
3551    @dist_init
3552    def test_wait_all_timeout(self):
3553        expected_error = self.get_timeout_error_regex()
3554        with self.assertRaisesRegex(RuntimeError, expected_error):
3555            with _wait_all():
3556                self.assertTrue(_thread_local_var.future_list == [])
3557                dst = worker_name((self.rank + 1) % self.world_size)
3558                timeout = 0.1  # 100 ms
3559                fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout)
3560        self.assertFalse(hasattr(_thread_local_var, "future_list"))
3561
3562    @dist_init
3563    def test_wait_all_raise_in_user_func(self):
3564        with self.assertRaises(ValueError):
3565            with _wait_all():
3566                self.assertTrue(_thread_local_var.future_list == [])
3567                dst = worker_name((self.rank + 1) % self.world_size)
3568                fut = rpc.rpc_async(dst, raise_func)
3569        self.assertFalse(hasattr(_thread_local_var, "future_list"))
3570
3571    @dist_init
3572    def test_wait_all_raise_in_body(self):
3573        with self.assertRaises(ValueError):
3574            with _wait_all():
3575                raise_func()
3576        self.assertFalse(hasattr(_thread_local_var, "future_list"))
3577
3578    @dist_init
3579    def test_custom_exception_throw_during_reconstruction(self):
3580        """
3581        Test that we still throw info about the remote side exception even when
3582        we cannot recreate it on client side.
3583        """
3584        initialize_pg(self.file_init_method, self.rank, self.world_size)
3585        if self.rank != 0:
3586            exc_caught = False
3587            dst = worker_name(0)
3588            try:
3589                rpc.rpc_sync(dst, custom_raise_func, args=())
3590            except RuntimeError as e:
3591                exc_caught = True
3592                msg = str(e)
3593                print(f"Got msg {msg}")
3594                self.assertTrue("Original exception on remote side was" in msg)
3595                self.assertTrue("CustomException" in msg)
3596            except BaseException as e:
3597                raise RuntimeError(
3598                    f"Failure - expected RuntimeError, got {e}"
3599                ) from e
3600            finally:
3601                self.assertTrue(exc_caught)
3602
3603        dist.barrier()
3604
3605
3606    timed_out_rpc_event = None
3607
3608    @staticmethod
3609    def timed_out_rpc():
3610        RpcTest.timed_out_rpc_event.wait()
3611
3612    @dist_init
3613    def test_wait_all_exit_early_python(self):
3614        # Initialize the event in the subprocess.
3615        RpcTest.timed_out_rpc_event = Event()
3616
3617        # Wait for all processes to initialize event.
3618        initialize_pg(self.file_init_method, self.rank, self.world_size)
3619        dist.barrier()
3620
3621        dst = worker_name((self.rank + 1) % self.world_size)
3622        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
3623        fut2 = rpc.rpc_async(dst, raise_func)
3624        fut3 = rpc.rpc_async(dst, raise_func)
3625
3626        # We should receive the error from fut2
3627        with self.assertRaisesRegex(ValueError, expected_err):
3628            torch.futures.wait_all([fut1, fut2, fut3])
3629
3630        # Unblock RPC thread for fut1
3631        RpcTest.timed_out_rpc_event.set()
3632
3633    @dist_init
3634    def test_wait_all_exit_early_builtin(self):
3635        # Initialize the event in the subprocess.
3636        RpcTest.timed_out_rpc_event = Event()
3637
3638        # Wait for all processes to initialize event.
3639        initialize_pg(self.file_init_method, self.rank, self.world_size)
3640        dist.barrier()
3641
3642        dst = worker_name((self.rank + 1) % self.world_size)
3643        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
3644        fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
3645        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
3646
3647        # We should receive the error from fut2
3648        with self.assertRaisesRegex(RuntimeError, "size of tensor"):
3649            torch.futures.wait_all([fut1, fut2, fut3])
3650
3651        # Unblock RPC thread for fut1
3652        RpcTest.timed_out_rpc_event.set()
3653
3654    @dist_init
3655    def test_wait_all_exit_early_script_function(self):
3656        # Initialize the event in the subprocess.
3657        RpcTest.timed_out_rpc_event = Event()
3658
3659        # Wait for all processes to initialize event.
3660        initialize_pg(self.file_init_method, self.rank, self.world_size)
3661        dist.barrier()
3662
3663        dst = worker_name((self.rank + 1) % self.world_size)
3664        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
3665        fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
3666        fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
3667
3668        # We should receive the error from fut2
3669        with self.assertRaisesRegex(RuntimeError, expected_err):
3670            torch.futures.wait_all([fut1, fut2, fut3])
3671
3672        # Unblock RPC thread for fut1
3673        RpcTest.timed_out_rpc_event.set()
3674
3675
3676    @dist_init
3677    def test_function_not_on_callee(self):
3678        # test that if a function does not exist on a callee, we don't crash,
3679        # instead we get an AttributeError indicating that the func does not exist.
3680        this_module = sys.modules[__name__]
3681        caller_worker = "worker0"
3682        callee_worker = "worker1"
3683
3684        if self.rank == 1:
3685            # Use delattr to remove the binding of a func on this nodes
3686            delattr(this_module, "foo_add")
3687            # notify remote end that we have removed it.
3688            rpc.rpc_sync(caller_worker, set_value, args=(self.rank,))
3689
3690        if self.rank == 0:
3691            # func exists on caller, but not callee.
3692            # wait for remote end to remove the binding of foo_add func.
3693            wait_for_value_future()
3694            # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error.
3695            self.assertTrue(hasattr(this_module, "foo_add"))
3696            with self.assertRaisesRegex(
3697                RuntimeError, "RPC pickler does not serialize"
3698            ):
3699                rpc.rpc_sync(callee_worker, foo_add, args=())
3700
3701    @dist_init
3702    def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self):
3703        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
3704
3705        a = MyClass(1)
3706        b = MyClass(2)
3707
3708        # This is to make Python not garbage collect a and b.
3709        a.other = b
3710        b.other = a
3711
3712        n = self.rank
3713        a.rref = rpc.remote(
3714            dst_worker_name,
3715            torch.add,
3716            args=(torch.ones(n, n), 2)
3717        )
3718
3719    @dist_init(setup_rpc=False)
3720    def test_use_rref_after_shutdown(self):
3721        rpc.init_rpc(
3722            name="worker%d" % self.rank,
3723            backend=self.rpc_backend,
3724            rank=self.rank,
3725            world_size=self.world_size,
3726            rpc_backend_options=self.rpc_backend_options,
3727        )
3728        n = self.rank + 1
3729        dst_rank = n % self.world_size
3730        rref = rpc.remote(
3731            worker_name(dst_rank),
3732            torch.add,
3733            args=(torch.ones(n, n), torch.ones(n, n)),
3734        )
3735        # pass in graceful=True to ensure that local UserRRefs are deleted.
3736        rpc.shutdown(graceful=True)
3737
3738        with self.assertRaisesRegex(
3739            RuntimeError, "Cannot call to_here\\(\\) on it after deletion."
3740        ):
3741            rref.to_here()
3742
3743        with self.assertRaisesRegex(
3744            RuntimeError, "Cannot call fork an UserRRef after deletion."
3745        ):
3746            import torch.distributed.rpc.internal as internal
3747            internal.serialize(rref)
3748
3749    @staticmethod
3750    def _return_gpu_tensor():
3751        return torch.rand(3, 3).cuda(0)
3752
3753    @staticmethod
3754    def _return_gpu_tensor_list():
3755        return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)]
3756
3757    @staticmethod
3758    def _gpu_tensor_list_arg(tensor_list):
3759        return torch.rand(3, 3)
3760
3761    def _create_rref(self):
3762        owner_rank = (self.rank + 2) % self.world_size
3763        return rpc.remote(
3764            worker_name(owner_rank),
3765            torch.add,
3766            args=(torch.zeros(2, 2), 1)
3767        )
3768
3769    @dist_init
3770    def test_user_rrefs_confirmed(self):
3771        dst_rank = (self.rank + 1) % self.world_size
3772        rref = self._create_rref()
3773        ret = rpc.rpc_sync(
3774            worker_name(dst_rank),
3775            check_rref_confirmed,
3776            args=(rref,)
3777        )
3778        self.assertEqual(ret, True)
3779
3780    @dist_init
3781    def test_user_rrefs_confirmed_remote(self):
3782        dst_rank = (self.rank + 1) % self.world_size
3783        rref = self._create_rref()
3784        ret_rref = rpc.remote(
3785            worker_name(dst_rank),
3786            check_rref_confirmed,
3787            args=(rref,)
3788        )
3789        self.assertEqual(ret_rref.to_here(), True)
3790
3791    @dist_init
3792    def test_rref_py_pickle_not_supported(self):
3793        local_rref = RRef(35)
3794        with TemporaryFileName() as fname:
3795            with self.assertRaisesRegex(RuntimeError, "Can not pickle rref in python pickler"):
3796                torch.save(local_rref, fname)
3797
3798    @dist_init
3799    def test_remote_throw(self):
3800        rref = rpc.remote(worker_name((self.rank + 1) % self.world_size),
3801                          raise_or_inc,
3802                          args=(torch.ones(2),))
3803        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
3804            rref.to_here()
3805
3806    @dist_init
3807    def test_non_cont_tensors(self):
3808        if self.rank == 0:
3809            # Create a non-contiguous tensor.
3810            t = torch.rand(5, 5)
3811            t_view = t.narrow(1, 2, 2)
3812            self.assertFalse(t_view.is_contiguous())
3813            t_cont = t_view.contiguous()
3814            self.assertTrue(t_cont.is_contiguous())
3815            self.assertEqual(t_view, t_cont)
3816
3817            # Send non-cont tensor over RPC.
3818            next_rank = (self.rank + 1) % self.world_size
3819            t_ret = rpc.rpc_sync(worker_name(next_rank), non_cont_test, args=(t_view, t_cont))
3820
3821            # Verify the returned tensor.
3822            self.assertEqual(t_view, t_ret)
3823            self.assertFalse(t_ret.is_contiguous())
3824
3825    @dist_init
3826    def test_callback_simple(self):
3827        set_by_cb = concurrent.futures.Future()
3828        n = self.rank + 1
3829
3830        def callback(fut):
3831            ret = fut.wait()
3832            self.assertEqual(ret, torch.ones(n, n) * 2)
3833            set_by_cb.set_result(ret.clone() + 1)
3834
3835        fut = rpc.rpc_async(
3836            worker_name(n % self.world_size),
3837            torch.add,
3838            args=(torch.ones(n, n), torch.ones(n, n))
3839        )
3840
3841        fut.then(callback)
3842
3843        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
3844        self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
3845        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
3846
3847    @dist_init
3848    def test_callback_wrong_arg_num(self):
3849        set_by_cb = concurrent.futures.Future()
3850        n = self.rank + 1
3851
3852        fut = rpc.rpc_async(
3853            worker_name(n % self.world_size),
3854            torch.add,
3855            args=(torch.ones(n, n), torch.ones(n, n))
3856        )
3857
3858        cb_fut = fut.then(my_function)
3859
3860        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
3861
3862        with self.assertRaisesRegex(
3863            RuntimeError,
3864            "my\\_function\\(\\) missing 2 required positional arguments"
3865        ):
3866            cb_fut.wait()
3867
3868    @dist_init
3869    def test_callback_wrong_arg_type(self):
3870        dst = worker_name((self.rank + 1) % self.world_size)
3871
3872        fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
3873        fut1 = fut0.then(lambda x: x + 1)
3874
3875        with self.assertRaisesRegex(
3876            RuntimeError,
3877            "unsupported operand type\\(s\\) for \\+"
3878        ):
3879            fut1.wait()
3880
3881    @dist_init
3882    def test_callback_multi(self):
3883        num_cbs = 10
3884        n = self.rank + 1
3885
3886        def callback(idx, fut):
3887            ret = fut.wait()
3888            self.assertEqual(ret, torch.ones(n, n) * 2)
3889            return ret + idx
3890
3891        fut = rpc.rpc_async(
3892            worker_name(n % self.world_size),
3893            torch.add,
3894            args=(torch.ones(n, n), torch.ones(n, n))
3895        )
3896
3897        cb_futs = []
3898        for idx in range(num_cbs):
3899            cb_futs.append(fut.then(partial(callback, idx)))
3900
3901        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
3902
3903        for idx in range(num_cbs):
3904            self.assertEqual(
3905                cb_futs[idx].wait(),
3906                torch.ones(n, n) * 2 + idx
3907            )
3908
3909        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
3910
3911    @dist_init
3912    def test_callback_chain(self):
3913        n = self.rank + 1
3914        dst = worker_name(n % self.world_size)
3915
3916        def callback(fut):
3917            return fut.wait() + 1
3918
3919        fut = rpc.rpc_async(
3920            worker_name(n % self.world_size),
3921            torch.add,
3922            args=(torch.ones(n, n), 1)
3923        )
3924
3925        num_cbs = 20
3926        for _ in range(num_cbs):
3927            fut = fut.then(callback)
3928
3929        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
3930
3931    @dist_init
3932    def test_callback_in_rpc(self):
3933        dst1 = worker_name((self.rank + 1) % self.world_size)
3934        dst2 = worker_name((self.rank + 2) % self.world_size)
3935
3936        ret = rpc.rpc_sync(
3937            dst1,
3938            add_use_future_cb,
3939            args=(dst2, torch.ones(2, 2), 1, 2)
3940        )
3941        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
3942
3943    @dist_init
3944    def test_callback_with_ret(self):
3945        dst = worker_name((self.rank + 1) % self.world_size)
3946
3947        def callback(fut0):
3948            fut2 = rpc.rpc_async(
3949                dst,
3950                torch.add,
3951                args=(fut0.wait(), 1)
3952            ).then(lambda fut1: fut1.wait() + 1)
3953
3954            return fut2.wait()
3955
3956        fut3 = rpc.rpc_async(
3957            dst,
3958            torch.add,
3959            args=(torch.ones(2, 2), 1)
3960        ).then(callback)
3961
3962        self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
3963
3964    @dist_init
3965    def test_callback_with_error(self):
3966        dst = worker_name((self.rank + 1) % self.world_size)
3967
3968        def callback(fut0):
3969            with self.assertRaisesRegex(ValueError, "Expected error"):
3970                fut0.wait()
3971            raise RuntimeError("Another expected error")
3972
3973        fut1 = rpc.rpc_async(dst, raise_func).then(callback)
3974        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
3975            fut1.wait()
3976
3977    @dist_init
3978    def test_callback_none(self):
3979        dst = worker_name((self.rank + 1) % self.world_size)
3980        with self.assertRaisesRegex(
3981            TypeError,
3982            "incompatible function arguments."
3983        ):
3984            rpc.rpc_async(dst, raise_func).then(None)
3985
3986    @dist_init
3987    def test_add_done_callback(self):
3988        set_by_cb = False
3989        n = self.rank + 1
3990
3991        def callback(fut):
3992            nonlocal set_by_cb
3993            fut.wait()
3994            set_by_cb = True
3995
3996        fut = rpc.rpc_async(
3997            worker_name(n % self.world_size),
3998            torch.add,
3999            args=(torch.ones(n, n), torch.ones(n, n))
4000        )
4001
4002        fut.add_done_callback(callback)
4003        fut_then = fut.then(lambda _: True)
4004
4005        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
4006
4007        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
4008        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
4009        fut_then.wait()
4010        self.assertTrue(set_by_cb)
4011        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
4012
4013    @dist_init
4014    def test_mark_future_twice(self):
4015        fut = rpc.rpc_async(
4016            worker_name((self.rank + 1) % self.world_size),
4017            torch.add,
4018            args=(torch.zeros(2, 2), 1)
4019        )
4020        self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
4021        with self.assertRaisesRegex(
4022            RuntimeError,
4023            "Future can only be marked completed once"
4024        ):
4025            fut.set_result(1)
4026
4027    @dist_init
4028    def test_pickle_future(self):
4029        fut = torch.futures.Future()
4030        errMsg = "Can not pickle torch.futures.Future"
4031
4032        dst = worker_name((self.rank + 1) % self.world_size)
4033        with TemporaryFileName() as fname:
4034            with self.assertRaisesRegex(RuntimeError, errMsg):
4035                rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
4036
4037        with TemporaryFileName() as fname:
4038            with self.assertRaisesRegex(RuntimeError, errMsg):
4039                rpc.rpc_async(dst, fail_on_fut, args=(fut,))
4040
4041        with TemporaryFileName() as fname:
4042            with self.assertRaisesRegex(RuntimeError, errMsg):
4043                rpc.remote(dst, fail_on_fut, args=(fut,))
4044
4045    @dist_init
4046    def test_future_done(self):
4047        dst = worker_name((self.rank + 1) % self.world_size)
4048        fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
4049        fut.wait()
4050        self.assertTrue(fut.done())
4051
4052    @dist_init
4053    def test_future_done_exception(self):
4054        dst = worker_name((self.rank + 1) % self.world_size)
4055        fut = rpc.rpc_async(dst, raise_func)
4056        with self.assertRaisesRegex(ValueError, "Expected error"):
4057            fut.wait()
4058        self.assertTrue(fut.done())
4059
4060    def _test_future_cb(self, func):
4061        dst1 = worker_name((self.rank + 1) % self.world_size)
4062        dst2 = worker_name((self.rank + 2) % self.world_size)
4063
4064        ret = rpc.rpc_sync(
4065            dst1,
4066            func,
4067            args=(dst2, torch.ones(2, 2), 1, 2)
4068        )
4069        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
4070
4071    @dist_init
4072    def test_future_in_rpc(self):
4073        self._test_future_cb(add_use_future_set_result)
4074
4075    @dist_init
4076    def test_future_nested_callback(self):
4077        self._test_future_cb(add_use_future_nested_cb)
4078
4079    def _test_async_function_raise(self, mode):
4080        with self.assertRaisesRegex(RuntimeError, "Expected error"):
4081            self._run_func_in_mode(
4082                worker_name((self.rank + 1) % self.world_size),
4083                async_raise_func,
4084                mode
4085            )
4086
4087    @dist_init
4088    def test_async_function_raise(self):
4089        self._test_async_function_raise(RPCExecMode.SYNC)
4090
4091    @dist_init
4092    def test_async_function_raise_async(self):
4093        self._test_async_function_raise(RPCExecMode.ASYNC)
4094
4095    @dist_init
4096    def test_async_function_raise_remote(self):
4097        self._test_async_function_raise(RPCExecMode.REMOTE)
4098
4099    def _test_async_function_wrong_return_type(self, mode):
4100        errMsg = (
4101            "Functions decorated with @rpc\\.async_function must return a "
4102            "torch\\.futures\\.Future object,"
4103        )
4104        with self.assertRaisesRegex(RuntimeError, errMsg):
4105            self._run_func_in_mode(
4106                worker_name((self.rank + 1) % self.world_size),
4107                async_wrong_type,
4108                mode
4109            )
4110
4111    @dist_init
4112    def test_async_function_wrong_return_type(self):
4113        self._test_async_function_wrong_return_type(RPCExecMode.SYNC)
4114
4115    @dist_init
4116    def test_async_function_wrong_return_type_async(self):
4117        self._test_async_function_wrong_return_type(RPCExecMode.ASYNC)
4118
4119    @dist_init
4120    def test_async_function_wrong_return_type_remote(self):
4121        self._test_async_function_wrong_return_type(RPCExecMode.REMOTE)
4122
4123    @dist_init
4124    def test_async_function_simple(self):
4125        dst1 = worker_name((self.rank + 1) % self.world_size)
4126        dst2 = worker_name((self.rank + 2) % self.world_size)
4127
4128        ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1))
4129        self.assertEqual(ret, torch.ones(2, 2) + 1)
4130
4131    def _test_async_function(self, fn, mode=RPCExecMode.SYNC):
4132        dst1 = worker_name((self.rank + 1) % self.world_size)
4133        dst2 = worker_name((self.rank + 2) % self.world_size)
4134
4135        args = (dst2, torch.ones(2, 2), 1, 2)
4136        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
4137        self.assertEqual(ret, torch.ones(2, 2) + 3)
4138
4139    @dist_init
4140    def test_async_function_with_future_ctor(self):
4141        self._test_async_function(async_add_with_future_ctor)
4142
4143    @dist_init
4144    def test_async_function_with_future_ctor_remote(self):
4145        self._test_async_function(
4146            async_add_with_future_ctor,
4147            RPCExecMode.REMOTE
4148        )
4149
4150    @dist_init
4151    def test_async_function_chained(self):
4152        self._test_async_function(async_add_chained)
4153
4154    @dist_init
4155    def test_async_function_chained_remote(self):
4156        self._test_async_function(async_add_chained, RPCExecMode.REMOTE)
4157
4158    @dist_init
4159    def test_async_function_nested(self):
4160        self._test_async_function(async_add_nested)
4161
4162    @dist_init
4163    def test_async_function_nested_remote(self):
4164        self._test_async_function(async_add_nested, RPCExecMode.REMOTE)
4165
4166    @dist_init
4167    def test_async_static_method(self):
4168        self._test_async_function(AsyncExecutionClass.static_async_add)
4169
4170    @dist_init
4171    def test_async_static_method_remote(self):
4172        self._test_async_function(
4173            AsyncExecutionClass.static_async_add,
4174            RPCExecMode.REMOTE
4175        )
4176
4177    @dist_init
4178    def test_async_class_method(self):
4179        self._test_async_function(AsyncExecutionClass.class_async_add)
4180
4181    @dist_init
4182    def test_async_class_method_remote(self):
4183        self._test_async_function(
4184            AsyncExecutionClass.class_async_add,
4185            RPCExecMode.REMOTE
4186        )
4187
4188    def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC):
4189        dst1 = worker_name((self.rank + 1) % self.world_size)
4190        dst2 = worker_name((self.rank + 2) % self.world_size)
4191        rref = rpc.remote(dst1, AsyncExecutionClass)
4192
4193        x = torch.ones(2, 2)
4194        y = torch.ones(2, 2) + 1
4195        if mode == RPCExecMode.SYNC:
4196            ret = rref.rpc_sync().static_async_add(dst2, x, x, y)
4197            ret += rref.rpc_sync().class_async_add(dst2, x, x, y)
4198            ret += rref.rpc_sync().bound_async_add(dst2, x, x, y)
4199        elif mode == RPCExecMode.ASYNC:
4200            ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait()
4201            ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait()
4202            ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait()
4203        elif mode == RPCExecMode.REMOTE:
4204            ret = rref.remote().static_async_add(dst2, x, x, y).to_here()
4205            ret += rref.remote().class_async_add(dst2, x, x, y).to_here()
4206            ret += rref.remote().bound_async_add(dst2, x, x, y).to_here()
4207
4208        self.assertEqual(ret, 3 * 4 * x)
4209
4210    @dist_init
4211    def test_async_class_rref_proxy(self):
4212        self._test_test_async_class_rref_proxy()
4213
4214    @dist_init
4215    def test_async_class_rref_proxy_async(self):
4216        self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC)
4217
4218    @dist_init
4219    def test_async_class_rref_proxy_remote(self):
4220        self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE)
4221
4222    def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC):
4223        dst1 = worker_name((self.rank + 1) % self.world_size)
4224        dst2 = worker_name((self.rank + 2) % self.world_size)
4225
4226        num = 20
4227        step = 3
4228        args = (dst2, torch.ones(2, 2), num, step)
4229        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
4230        self.assertEqual(ret, torch.ones(2, 2) + num * step)
4231
4232    @dist_init
4233    def test_async_function_multi_chained(self):
4234        self._test_async_function_multi(async_add_chained_multi)
4235
4236    @dist_init
4237    def test_async_function_multi_chained_async(self):
4238        self._test_async_function_multi(
4239            async_add_chained_multi,
4240            RPCExecMode.ASYNC
4241        )
4242
4243    @dist_init
4244    def test_async_function_multi_chained_remote(self):
4245        self._test_async_function_multi(
4246            async_add_chained_multi,
4247            RPCExecMode.REMOTE
4248        )
4249
4250    @dist_init
4251    def test_async_function_multi_fanout(self):
4252        self._test_async_function_multi(async_add_multi_fanout)
4253
4254    @dist_init
4255    def test_async_function_multi_fanout_async(self):
4256        self._test_async_function_multi(
4257            async_add_multi_fanout,
4258            RPCExecMode.ASYNC
4259        )
4260
4261    @dist_init
4262    def test_async_function_multi_fanout_remote(self):
4263        self._test_async_function_multi(
4264            async_add_multi_fanout,
4265            RPCExecMode.REMOTE
4266        )
4267
4268    def _test_return_future(self, mode):
4269        with self.assertRaisesRegex(
4270            RuntimeError,
4271            "Can not pickle torch.futures.Future"
4272        ):
4273            self._run_func_in_mode(
4274                worker_name((self.rank + 1) % self.world_size),
4275                return_future,
4276                mode
4277            )
4278
4279    @dist_init
4280    def test_return_future(self):
4281        self._test_return_future(RPCExecMode.SYNC)
4282
4283    @dist_init
4284    def test_return_future_async(self):
4285        self._test_return_future(RPCExecMode.ASYNC)
4286
4287    @dist_init
4288    def test_return_future_remote(self):
4289        self._test_return_future(RPCExecMode.REMOTE)
4290
4291    @dist_init
4292    def test_rref_timeout(self):
4293        # This test is similar to ones in FaultyProcessGroupTest, but is meant to be
4294        # run with other backends besides ProcessGroup.
4295        if self.rank != 0:
4296            return
4297
4298        dst_rank = (self.rank + 1) % self.world_size
4299        dst_worker = f"worker{dst_rank}"
4300        # 10 ms timeout
4301        rref = rpc.remote(dst_worker, my_sleep_func, args=(2, ), timeout=0.01)
4302        # Future corresponding to the remote creation should time out.
4303        expected_error = self.get_timeout_error_regex()
4304        with self.assertRaisesRegex(RuntimeError, expected_error):
4305            rref._get_future().wait()
4306        # Call to ensure pending callbacks are run.
4307        wait_until_pending_futures_and_users_flushed()
4308        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
4309            rref.to_here()
4310
4311        wait_until_owners_and_forks_on_rank(1, 1, rank=1)
4312
4313    @dist_init(setup_rpc=False)
4314    @skip_but_pass_in_sandcastle_if(
4315        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
4316        "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614."
4317    )
4318    def test_init_pg_then_rpc(self):
4319        dist.init_process_group(
4320            backend="gloo",
4321            init_method=self.init_method,
4322            rank=self.rank,
4323            world_size=self.world_size,
4324        )
4325
4326        rpc.init_rpc(
4327            name=worker_name(self.rank),
4328            backend=self.rpc_backend,
4329            rank=self.rank,
4330            world_size=self.world_size,
4331            rpc_backend_options=self.rpc_backend_options,
4332        )
4333
4334        # Test RPC.
4335        next_rank = (self.rank + 1) % self.world_size
4336        ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1))
4337        self.assertEqual(ret, torch.ones(2, 2) + 1)
4338
4339        # Test PG
4340        dist.barrier()
4341
4342        rpc.shutdown()
4343
4344    @dist_init(setup_rpc=False)
4345    @skip_but_pass_in_sandcastle_if(
4346        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
4347        "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614."
4348    )
4349    def test_init_rpc_then_pg(self):
4350        rpc.init_rpc(
4351            name=worker_name(self.rank),
4352            backend=self.rpc_backend,
4353            rank=self.rank,
4354            world_size=self.world_size,
4355            rpc_backend_options=self.rpc_backend_options,
4356        )
4357
4358        dist.init_process_group(
4359            backend="gloo",
4360            init_method=self.init_method,
4361            rank=self.rank,
4362            world_size=self.world_size,
4363        )
4364
4365        # Test RPC.
4366        next_rank = (self.rank + 1) % self.world_size
4367        ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1))
4368        self.assertEqual(ret, torch.ones(2, 2) + 1)
4369
4370        # Test PG
4371        dist.barrier()
4372
4373        rpc.shutdown()
4374
4375    @dist_init
4376    def test_wait_all_with_exception(self):
4377        futs = []
4378        dst = worker_name((self.rank + 1) % self.world_size)
4379        for _ in range(10):
4380            futs.append(rpc.rpc_async(dst, raise_func))
4381
4382        with self.assertRaisesRegex(ValueError, "Expected error"):
4383            ret = torch.futures.wait_all(futs)
4384
4385    @dist_init
4386    def test_wait_all_with_partial_exception(self):
4387        futs = []
4388        dst = worker_name((self.rank + 1) % self.world_size)
4389        for _ in range(10):
4390            futs.append(rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1)))
4391
4392        futs.append(rpc.rpc_async(dst, raise_func))
4393
4394        with self.assertRaisesRegex(ValueError, "Expected error"):
4395            ret = torch.futures.wait_all(futs)
4396
4397    @dist_init(setup_rpc=False)
4398    @skip_but_pass_in_sandcastle_if(
4399        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
4400        "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491",
4401    )
4402    def test_init_rpc_twice(self):
4403        initialize_pg(self.file_init_method, self.rank, self.world_size)
4404
4405        rpc.init_rpc(
4406            name=worker_name(self.rank),
4407            backend=self.rpc_backend,
4408            rank=self.rank,
4409            world_size=self.world_size,
4410            rpc_backend_options=self.rpc_backend_options,
4411        )
4412        rpc.shutdown()
4413
4414        # Wait for all init to complete.
4415        dist.barrier()
4416
4417        # Use a different file name for the next initialization
4418        new_backend_options = self.rpc_backend_options
4419        new_backend_options.init_method += "init_2"
4420
4421        # Ensure rpc initialization works again.
4422        rpc.init_rpc(
4423            name=worker_name(self.rank),
4424            backend=self.rpc_backend,
4425            rank=self.rank,
4426            world_size=self.world_size,
4427            rpc_backend_options=new_backend_options,
4428        )
4429
4430        # Verify RPCs work after re-init.
4431        dst = worker_name((self.rank + 1) % self.world_size)
4432        rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
4433        rpc.rpc_sync(dst, foo_add, args=())
4434
4435        rpc.shutdown()
4436
4437    def test_wrong_types(self):
4438        with self.assertRaisesRegex(
4439            TypeError,
4440            "Argument backend must be a member of BackendType",
4441        ):
4442            rpc.init_rpc(
4443                name=worker_name(self.rank),
4444                rank=self.rank,
4445                world_size=self.world_size,
4446                backend="TENSORPIPE",
4447            )
4448
4449        with self.assertRaisesRegex(
4450            TypeError,
4451            "Argument rpc_backend_options must be an instance of RpcBackendOptions",
4452        ):
4453            rpc.init_rpc(
4454                name=worker_name(self.rank),
4455                rank=self.rank,
4456                world_size=self.world_size,
4457                backend=self.rpc_backend,
4458                rpc_backend_options={"init_method": self.init_method}
4459            )
4460
4461    def test_cannot_infer_backend_from_options(self):
4462        # An exception should be raised if the backend isn't specified but
4463        # options are given which are not an instance of any of the known
4464        # agents' option classes.
4465        rpc_backend_options = FooBackendOptions(self.init_method)
4466
4467        with self.assertRaisesRegex(TypeError, "Could not infer backend for options"):
4468            rpc.init_rpc(
4469                name=worker_name(self.rank),
4470                rank=self.rank,
4471                world_size=self.world_size,
4472                # Do _not_ pass backend.
4473                rpc_backend_options=rpc_backend_options,
4474            )
4475
4476    @dist_init
4477    def test_owner_rref_backward(self):
4478        dst = worker_name((self.rank + 1) % self.world_size)
4479        t1 = torch.rand(10, 10, requires_grad=True)
4480        rref = rpc.RRef(t1.sum() + t1.sum())
4481        rref.backward()
4482        expected_grad = torch.ones_like(t1) * 2
4483        self.assertEqual(expected_grad, t1.grad)
4484
4485        with dist_autograd.context() as context_id:
4486            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
4487            rref = rpc.RRef(t2.sum())
4488            rref.backward(context_id)
4489            self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1])
4490
4491        # Double backward.
4492        with dist_autograd.context() as context_id:
4493            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
4494            rref = rpc.RRef(t2.sum())
4495            rref.backward(context_id, retain_graph=True)
4496            rref.backward(context_id)
4497            self.assertEqual(expected_grad * 2, dist_autograd.get_gradients(context_id)[t1])
4498
4499        # Test errors.
4500        with self.assertRaisesRegex(RuntimeError, "tensors does not require grad and does not have a grad_fn"):
4501            rpc.RRef(torch.rand(10)).backward()
4502
4503        with self.assertRaisesRegex(RuntimeError, "grad can be implicitly created only for scalar outputs"):
4504            rpc.RRef(torch.rand(10, requires_grad=True)).backward()
4505
4506        with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id: 100"):
4507            rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100)
4508
4509        with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"):
4510            rpc.RRef("foo").backward()
4511
4512    @staticmethod
4513    def _sum(x):
4514        return x.sum()
4515
4516    @staticmethod
4517    def _identity(x):
4518        return x
4519
4520    @dist_init
4521    def test_user_rref_backward(self):
4522        dst = worker_name((self.rank + 1) % self.world_size)
4523        t = torch.rand(10, requires_grad=True)
4524        with dist_autograd.context() as context_id:
4525            rref = rpc.remote(dst, RpcTest._sum, args=(t,))
4526            rref.backward(context_id, retain_graph=True)
4527            rref.backward(context_id)
4528            self.assertEqual(torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t])
4529
4530        with dist_autograd.context() as context_id:
4531            rref = rpc.remote(dst, RpcTest._identity, args=("foo",))
4532            with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"):
4533                rref.backward(context_id)
4534
4535            with self.assertRaisesRegex(RuntimeError, "User RRefs require 'dist_autograd_ctx_id' to be specified"):
4536                rref.backward()
4537
4538    @dist_init(setup_rpc=False)
4539    def test_shutdown_errors(self):
4540        initialize_pg(self.file_init_method, self.rank, self.world_size)
4541
4542        rpc.init_rpc(
4543            name=worker_name(self.rank),
4544            backend=self.rpc_backend,
4545            rank=self.rank,
4546            world_size=self.world_size,
4547            rpc_backend_options=self.rpc_backend_options,
4548        )
4549
4550        if self.rank != 0:
4551            og_func = rpc.api._broadcast_to_followers
4552            og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs
4553
4554            # Monkey-patch _broadcast_to_followers to fail, which would ensure
4555            # _all_gather on leader raises an exception.
4556            def raise_error(sequence_id, objects_map):
4557                og_func(sequence_id, objects_map)
4558                raise RuntimeError('simulation')
4559
4560            # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail,
4561            # which would ensure barrier is not called on followers.
4562            def rref_error():
4563                raise RuntimeError('simulation rref')
4564
4565            try:
4566                rpc.api._broadcast_to_followers = raise_error
4567                rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error
4568                with self.assertRaisesRegex(RuntimeError, 'simulation rref'):
4569                    rpc.shutdown()
4570            finally:
4571                rpc.api._broadcast_to_followers = og_func
4572                rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func
4573        else:
4574            with self.assertRaisesRegex(RuntimeError, 'timed out in _all_gather'):
4575                rpc.shutdown()
4576
4577        dist.barrier()
4578
4579    @dist_init
4580    def test_my_parameter_server(self):
4581        self._my_parameter_server(False)
4582
4583
4584class CudaRpcTest(RpcAgentTestFixture):
4585
4586    @skip_if_lt_x_gpu(2)
4587    @dist_init
4588    def test_profiler_remote_cuda(self):
4589        if self.rank != 1:
4590            return
4591
4592        dst_cuda_0 = (self.rank + 1) % self.world_size
4593        dst_cuda_1 = (self.rank + 2) % self.world_size
4594        dst_worker_cuda_0 = worker_name(dst_cuda_0)
4595        dst_worker_cuda_1 = worker_name(dst_cuda_1)
4596
4597        with _profile(use_cuda=True) as p:
4598            fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0, ))
4599            fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1, ))
4600            fut1.wait()
4601            fut2.wait()
4602
4603        def get_name(event):
4604            return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR):]
4605
4606        function_events = p.function_events
4607        for event in function_events:
4608            if event.is_async:
4609                self.assertEqual(0, event.device_time_total)
4610                self.assertEqual([], event.kernels)
4611                self.assertEqual(0, event.device_time)
4612            else:
4613                if event.node_id == 1:
4614                    continue
4615                self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1])
4616                if get_name(event) in EXPECTED_REMOTE_EVENTS:
4617                    self.assertGreater(event.device_time_total, 0)
4618                    self.assertEqual(1, len(event.kernels))
4619                    kernel = event.kernels[0]
4620                    if event.node_id == dst_cuda_0:
4621                        self.assertEqual(kernel.device, 0)
4622                    if event.node_id == dst_cuda_1:
4623                        self.assertEqual(kernel.device, 1)
4624                    self.assertGreater(event.device_time, 0)
4625
4626        # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled
4627        # events.
4628        remote_events = [event for event in function_events if event.is_remote]
4629        remote_event_names = [get_name(event) for event in remote_events if get_name(event) in EXPECTED_REMOTE_EVENTS]
4630        self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS))
4631
4632
4633class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon):
4634
4635    def test_mismatched_type_for_options(self):
4636        # An exception should be raised if the options are not an instance of
4637        # TensorPipeRpcBackendOptions.
4638        rpc_backend_options = FooBackendOptions(self.init_method)
4639
4640        with self.assertRaisesRegex(
4641            TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`"
4642        ):
4643            rpc.init_rpc(
4644                name=worker_name(self.rank),
4645                rank=self.rank,
4646                world_size=self.world_size,
4647                backend=rpc.BackendType.TENSORPIPE,
4648                rpc_backend_options=rpc_backend_options,
4649            )
4650
4651    def test_infer_backend_from_options(self):
4652        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
4653            init_method=self.init_method,
4654            _transports=tp_transports()
4655        )
4656
4657        rpc.init_rpc(
4658            name=worker_name(self.rank),
4659            rank=self.rank,
4660            world_size=self.world_size,
4661            # Do _not_ pass backend.
4662            rpc_backend_options=rpc_backend_options,
4663        )
4664
4665        self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent)
4666
4667    # FIXME Merge this test with the corresponding one in RpcTest.
4668    @dist_init(setup_rpc=False)
4669    def test_set_and_get_num_worker_threads(self):
4670        NUM_THREADS = 27
4671        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
4672            init_method=self.rpc_backend_options.init_method,
4673            num_worker_threads=NUM_THREADS,
4674            _transports=tp_transports(),
4675        )
4676        rpc.init_rpc(
4677            name=worker_name(self.rank),
4678            backend=self.rpc_backend,
4679            rank=self.rank,
4680            world_size=self.world_size,
4681            rpc_backend_options=rpc_backend_options,
4682        )
4683
4684        info = rpc.api._get_current_rpc_agent().get_debug_info()
4685        self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
4686        rpc.shutdown()
4687
4688    # FIXME Merge this test with the corresponding one in RpcTest.
4689    @dist_init(setup_rpc=False)
4690    def test_tensorpipe_set_default_timeout(self):
4691        # Set a high timeout since it doesn't affect test runtime and ensures
4692        # the test doesn't erroneously timeout due to slow machines.
4693        timeout = 100
4694        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
4695            init_method=self.rpc_backend_options.init_method,
4696            num_worker_threads=self.rpc_backend_options.num_worker_threads,
4697            rpc_timeout=timeout,
4698            _transports=tp_transports(),
4699        )
4700        rpc.init_rpc(
4701            name=worker_name(self.rank),
4702            backend=self.rpc_backend,
4703            rank=self.rank,
4704            world_size=self.world_size,
4705            rpc_backend_options=rpc_backend_options,
4706        )
4707
4708        default_timeout = rpc.get_rpc_timeout()
4709        self.assertEqual(default_timeout, timeout)
4710        rpc.shutdown()
4711
4712    # FIXME Merge this test with the corresponding one in RpcTest.
4713    @dist_init(setup_rpc=False)
4714    def test_tensorpipe_options_throw_on_timedelta_timeout(self):
4715        from datetime import timedelta
4716
4717        timeout = timedelta()
4718        # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails
4719        with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"):
4720            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
4721                init_method=self.rpc_backend_options.init_method,
4722                num_worker_threads=self.rpc_backend_options.num_worker_threads,
4723                rpc_timeout=timeout,
4724            )
4725
4726    @dist_init
4727    def _test_rref_get_type_timeout(self, blocking):
4728        # Test where we try to get the type of a RRef from an owner, but RRef
4729        # creation is slower than timeout passed into _get_type.
4730        dst_rank = (self.rank + 1) % self.world_size
4731        dst = worker_name(dst_rank)
4732        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
4733        timeout = 0.5
4734        expected_err = self.get_timeout_error_regex()
4735        # Blocking: blocks on inline call
4736        if blocking:
4737            with self.assertRaisesRegex(RuntimeError, expected_err):
4738                slow_rref._get_type(timeout=timeout, blocking=blocking)
4739        # Non-blocking: blocks on wait
4740        else:
4741            fut = slow_rref._get_type(timeout=timeout, blocking=blocking)
4742            with self.assertRaisesRegex(RuntimeError, expected_err):
4743                fut.wait()
4744
4745        # FIXME We wait until the remote completed creating the OwnerRRef
4746        # because there's currently a race if we shut down RPC before that.
4747        slow_rref.to_here()
4748
4749    def test_rref_get_type_timeout_blocking(self):
4750        self._test_rref_get_type_timeout(blocking=True)
4751
4752    def test_rref_get_type_timeout_non_blocking(self):
4753        self._test_rref_get_type_timeout(blocking=False)
4754
4755    @dist_init
4756    def test_op_with_invalid_args(self):
4757        dst = worker_name((self.rank + 1) % self.world_size)
4758        with self.assertRaisesRegex(
4759            RuntimeError, "Overloaded torch operator invoked from Python failed to match any schema"
4760        ):
4761            rpc.rpc_sync(dst, torch.add, args=())
4762
4763    def _test_rref_proxy_timeout(self, rref_proxy_api):
4764        dst_rank = (self.rank + 1) % self.world_size
4765        dst = worker_name(dst_rank)
4766        rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), ))
4767        # Ensure RRef is created on remote node.
4768        rref.to_here()
4769        rref_api = getattr(rref, rref_proxy_api)
4770        self.assertTrue(rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}")
4771        expected_error = self.get_timeout_error_regex()
4772        timeout = 2
4773        with self.assertRaisesRegex(RuntimeError, expected_error):
4774            result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2))
4775            if rref_api == rref.rpc_async:
4776                result.wait()
4777            elif rref_api == rref.remote:
4778                result._get_future().wait()
4779
4780        # Case where rpc.remote() is stuck and exceeds timeout
4781        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
4782        timeout = 0.01
4783        rref_api = getattr(slow_rref, rref_proxy_api)
4784        # Note that even when we call rref.rpc_async() in this case, we
4785        # time out in future creation, not waiting for future. This is because
4786        # rref proxy function calls rref._get_type before returning future,
4787        # which blocks on the RRef being created on owner node, until the
4788        # specified timeout.
4789        with self.assertRaisesRegex(RuntimeError, expected_error):
4790            result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
4791            # rpc_async returns immediately and surface a timeout through wait()
4792            if rref_api == slow_rref.rpc_async:
4793                result.wait()
4794
4795        # FIXME We wait until the remote completed creating the OwnerRRef
4796        # because there's currently a race if we shut down RPC before that.
4797        slow_rref.to_here()
4798
4799    @dist_init
4800    def test_rref_proxy_timeout(self):
4801        for rpc_api in ["rpc_sync", "rpc_async", "remote"]:
4802            self._test_rref_proxy_timeout(rpc_api)
4803
4804    @dist_init
4805    def test_send_to_rank_sparse(self):
4806        dst_rank = (self.rank + 1) % self.world_size
4807
4808        # Test sparse tensor
4809        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
4810            x = build_sparse_tensor()
4811            y = build_sparse_tensor()
4812            expected_tensor = (x + y)
4813            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
4814            self.assertEqual(expected_tensor, ret)
4815
4816        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
4817            x = build_sparse_tensor(coalesce=True)
4818            y = build_sparse_tensor(coalesce=True)
4819            expected_tensor = (x + y)
4820            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
4821            self.assertEqual(expected_tensor, ret)
4822
4823    @dist_init
4824    def test_self_py_udf_remote_sparse(self):
4825        self._self_py_udf_remote(
4826            rpc.get_worker_info(),
4827            build_sparse_tensor(),
4828            build_sparse_tensor(),
4829            build_sparse_tensor()
4830        )
4831
4832    @dist_init
4833    def test_self_remote_rref_as_rpc_arg_sparse(self):
4834        dst = worker_name((self.rank + 1) % self.world_size)
4835        self._self_remote_rref_as_rpc_arg(
4836            dst,
4837            build_sparse_tensor(),
4838            build_sparse_tensor(),
4839            build_sparse_tensor()
4840        )
4841
4842    @dist_init
4843    def test_self_remote_rref_as_self_rpc_arg_sparse(self):
4844        self._self_remote_rref_as_rpc_arg(
4845            rpc.get_worker_info(),
4846            build_sparse_tensor(),
4847            build_sparse_tensor(),
4848            build_sparse_tensor()
4849        )
4850
4851    @dist_init
4852    def test_self_remote_rref_as_remote_arg_sparse(self):
4853        dst = worker_name((self.rank + 1) % self.world_size)
4854        self._self_remote_rref_as_remote_arg(
4855            dst,
4856            build_sparse_tensor(),
4857            build_sparse_tensor(),
4858            build_sparse_tensor()
4859        )
4860
4861    @dist_init
4862    def test_self_remote_rref_as_self_remote_arg_sparse(self):
4863        self._self_remote_rref_as_remote_arg(
4864            rpc.get_worker_info(),
4865            build_sparse_tensor(),
4866            build_sparse_tensor(),
4867            build_sparse_tensor()
4868        )
4869
4870    def test_world_size_one_sparse(self):
4871        self._world_size_one(
4872            build_sparse_tensor(),
4873            build_sparse_tensor()
4874        )
4875
4876    @dist_init
4877    def test_multi_rpc_sparse(self):
4878        self._multi_rpc(True)
4879
4880    def test_wait_all_workers_sparse(self):
4881        self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
4882
4883    def test_wait_all_workers_twice_sparse(self):
4884        self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
4885
4886    @dist_init
4887    def test_py_sparse_tensors_in_container(self):
4888        n = self.rank + 1
4889        dst_rank = n % self.world_size
4890        a = [build_sparse_tensor(), build_sparse_tensor()]
4891        ret = rpc.rpc_sync(
4892            worker_name(dst_rank), my_container_sum, args=(a,)
4893        )
4894        self.assertEqual(ret, my_container_sum(a))
4895
4896    @dist_init
4897    def test_nested_rpc_sparse(self):
4898        self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
4899
4900    @dist_init
4901    def test_stress_heavy_rpc_sparse(self):
4902        self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),))
4903
4904    @dist_init
4905    def test_builtin_remote_ret_sparse(self):
4906        self._builtin_remote_ret(
4907            build_sparse_tensor(),
4908            build_sparse_tensor(),
4909            build_sparse_tensor() * 2
4910        )
4911
4912    @dist_init
4913    def test_builtin_remote_self_sparse(self):
4914        self._builtin_remote_self(
4915            build_sparse_tensor(),
4916            build_sparse_tensor(),
4917            build_sparse_tensor() * 2
4918        )
4919
4920    @dist_init
4921    def test_multi_builtin_remote_ret_sparse(self):
4922        self._test_multi_remote_call(
4923            torch.add, True,
4924            args_fn=RpcTest._multi_args_fn
4925        )
4926
4927    @dist_init
4928    def test_multi_py_udf_remote_sparse(self):
4929        self._test_multi_remote_call(
4930            my_function,
4931            True,
4932            kwargs_fn=RpcTest._multi_kwargs_fn
4933        )
4934
4935    @dist_init
4936    def test_py_rref_args_sparse(self):
4937        self._py_rref_args(
4938            build_sparse_tensor(),
4939            build_sparse_tensor(),
4940            build_sparse_tensor(),
4941            build_sparse_tensor(),
4942            build_sparse_tensor() * 4
4943        )
4944
4945    @dist_init
4946    def test_py_rref_args_user_share_sparse(self):
4947        self._py_rref_args_user_share(
4948            build_sparse_tensor(),
4949            build_sparse_tensor(),
4950            build_sparse_tensor(),
4951            build_sparse_tensor(),
4952            build_sparse_tensor(),
4953            build_sparse_tensor(),
4954            build_sparse_tensor() * 6
4955        )
4956
4957    @dist_init
4958    def test_py_rpc_rref_args_sparse(self):
4959        self._py_rpc_rref_args(
4960            build_sparse_tensor(),
4961            build_sparse_tensor(),
4962            build_sparse_tensor(),
4963            build_sparse_tensor(),
4964            build_sparse_tensor(),
4965            build_sparse_tensor(),
4966            build_sparse_tensor() * 6
4967        )
4968
4969    @dist_init
4970    def test_nested_remote_sparse(self):
4971        self._nested_remote(
4972            nested_remote_sparse,
4973            build_sparse_tensor() + build_sparse_tensor()
4974        )
4975
4976    @dist_init
4977    def test_nested_rref_sparse(self):
4978        self._nested_rref(
4979            nested_rref_sparse,
4980            build_sparse_tensor() * 2,
4981            build_sparse_tensor() * 2
4982        )
4983
4984    @dist_init
4985    def test_nested_rref_stress_sparse(self):
4986        self._nested_rref_stress(
4987            nested_rref_sparse,
4988            build_sparse_tensor() * 2,
4989            build_sparse_tensor() * 2
4990        )
4991
4992    @dist_init
4993    def test_my_parameter_server_sparse(self):
4994        self._my_parameter_server(True)
4995
4996    # Test init_rpc without world_size argument
4997    @dist_init(setup_rpc=False)
4998    def test_dynamic_rpc_init_rpc(self):
4999        rpc.init_rpc(
5000            name=worker_name(self.rank),
5001            backend=self.rpc_backend,
5002            rank=self.rank,
5003            rpc_backend_options=self.rpc_backend_options,
5004        )
5005        rpc.shutdown()
5006
5007    # Dynamic RPC new ranks communicate with existing ranks
5008    @dist_init(setup_rpc=False)
5009    def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self):
5010        initialize_pg(self.file_init_method, self.rank, self.world_size)
5011
5012        if self.rank == 0:
5013            rpc.init_rpc(
5014                name=worker_name(self.rank),
5015                backend=self.rpc_backend,
5016                rank=self.rank,
5017                rpc_backend_options=self.rpc_backend_options,
5018            )
5019
5020        # Rank 0 will be initialized with RPC after this barrier
5021        dist.barrier()
5022
5023        if self.rank != 0:
5024            # Newly joined ranks will be able to communicate with rank 0, since that was created first
5025            rpc.init_rpc(
5026                name=worker_name(self.rank),
5027                backend=self.rpc_backend,
5028                rank=self.rank,
5029                rpc_backend_options=self.rpc_backend_options,
5030            )
5031            result = rpc.rpc_sync(worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1)))
5032            self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
5033
5034        # Barrier to ensure that all rpc_sync calls are finished
5035        dist.barrier()
5036        rpc.shutdown()
5037
5038    # Dynamic RPC existing ranks can communicate with new ranks
5039    @dist_init(setup_rpc=False)
5040    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self):
5041        initialize_pg(self.file_init_method, self.rank, self.world_size)
5042
5043        if self.rank == 0:
5044            rpc.init_rpc(
5045                name=worker_name(self.rank),
5046                backend=self.rpc_backend,
5047                rank=self.rank,
5048                rpc_backend_options=self.rpc_backend_options,
5049            )
5050
5051        # Rank 0 will be initialized with RPC after this barrier
5052        dist.barrier()
5053
5054        # Rest of ranks join after barrier
5055        if self.rank != 0:
5056            # Newly joined ranks will be able to communicate with rank 0, since that was created first
5057            rpc.init_rpc(
5058                name=worker_name(self.rank),
5059                backend=self.rpc_backend,
5060                rank=self.rank,
5061                rpc_backend_options=self.rpc_backend_options,
5062            )
5063
5064        dist.barrier()
5065        if self.rank == 0:
5066            for i in range(1, self.world_size):
5067                result = rpc.rpc_sync(worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1)))
5068                self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
5069
5070        # Barrier to ensure that all rpc_sync calls are finished
5071        dist.barrier()
5072        rpc.shutdown()
5073
5074    # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc
5075    @skip_if_lt_x_gpu(2)
5076    @dist_init(setup_rpc=False)
5077    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self):
5078        initialize_pg(self.file_init_method, self.rank, self.world_size)
5079
5080        if self.rank == 0:
5081            options = self.rpc_backend_options
5082            for i in range(1, self.world_size):
5083                dst = worker_name(i)
5084                options.set_device_map(dst, {1: 0})
5085                options.set_device_map(dst, {0: 1})
5086            rpc.init_rpc(
5087                name=worker_name(self.rank),
5088                backend=self.rpc_backend,
5089                rank=self.rank,
5090                rpc_backend_options=options,
5091            )
5092
5093        # Rank 0 will be initialized with RPC after this barrier
5094        dist.barrier()
5095
5096        # Rest of ranks join after barrier
5097        if self.rank != 0:
5098            # Newly joined ranks will be able to communicate with rank 0, since that was created first
5099            rpc.init_rpc(
5100                name=worker_name(self.rank),
5101                backend=self.rpc_backend,
5102                rank=self.rank,
5103                rpc_backend_options=self.rpc_backend_options,
5104            )
5105
5106        # TODO: Cuda RPC is failing due to:
5107        # terminate called after throwing an instance of 'c10::Error'
5108        # what():  0 <= device && static_cast<size_t>(device) < device_allocator.size()
5109        # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937,
5110        # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init?
5111        # dist.barrier()
5112        # if self.rank == 0:
5113        #     for i in range(1, self.world_size):
5114        #         x = torch.ones(2)
5115        #         result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1))
5116        #         result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1))
5117        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0)
5118        #         self.assertEqual(torch.device('cuda:0'), result_on_device_0.device)
5119        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1)
5120        #         self.assertEqual(torch.device('cuda:1'), result_on_device_1.device)
5121
5122        # Barrier to ensure that all rpc_sync calls are finished
5123        dist.barrier()
5124        rpc.shutdown()
5125
5126    @dist_init(setup_rpc=False)
5127    def test_dynamic_rpc_init_rpc_without_rank(self):
5128        # default initialization uses file init
5129        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
5130            rpc.init_rpc(
5131                name=worker_name(self.rank),
5132                backend=self.rpc_backend,
5133                rpc_backend_options=self.rpc_backend_options,
5134            )
5135
5136        # env init
5137        with self.assertRaisesRegex(ValueError, "environment variable RANK expected"):
5138            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://")
5139            rpc.init_rpc(
5140                name=worker_name(self.rank),
5141                backend=self.rpc_backend,
5142                rpc_backend_options=rpc_backend_options,
5143            )
5144
5145        # tcp init
5146        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
5147            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="tcp://127.0.0.1:23456")
5148            rpc.init_rpc(
5149                name=worker_name(self.rank),
5150                backend=self.rpc_backend,
5151                rpc_backend_options=rpc_backend_options,
5152            )
5153
5154    @dist_init(setup_rpc=False)
5155    def test_dynamic_and_static_init_rpc_together(self):
5156        # Initialize a static rpc group with size = self.world_size - 1
5157        dist.init_process_group(
5158            backend='gloo',
5159            init_method=self.file_init_method,
5160            rank=self.rank,
5161            world_size=self.world_size)
5162
5163        world_size_minus_one = self.world_size - 1
5164        if self.rank < world_size_minus_one:
5165            rpc.init_rpc(
5166                name=worker_name(self.rank),
5167                backend=self.rpc_backend,
5168                rank=self.rank,
5169                world_size=world_size_minus_one,
5170                rpc_backend_options=self.rpc_backend_options,
5171            )
5172
5173        dist.barrier()
5174
5175        # Attempt to add an additional dynamic group member
5176        if self.rank == world_size_minus_one:
5177            # Expect error message to be thrown
5178            with self.assertRaisesRegex(RuntimeError, "RPC group mixes statically and dynamically\
5179 initialized members which is not supported."):
5180                rpc.init_rpc(
5181                    name=worker_name(self.rank),
5182                    backend=self.rpc_backend,
5183                    rank=self.rank,
5184                    rpc_backend_options=self.rpc_backend_options,
5185                )
5186
5187class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
5188
5189    def _test_device_maps(self, options, errMsg):
5190        with self.assertRaisesRegex(ValueError, errMsg):
5191            rpc.init_rpc(
5192                name=worker_name(self.rank),
5193                backend=self.rpc_backend,
5194                rank=self.rank,
5195                world_size=self.world_size,
5196                rpc_backend_options=options,
5197            )
5198
5199        self.assertFalse(rpc.api._is_current_rpc_agent_set())
5200
5201    @skip_if_lt_x_gpu(2)
5202    def test_device_maps_wrong_worker_name(self):
5203        options = self.rpc_backend_options
5204        options.set_device_map("none_exist", {0: 1})
5205
5206        self._test_device_maps(
5207            options,
5208            errMsg="Node worker0 has invalid target node names in its device maps"
5209        )
5210
5211    @skip_if_lt_x_gpu(1)
5212    def test_device_maps_invalid_max_local_device(self):
5213        options = self.rpc_backend_options
5214        dst = worker_name((self.rank + 1) % self.world_size)
5215        options.set_device_map(dst, {torch.cuda.device_count(): 0})
5216
5217        self._test_device_maps(
5218            options,
5219            errMsg="Node worker0 has source devices with invalid indices in its device map for worker1"
5220        )
5221
5222    @skip_if_lt_x_gpu(1)
5223    def test_device_maps_invalid_max_remote_device(self):
5224        options = self.rpc_backend_options
5225        dst = worker_name((self.rank + 1) % self.world_size)
5226        options.set_device_map(dst, {0: torch.cuda.device_count()})
5227
5228        self._test_device_maps(
5229            options,
5230            errMsg="Node worker0 has target devices with invalid indices in its device map for worker1"
5231        )
5232
5233    @skip_if_lt_x_gpu(2)
5234    def test_device_maps_many_to_one(self):
5235        options = self.rpc_backend_options
5236        dst = worker_name((self.rank + 1) % self.world_size)
5237        options.set_device_map(dst, {1: 0})
5238        options.set_device_map(dst, {0: 0})
5239
5240        self._test_device_maps(
5241            options,
5242            errMsg="Node worker0 has duplicated target devices in its device map for worker1"
5243        )
5244
5245    @skip_if_lt_x_gpu(2)
5246    def test_device_maps_one_to_many(self):
5247        if self.rank == 0:
5248            options = self.rpc_backend_options
5249            dst = worker_name((self.rank + 1) % self.world_size)
5250            options.set_device_map(dst, {0: 1})
5251            with self.assertRaisesRegex(
5252                ValueError, "`set_device_map` only supports 1-to-1 mapping"
5253            ):
5254                options.set_device_map(dst, {0: 0})
5255
5256    @skip_if_lt_x_gpu(1)
5257    def test_device_maps_invalid_min_device(self):
5258        options = self.rpc_backend_options
5259        dst = worker_name((self.rank + 1) % self.world_size)
5260        with self.assertRaisesRegex(
5261            RuntimeError, "Device index must not be negative"
5262        ):
5263            options.set_device_map(dst, {-1: 0})
5264
5265        with self.assertRaisesRegex(
5266            RuntimeError, "Device index must not be negative"
5267        ):
5268            options.set_device_map(dst, {0: -1})
5269
5270    @staticmethod
5271    def _gpu_add(x, y):
5272        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]):
5273            return (x + y).to(0)
5274        else:
5275            raise ValueError("Wrong device affinity")
5276
5277    @skip_if_lt_x_gpu(2)
5278    def test_device_maps_gpu(self):
5279        options = self.rpc_backend_options
5280        dst = worker_name((self.rank + 1) % self.world_size)
5281        options.set_device_map(dst, {0: 1, 1: 0})
5282
5283        rpc.init_rpc(
5284            name=worker_name(self.rank),
5285            backend=self.rpc_backend,
5286            rank=self.rank,
5287            world_size=self.world_size,
5288            rpc_backend_options=options,
5289        )
5290
5291        ret = rpc.rpc_sync(
5292            dst,
5293            TensorPipeAgentCudaRpcTest._gpu_add,
5294            args=(torch.zeros(2).to(0), torch.ones(2).to(0))
5295        )
5296        self.assertEqual(ret.device, torch.device(1))
5297        self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1))
5298        rpc.shutdown()
5299
5300    @staticmethod
5301    def _gpu_add_given_devices(x, y, x_to, y_to, z_to):
5302        x_device = "cpu" if x.device.type == "cpu" else x.device.index
5303        y_device = "cpu" if y.device.type == "cpu" else y.device.index
5304        if x_device == x_to and y_device == y_to:
5305            return x.to(z_to) + y.to(z_to)
5306        else:
5307            raise ValueError("Wrong device affinity")
5308
5309    def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None, fn=None):
5310        fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn
5311        x_to = device_map[x_from]
5312        y_to = device_map[y_from]
5313
5314        options = self.rpc_backend_options
5315        dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst
5316        options.set_device_map(dst, device_map)
5317
5318        rpc.init_rpc(
5319            name=worker_name(self.rank),
5320            backend=self.rpc_backend,
5321            rank=self.rank,
5322            world_size=self.world_size,
5323            rpc_backend_options=options,
5324        )
5325
5326        x = torch.zeros(2).to(x_from)
5327        y = torch.ones(2).to(y_from)
5328
5329        ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to))
5330
5331        reverse_device_map = {device_map[k] : k for k in device_map}
5332        z_from = reverse_device_map[z_to]
5333
5334        ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index
5335        self.assertEqual(ret_device, z_from)
5336        self.assertEqual(ret, torch.ones(2).to(z_from))
5337
5338        rpc.shutdown()
5339
5340    def test_device_map_cpu(self):
5341        self._test_device_maps_gpu(
5342            x_from="cpu",
5343            y_from="cpu",
5344            z_to="cpu",
5345            device_map={"cpu" : "cpu"},
5346            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
5347        )
5348
5349    @skip_if_lt_x_gpu(1)
5350    def test_device_map_cpu_to_gpu_default(self):
5351        self._test_device_maps_gpu(
5352            x_from="cpu",
5353            y_from="cpu",
5354            z_to=0,
5355            device_map={"cpu" : 0},
5356            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
5357        )
5358
5359    @skip_if_lt_x_gpu(2)
5360    def test_device_map_cpu_to_gpu_non_default(self):
5361        self._test_device_maps_gpu(
5362            x_from="cpu",
5363            y_from="cpu",
5364            z_to=1,
5365            device_map={"cpu" : 1},
5366            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
5367        )
5368
5369    @skip_if_lt_x_gpu(1)
5370    def test_device_map_gpu_to_cpu_default(self):
5371        self._test_device_maps_gpu(
5372            x_from=0,
5373            y_from=0,
5374            z_to="cpu",
5375            device_map={0 : "cpu"},
5376            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
5377        )
5378
5379    @skip_if_lt_x_gpu(2)
5380    def test_device_map_gpu_to_cpu_non_default(self):
5381        self._test_device_maps_gpu(
5382            x_from=1,
5383            y_from=1,
5384            z_to="cpu",
5385            device_map={1 : "cpu"},
5386            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
5387        )
5388
5389    @skip_if_lt_x_gpu(2)
5390    def test_device_map_gpu_default(self):
5391        self._test_device_maps_gpu(
5392            x_from=0,
5393            y_from=0,
5394            z_to=0,
5395            device_map={0 : 0}
5396        )
5397
5398    @skip_if_lt_x_gpu(2)
5399    def test_device_map_gpu_non_default(self):
5400        self._test_device_maps_gpu(
5401            x_from=1,
5402            y_from=1,
5403            z_to=1,
5404            device_map={1 : 1}
5405        )
5406
5407    @skip_if_lt_x_gpu(2)
5408    def test_device_map_gpu_default_to_non_default(self):
5409        self._test_device_maps_gpu(
5410            x_from=0,
5411            y_from=0,
5412            z_to=1,
5413            device_map={0 : 1}
5414        )
5415
5416    @skip_if_lt_x_gpu(2)
5417    def test_device_map_gpu_non_default_to_default(self):
5418        self._test_device_maps_gpu(
5419            x_from=1,
5420            y_from=1,
5421            z_to=0,
5422            device_map={1 : 0}
5423        )
5424
5425    @skip_if_lt_x_gpu(2)
5426    def test_device_map_gpu_mixed_1(self):
5427        self._test_device_maps_gpu(
5428            x_from=0,
5429            y_from=1,
5430            z_to=0,
5431            device_map={0 : 0, 1 : 1}
5432        )
5433
5434    @skip_if_lt_x_gpu(2)
5435    def test_device_map_gpu_mixed_2(self):
5436        self._test_device_maps_gpu(
5437            x_from=0,
5438            y_from=1,
5439            z_to=1,
5440            device_map={0 : 0, 1 : 1}
5441        )
5442
5443    @skip_if_lt_x_gpu(2)
5444    def test_device_map_gpu_mixed_3(self):
5445        self._test_device_maps_gpu(
5446            x_from=1,
5447            y_from=0,
5448            z_to=0,
5449            device_map={0 : 0, 1 : 1}
5450        )
5451
5452    @skip_if_lt_x_gpu(2)
5453    def test_device_map_gpu_mixed_4(self):
5454        self._test_device_maps_gpu(
5455            x_from=1,
5456            y_from=0,
5457            z_to=1,
5458            device_map={0 : 0, 1 : 1}
5459        )
5460
5461    @skip_if_lt_x_gpu(2)
5462    def test_device_map_gpu_mixed_5(self):
5463        self._test_device_maps_gpu(
5464            x_from=0,
5465            y_from=1,
5466            z_to=0,
5467            device_map={0 : 1, 1 : 0}
5468        )
5469
5470    @skip_if_lt_x_gpu(2)
5471    def test_device_map_gpu_mixed_6(self):
5472        self._test_device_maps_gpu(
5473            x_from=0,
5474            y_from=1,
5475            z_to=1,
5476            device_map={0 : 1, 1 : 0}
5477        )
5478
5479    @skip_if_lt_x_gpu(2)
5480    def test_device_map_gpu_mixed_7(self):
5481        self._test_device_maps_gpu(
5482            x_from=1,
5483            y_from=0,
5484            z_to=0,
5485            device_map={0 : 1, 1 : 0}
5486        )
5487
5488    @skip_if_lt_x_gpu(2)
5489    def test_device_map_gpu_mixed_8(self):
5490        self._test_device_maps_gpu(
5491            x_from=1,
5492            y_from=0,
5493            z_to=1,
5494            device_map={0 : 1, 1 : 0}
5495        )
5496
5497    @skip_if_lt_x_gpu(2)
5498    def test_device_map_gpu_mixed_self_1(self):
5499        self._test_device_maps_gpu(
5500            x_from=0,
5501            y_from=1,
5502            z_to=0,
5503            device_map={0 : 0, 1 : 1},
5504            dst=worker_name(self.rank)
5505        )
5506
5507    @skip_if_lt_x_gpu(2)
5508    def test_device_map_gpu_mixed_self_2(self):
5509        self._test_device_maps_gpu(
5510            x_from=0,
5511            y_from=1,
5512            z_to=1,
5513            device_map={0 : 0, 1 : 1},
5514            dst=worker_name(self.rank)
5515        )
5516
5517    @skip_if_lt_x_gpu(2)
5518    def test_device_map_gpu_mixed_self_3(self):
5519        self._test_device_maps_gpu(
5520            x_from=1,
5521            y_from=0,
5522            z_to=0,
5523            device_map={0 : 0, 1 : 1},
5524            dst=worker_name(self.rank)
5525        )
5526
5527    @skip_if_lt_x_gpu(2)
5528    def test_device_map_gpu_mixed_self_4(self):
5529        self._test_device_maps_gpu(
5530            x_from=1,
5531            y_from=0,
5532            z_to=1,
5533            device_map={0 : 0, 1 : 1},
5534            dst=worker_name(self.rank)
5535        )
5536
5537    @skip_if_lt_x_gpu(2)
5538    def test_device_map_gpu_mixed_self_5(self):
5539        self._test_device_maps_gpu(
5540            x_from=0,
5541            y_from=1,
5542            z_to=0,
5543            device_map={0 : 1, 1 : 0},
5544            dst=worker_name(self.rank)
5545        )
5546
5547    @skip_if_lt_x_gpu(2)
5548    def test_device_map_gpu_mixed_self_6(self):
5549        self._test_device_maps_gpu(
5550            x_from=0,
5551            y_from=1,
5552            z_to=1,
5553            device_map={0 : 1, 1 : 0},
5554            dst=worker_name(self.rank)
5555        )
5556
5557    @skip_if_lt_x_gpu(2)
5558    def test_device_map_gpu_mixed_self_7(self):
5559        self._test_device_maps_gpu(
5560            x_from=1,
5561            y_from=0,
5562            z_to=0,
5563            device_map={0 : 1, 1 : 0},
5564            dst=worker_name(self.rank)
5565        )
5566
5567    @skip_if_lt_x_gpu(2)
5568    def test_device_map_gpu_mixed_self_8(self):
5569        self._test_device_maps_gpu(
5570            x_from=1,
5571            y_from=0,
5572            z_to=1,
5573            device_map={0 : 1, 1 : 0},
5574            dst=worker_name(self.rank)
5575        )
5576
5577    @staticmethod
5578    def _gpu_add_multi_gpu(x, y):
5579        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]):
5580            return x.to(0) + y, x - y.to(1)
5581        else:
5582            raise ValueError("Wrong device affinity")
5583
5584    def _test_device_maps_multi_gpu(self, dst):
5585        options = self.rpc_backend_options
5586        options.set_device_map(dst, {0: 1})
5587        options.set_device_map(dst, {1: 0})
5588
5589        rpc.init_rpc(
5590            name=worker_name(self.rank),
5591            backend=self.rpc_backend,
5592            rank=self.rank,
5593            world_size=self.world_size,
5594            rpc_backend_options=options,
5595        )
5596
5597        x = torch.zeros(2).to(0)
5598        y = torch.ones(2).to(1)
5599        rets = rpc.rpc_sync(
5600            dst,
5601            TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu,
5602            args=(x, y)
5603        )
5604
5605        self.assertEqual(rets[0].device, torch.device(1))
5606        self.assertEqual(rets[1].device, torch.device(0))
5607        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
5608        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
5609        rpc.shutdown()
5610
5611    @skip_if_lt_x_gpu(2)
5612    def test_device_maps_multi_gpu(self):
5613        dst = worker_name((self.rank + 1) % self.world_size)
5614        self._test_device_maps_multi_gpu(dst)
5615
5616    @skip_if_lt_x_gpu(2)
5617    def test_device_maps_multi_gpu_self(self):
5618        dst = worker_name(self.rank)
5619        self._test_device_maps_multi_gpu(dst)
5620
5621    @staticmethod
5622    def _gpu_add_return_to_gpu(x, y):
5623        if x.device.type == 'cpu' and y.device.type == 'cpu':
5624            return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3)
5625        else:
5626            raise ValueError("Wrong device affinity")
5627
5628    @skip_if_lt_x_gpu(2)
5629    def test_device_maps_in_options(self):
5630        dst = worker_name((self.rank + 1) % self.world_size)
5631        options = self.rpc_backend_options
5632
5633        rpc.init_rpc(
5634            name=worker_name(self.rank),
5635            backend=self.rpc_backend,
5636            rank=self.rank,
5637            world_size=self.world_size,
5638            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
5639                init_method=options.init_method,
5640                num_worker_threads=options.num_worker_threads,
5641                device_maps={dst: {0: 1, 1: 0}},
5642                _transports=tp_transports()
5643            )
5644        )
5645
5646        rets = rpc.rpc_sync(
5647            dst,
5648            TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu,
5649            args=(torch.zeros(2).to(0), torch.ones(2).to(1))
5650        )
5651        self.assertEqual(rets[0].device, torch.device(1))
5652        self.assertEqual(rets[1].device, torch.device(0))
5653        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
5654        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
5655        rpc.shutdown()
5656
5657    def _test_device_maps_return_to_gpu(self, dst):
5658        options = self.rpc_backend_options
5659
5660        options.set_device_map(dst, {0: 1})
5661        options.set_device_map(dst, {1: 2})
5662        options.set_device_map(dst, {2: 3})
5663        options.set_device_map(dst, {3: 0})
5664
5665        rpc.init_rpc(
5666            name=worker_name(self.rank),
5667            backend=self.rpc_backend,
5668            rank=self.rank,
5669            world_size=self.world_size,
5670            rpc_backend_options=options,
5671        )
5672
5673        rets = rpc.rpc_sync(
5674            dst,
5675            TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu,
5676            args=(torch.zeros(2), torch.ones(2))
5677        )
5678        for i in range(len(rets)):
5679            self.assertEqual(rets[i].device, torch.device((3 + i) % 4))
5680        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3))
5681        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
5682        self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1))
5683        self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2))
5684        rpc.shutdown()
5685
5686    @skip_if_lt_x_gpu(4)
5687    def test_device_maps_return_to_gpu(self):
5688        dst = worker_name((self.rank + 1) % self.world_size)
5689        self._test_device_maps_return_to_gpu(dst)
5690
5691    @skip_if_lt_x_gpu(4)
5692    def test_device_maps_return_to_gpu_self(self):
5693        dst = worker_name(self.rank)
5694        self._test_device_maps_return_to_gpu(dst)
5695
5696    @staticmethod
5697    def _add_to_gpu(x, y):
5698        return (x + y).to(0)
5699
5700    def _test_device_maps_missing_config(self, mode):
5701        dst = worker_name((self.rank + 1) % self.world_size)
5702        errMsg = (
5703            "TensorPipe RPC backend only supports CPU tensors by default.*"
5704            "`set_device_map` on `TensorPipeRpcBackendOptions`"
5705        )
5706
5707        with self.assertRaisesRegex(RuntimeError, errMsg):
5708            if mode == RPCExecMode.SYNC:
5709                rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1))
5710            elif mode == RPCExecMode.REMOTE:
5711                rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here()
5712            else:
5713                raise ValueError(f"unexpected mode {mode}")
5714
5715        # make sure RPC is still functioning
5716        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
5717        self.assertEqual(ret, torch.ones(2) + 1)
5718
5719    def _test_device_maps_missing_config_response(self, mode):
5720        dst = worker_name((self.rank + 1) % self.world_size)
5721        errMsg = "Response device mapping is not available"
5722
5723        with self.assertRaisesRegex(RuntimeError, errMsg):
5724            if mode == RPCExecMode.SYNC:
5725                rpc.rpc_sync(
5726                    dst,
5727                    TensorPipeAgentCudaRpcTest._add_to_gpu,
5728                    args=(torch.zeros(2), 1)
5729                )
5730            elif mode == RPCExecMode.REMOTE:
5731                rpc.remote(
5732                    dst,
5733                    TensorPipeAgentCudaRpcTest._add_to_gpu,
5734                    args=(torch.zeros(2), 1)
5735                ).to_here()
5736            else:
5737                raise ValueError(f"unexpected mode {mode}")
5738
5739        # make sure RPC is still functioning
5740        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
5741        self.assertEqual(ret, torch.ones(2) + 1)
5742
5743    @skip_if_lt_x_gpu(1)
5744    @dist_init
5745    def test_device_maps_missing_config(self):
5746        self._test_device_maps_missing_config(RPCExecMode.SYNC)
5747
5748    @skip_if_lt_x_gpu(1)
5749    def test_device_maps_missing_config_not_timeout(self):
5750        dst = worker_name((self.rank + 1) % self.world_size)
5751        options = self.rpc_backend_options
5752
5753        rpc.init_rpc(
5754            name=worker_name(self.rank),
5755            backend=self.rpc_backend,
5756            rank=self.rank,
5757            world_size=self.world_size,
5758            rpc_backend_options=self.rpc_backend_options
5759        )
5760
5761        timeout = rpc.get_rpc_timeout()
5762
5763        tik = time.time()
5764        self._test_device_maps_missing_config(RPCExecMode.SYNC)
5765        rpc.shutdown()
5766        tok = time.time()
5767
5768        self.assertTrue(tok - tik < timeout)
5769
5770    @skip_if_lt_x_gpu(1)
5771    @dist_init
5772    def test_device_maps_missing_config_loop(self):
5773        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
5774            self._test_device_maps_missing_config(RPCExecMode.SYNC)
5775
5776    @skip_if_lt_x_gpu(1)
5777    @dist_init
5778    def test_device_maps_missing_config_response(self):
5779        self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
5780
5781    @skip_if_lt_x_gpu(1)
5782    @dist_init
5783    def test_device_maps_missing_config_response_loop(self):
5784        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
5785            self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
5786
5787    @skip_if_lt_x_gpu(1)
5788    @dist_init
5789    def test_device_maps_missing_config_remote(self):
5790        self._test_device_maps_missing_config(RPCExecMode.REMOTE)
5791
5792    @skip_if_lt_x_gpu(1)
5793    @dist_init
5794    def test_device_maps_missing_config_remote_response(self):
5795        self._test_device_maps_missing_config_response(RPCExecMode.REMOTE)
5796
5797    @skip_if_lt_x_gpu(2)
5798    def test_device_maps_remote(self):
5799        options = self.rpc_backend_options
5800        dst = worker_name((self.rank + 1) % self.world_size)
5801        options.set_device_map(dst, {1: 0})
5802
5803        rpc.init_rpc(
5804            name=worker_name(self.rank),
5805            backend=self.rpc_backend,
5806            rank=self.rank,
5807            world_size=self.world_size,
5808            rpc_backend_options=options,
5809        )
5810
5811        rref = rpc.remote(
5812            dst,
5813            TensorPipeAgentCudaRpcTest._add_to_gpu,
5814            args=(torch.zeros(2), 1)
5815        )
5816
5817        self.assertEqual(rref.to_here().device.index, 1)
5818        self.assertEqual(rref.to_here(), torch.ones(2).to(1))
5819
5820        rpc.shutdown()
5821
5822    @staticmethod
5823    def _slow_add_on_user_stream(x, y):
5824        s0 = torch.cuda.current_stream(x.device)
5825        s1 = torch.cuda.Stream(device=x.device)
5826        s1.wait_stream(s0)
5827        x.record_stream(s1)
5828        y.record_stream(s1)
5829        with torch.cuda.stream(s1):
5830            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
5831            z = x + y
5832        s0.wait_stream(s1)
5833        z.record_stream(s0)
5834        return z
5835
5836    def _test_custom_stream(self, fn, device_map):
5837        options = self.rpc_backend_options
5838        dst = worker_name((self.rank + 1) % self.world_size)
5839        options.set_device_map(dst, device_map)
5840
5841        rpc.init_rpc(
5842            name=worker_name(self.rank),
5843            backend=self.rpc_backend,
5844            rank=self.rank,
5845            world_size=self.world_size,
5846            rpc_backend_options=options,
5847        )
5848
5849        fn(dst)
5850
5851        rpc.shutdown()
5852
5853    def _test_stream_sync(self, dst):
5854        x = torch.ones(2, 2).to(0)
5855        ret = rpc.rpc_sync(
5856            dst,
5857            TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
5858            args=(x, x)
5859        )
5860        self.assertEqual(ret, 2 * x)
5861
5862    @skip_if_lt_x_gpu(2)
5863    def test_custom_stream(self):
5864        self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"})
5865
5866    def _test_stream_multi_async(self, dst):
5867        futs = []
5868        for i in range(20):
5869            x = torch.ones(2, 2).to(0) * i
5870            futs.append(
5871                rpc.rpc_async(
5872                    dst,
5873                    TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
5874                    args=(x, x)
5875                )
5876            )
5877
5878        for i in range(20):
5879            self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i)
5880
5881    @skip_if_lt_x_gpu(2)
5882    def test_custom_stream_multi(self):
5883        self._test_custom_stream(
5884            self._test_stream_multi_async,
5885            {"cuda:0": "cuda:1"}
5886        )
5887
5888    @staticmethod
5889    def _nested_slow_add_on_user_stream(dst, x, y, z):
5890        ret = rpc.rpc_sync(
5891            dst,
5892            TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
5893            args=(x, y)
5894        )
5895
5896        return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z)
5897
5898    def _test_stream_nested_sync(self, dst):
5899        x = torch.ones(2, 2).to(0)
5900        y = torch.ones(2, 2).to(0) * 2
5901        z = torch.ones(2, 2).to(0) * 3
5902        nested_dst = worker_name((self.rank + 2) % self.world_size)
5903        ret = rpc.rpc_sync(
5904            dst,
5905            TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
5906            args=(nested_dst, x, y, z)
5907        )
5908        self.assertEqual(ret, 6 * x)
5909
5910    @skip_if_lt_x_gpu(2)
5911    def test_custom_stream_nested(self):
5912        self._test_custom_stream(
5913            self._test_stream_nested_sync,
5914            {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
5915        )
5916
5917    def _test_stream_nested_multi_async(self, dst):
5918        if self.rank == 0:
5919            futs = []
5920            n = 5
5921            xs, ys, zs = [], [], []
5922            for i in range(n):
5923                x = torch.ones(2, 2).to(0) * (i - 1)
5924                y = torch.ones(2, 2).to(0) * i
5925                z = torch.ones(2, 2).to(0) * (i + 1)
5926                xs.append(x)
5927                ys.append(y)
5928                zs.append(z)
5929                nested_dst = worker_name((self.rank + 2) % self.world_size)
5930                futs.append(
5931                    rpc.rpc_async(
5932                        dst,
5933                        TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
5934                        args=(nested_dst, x, y, z)
5935                    )
5936                )
5937
5938            for i in range(n):
5939                self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i])
5940
5941    @skip_if_lt_x_gpu(2)
5942    def test_custom_stream_nested_multi(self):
5943        self._test_custom_stream(
5944            self._test_stream_nested_multi_async,
5945            {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
5946        )
5947
5948    @staticmethod
5949    def _gpu_add_wrong_gpus(x, y):
5950        if x.is_cuda and y.is_cuda:
5951            return x.cpu() + y.cuda()
5952        else:
5953            raise ValueError("Wrong device affinity")
5954
5955    @skip_if_lt_x_gpu(1)
5956    def test_device_mismatch(self):
5957        dst = worker_name((self.rank + 1) % self.world_size)
5958        options = self.rpc_backend_options
5959        options.set_device_map(dst, {0: 0})
5960
5961        rpc.init_rpc(
5962            name=worker_name(self.rank),
5963            backend=self.rpc_backend,
5964            rank=self.rank,
5965            world_size=self.world_size,
5966            rpc_backend_options=options,
5967        )
5968
5969        x = torch.zeros(2).to(0)
5970        y = torch.ones(2).to(0)
5971
5972        with self.assertRaisesRegex(
5973            RuntimeError,
5974            "Expected all tensors to be on the same device, but found at least two devices"
5975        ):
5976            rets = rpc.rpc_sync(
5977                dst,
5978                TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus,
5979                args=(x, y)
5980            )
5981
5982        rpc.shutdown()
5983
5984    def _test_rref_synchronization(self, local_device, remote_device):
5985        dst = worker_name((self.rank + 1) % self.world_size)
5986        options = self.rpc_backend_options
5987        options.set_device_map(dst, {local_device : remote_device})
5988
5989        rpc.init_rpc(
5990            name=worker_name(self.rank),
5991            backend=self.rpc_backend,
5992            rank=self.rank,
5993            world_size=self.world_size,
5994            rpc_backend_options=options,
5995        )
5996
5997        if self.rank == 1:
5998            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
5999            # If to_here() is properly synchronized with forward(x) the results must be identical
6000            # This test needs multiple iterations and significant batch size to simulate real
6001            # training of a CNN of MNIST-like data.
6002            # see https://github.com/pytorch/pytorch/issues/54771
6003            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
6004            for _ in range(10):
6005                x = torch.randn(200, 1, 28, 28).to(local_device)
6006                actual = rref.remote().forward(x).to_here()
6007                expected = rref.rpc_sync().forward(x)
6008                self.assertEqual(actual, expected)
6009
6010        rpc.shutdown()
6011
6012    @skip_if_lt_x_gpu(1)
6013    def test_rref_to_here_synchronization1(self):
6014        self._test_rref_synchronization("cuda:0", "cuda:0")
6015
6016    @skip_if_lt_x_gpu(2)
6017    def test_rref_to_here_synchronization2(self):
6018        self._test_rref_synchronization("cuda:1", "cuda:0")
6019
6020    @skip_if_lt_x_gpu(2)
6021    def test_rref_to_here_synchronization3(self):
6022        self._test_rref_synchronization("cuda:1", "cuda:1")
6023
6024    @skip_if_lt_x_gpu(2)
6025    def test_rref_to_here_synchronization4(self):
6026        self._test_rref_synchronization("cuda:0", "cuda:1")
6027
6028    def _test_rref_as_arg_synchronization(
6029        self,
6030        local_device,
6031        remote_device,
6032        devicesOptions=None
6033    ):
6034        dst = worker_name((self.rank + 1) % self.world_size)
6035        options = self.rpc_backend_options
6036        options.set_device_map(dst, {local_device: remote_device})
6037
6038        input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size)
6039        options.set_device_map(input_src, {remote_device: local_device})
6040
6041        if devicesOptions is not None:
6042            options.set_devices(devicesOptions[self.rank])
6043
6044        rpc.init_rpc(
6045            name=worker_name(self.rank),
6046            backend=self.rpc_backend,
6047            rank=self.rank,
6048            world_size=self.world_size,
6049            rpc_backend_options=options,
6050        )
6051
6052        if self.rank == 1:
6053            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
6054            # If to_here() is properly synchronized with forward(x) the results must be identical
6055            # This test needs multiple iterations and significant batch size to simulate real
6056            # training of a CNN of MNIST-like data.
6057            # see https://github.com/pytorch/pytorch/issues/54771
6058            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
6059            for _ in range(10):
6060                rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device))
6061                actual = rref.remote().forward(rref_x, True).to_here()
6062                expected = rref.rpc_sync().forward(rref_x, True)
6063                self.assertEqual(actual, expected)
6064
6065        rpc.shutdown()
6066
6067    @skip_if_lt_x_gpu(1)
6068    def test_rref_as_arg_synchronization1(self):
6069        self._test_rref_as_arg_synchronization("cuda:0", "cuda:0")
6070
6071    @skip_if_lt_x_gpu(2)
6072    def test_rref_as_arg_synchronization2(self):
6073        self._test_rref_as_arg_synchronization("cuda:1", "cuda:0")
6074
6075    @skip_if_lt_x_gpu(2)
6076    def test_rref_as_arg_synchronization3(self):
6077        self._test_rref_as_arg_synchronization("cuda:1", "cuda:1")
6078
6079    @skip_if_lt_x_gpu(2)
6080    def test_rref_as_arg_synchronization4(self):
6081        self._test_rref_as_arg_synchronization("cuda:0", "cuda:1")
6082
6083    @skip_if_lt_x_gpu(1)
6084    def test_rref_as_arg_synchronization5(self):
6085        self._test_rref_as_arg_synchronization(
6086            "cuda:0",
6087            "cuda:0",
6088            [["cuda:0"] for _ in range(4)],  # devicesOptions
6089        )
6090
6091    @staticmethod
6092    def _rref_relay(rref):
6093        return rref.to_here()
6094
6095    def _test_rref_forward_synchronization(self, local_device, remote_device):
6096        options = self.rpc_backend_options
6097
6098        input_src = worker_name(0)
6099        model_dst = worker_name(1)
6100        out_relay = worker_name(2)
6101
6102        if self.rank == 0:
6103            # for 1) model construction 2) forward execution
6104            options.set_device_map(model_dst, {local_device: remote_device})
6105
6106            # Forward output will be first copied to the relay node before
6107            # returning to the worker. This is intentional, to test RRef
6108            # forward CUDA stream synchronizations.
6109            options.set_device_map(out_relay, {local_device: local_device})
6110        elif self.rank == 1:
6111            # worker1 hosts the model and runs forward. The forward functions
6112            # calls RRef.to_here(), hence needs to configure the device map
6113            options.set_device_map(input_src, {remote_device: local_device})
6114        elif self.rank == 2:
6115            # worker2 will get the out RRef and call to_here() and hence, needs
6116            # to configure device map.
6117            options.set_device_map(model_dst, {local_device: remote_device})
6118
6119        rpc.init_rpc(
6120            name=worker_name(self.rank),
6121            backend=self.rpc_backend,
6122            rank=self.rank,
6123            world_size=self.world_size,
6124            rpc_backend_options=options,
6125        )
6126
6127        if self.rank == 0:
6128            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
6129            # If to_here() is properly synchronized with forward(x) the results must be identical
6130            # This test needs multiple iterations and significant batch size to simulate real
6131            # training of a CNN of MNIST-like data.
6132            # see https://github.com/pytorch/pytorch/issues/54771
6133            rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,))
6134            for _ in range(10):
6135                rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device))
6136                rref_out = rref.remote().forward(rref_input, True)
6137                out = rpc.remote(
6138                    out_relay,
6139                    TensorPipeAgentCudaRpcTest._rref_relay,
6140                    args=(rref_out,)
6141                ).to_here()
6142                expected = rref.rpc_sync().forward(rref_input, True)
6143                self.assertEqual(out, expected)
6144
6145        rpc.shutdown()
6146
6147    @skip_if_lt_x_gpu(1)
6148    def test_rref_forward_synchronization1(self):
6149        self._test_rref_forward_synchronization("cuda:0", "cuda:0")
6150
6151    @skip_if_lt_x_gpu(2)
6152    def test_rref_forward_synchronization2(self):
6153        self._test_rref_forward_synchronization("cuda:0", "cuda:1")
6154
6155    @skip_if_lt_x_gpu(2)
6156    def test_rref_forward_synchronization3(self):
6157        self._test_rref_forward_synchronization("cuda:1", "cuda:0")
6158
6159    @skip_if_lt_x_gpu(2)
6160    def test_rref_forward_synchronization4(self):
6161        self._test_rref_forward_synchronization("cuda:1", "cuda:1")
6162
6163    def _test_owner_rref_forward_synchronization(self, local_device, remote_device):
6164        if self.rank == 0:
6165            options = self.rpc_backend_options
6166            options.set_device_map("w0", {local_device: remote_device})
6167            rpc.init_rpc(
6168                "w0",
6169                rank=0,
6170                world_size=1,
6171                rpc_backend_options=options
6172            )
6173
6174            model = rpc.remote(
6175                "w0", torch.nn.Linear, (2048, 20000)
6176            ).remote().to(remote_device)
6177            for _ in range(30):
6178                data = torch.rand(2048, 2048).to(local_device)
6179                output = model.rpc_sync().forward(data)
6180                # to_here() internally calls localValue as the caller is
6181                # the owner of the RRef.
6182                v0 = rpc.RRef(output).remote().sum().to_here().item()
6183                v1 = output.sum().item()
6184                self.assertEqual(v0, v1)
6185
6186            rpc.shutdown()
6187
6188    @skip_if_lt_x_gpu(1)
6189    def test_owner_rref_forward_synchronization1(self):
6190        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0")
6191
6192    @skip_if_lt_x_gpu(2)
6193    def test_owner_rref_forward_synchronization2(self):
6194        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1")
6195
6196    @skip_if_lt_x_gpu(2)
6197    def test_owner_rref_forward_synchronization3(self):
6198        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0")
6199
6200    @skip_if_lt_x_gpu(2)
6201    def test_owner_rref_forward_synchronization4(self):
6202        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1")
6203
6204    @staticmethod
6205    def _return_tensor_view(i):
6206        x = torch.ones(1000, 200).cuda(0) * i
6207        torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
6208        # serialization of the return value will create a new tensor from the
6209        # view, which is done outside of the user function.
6210        return x.split(100)[0]
6211
6212    @skip_if_lt_x_gpu(1)
6213    def test_tensor_view_as_return_value(self):
6214        dst = worker_name((self.rank + 1) % self.world_size)
6215        options = self.rpc_backend_options
6216        options.set_device_map(dst, {0 : 0})
6217
6218        rpc.init_rpc(
6219            name=worker_name(self.rank),
6220            backend=self.rpc_backend,
6221            rank=self.rank,
6222            world_size=self.world_size,
6223            rpc_backend_options=options,
6224        )
6225
6226        futs = []
6227        for i in range(5):
6228            futs.append(rpc.rpc_async(
6229                dst,
6230                TensorPipeAgentCudaRpcTest._return_tensor_view,
6231                args=(i,)
6232            ))
6233
6234        for i in range(5):
6235            self.assertEqual(torch.ones(100, 200) * i, futs[i].wait())
6236
6237        rpc.shutdown()
6238
6239    @skip_if_lt_x_gpu(2)
6240    def test_devices_option_mismatch(self):
6241        with self.assertRaisesRegex(
6242            ValueError,
6243            "Node worker0 has unexpected source devices in its device map for worker1"
6244        ):
6245            dst = worker_name((self.rank + 1) % self.world_size)
6246            options = self.rpc_backend_options
6247            options.set_device_map(dst, {0 : 0})
6248            options.set_devices([1])
6249
6250            rpc.init_rpc(
6251                name=worker_name(self.rank),
6252                backend=self.rpc_backend,
6253                rank=self.rank,
6254                world_size=self.world_size,
6255                rpc_backend_options=options,
6256            )
6257
6258            rpc.shutdown()
6259
6260    @skip_if_lt_x_gpu(2)
6261    def test_devices_option_mismatch_reverse(self):
6262        with self.assertRaisesRegex(
6263            ValueError,
6264            "Node worker0 has unexpected target devices in its device map for worker1"
6265        ):
6266            dst = worker_name((self.rank + 1) % self.world_size)
6267
6268            options = rpc.TensorPipeRpcBackendOptions(
6269                init_method=self.rpc_backend_options.init_method,
6270                num_worker_threads=self.rpc_backend_options.num_worker_threads,
6271                device_maps={dst: {0 : 1}},
6272                devices=[0]
6273            )
6274
6275            rpc.init_rpc(
6276                name=worker_name(self.rank),
6277                backend=self.rpc_backend,
6278                rank=self.rank,
6279                world_size=self.world_size,
6280                rpc_backend_options=options,
6281            )
6282
6283            rpc.shutdown()
6284
6285    @skip_if_lt_x_gpu(1)
6286    def test_cuda_future_device_as_int(self):
6287        fut = Future(devices=[0])
6288
6289    @skip_if_lt_x_gpu(1)
6290    def test_cuda_future_device_as_str(self):
6291        fut = Future(devices=["cuda:0"])
6292
6293    @skip_if_lt_x_gpu(1)
6294    def test_cuda_future_device_as_device(self):
6295        fut = Future(devices=[torch.device("cuda", 0)])
6296
6297    @skip_if_lt_x_gpu(1)
6298    def test_cuda_future_device_not_cuda(self):
6299        with self.assertRaisesRegex(
6300            ValueError, "Expected devices to have indices, got cpu"
6301        ):
6302            fut = Future(devices=["cpu"])
6303
6304    @skip_if_lt_x_gpu(1)
6305    def test_cuda_future_can_extract_cuda_tensor(self):
6306        self._test_cuda_future_extraction(
6307            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False
6308        )
6309
6310    @skip_if_lt_x_gpu(1)
6311    def test_cuda_future_can_extract_list_with_cuda_tensor(self):
6312        self._test_cuda_future_extraction(
6313            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False
6314        )
6315
6316    @skip_if_lt_x_gpu(1)
6317    def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
6318        self._test_cuda_future_extraction(
6319            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False
6320        )
6321
6322    @skip_if_lt_x_gpu(2)
6323    def test_cuda_future_callback_changes_devices(self):
6324        # We check proper CUDA stream synchronization by filling the tensor with
6325        # the expected value in one stream, and reading it from another stream.
6326        tensor0 = torch.zeros((100,), device="cuda:0")
6327        tensor1 = torch.zeros((100,), device="cuda:1")
6328        parent_future = Future(devices=["cuda:0", "cuda:1"])
6329
6330        def cb(fut):
6331            t0 = fut.value()
6332            tensor1.copy_(t0, non_blocking=True)
6333            return tensor1
6334
6335        child_future = parent_future.then(cb)
6336        with torch.cuda.device("cuda:0"):
6337            stream = torch.cuda.Stream()
6338            with torch.cuda.stream(stream):
6339                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
6340                tensor0.fill_(1)
6341                parent_future.set_result(tensor0)
6342        with torch.cuda.device("cuda:1"):
6343            another_stream = torch.cuda.Stream()
6344            with torch.cuda.stream(another_stream):
6345                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
6346
6347    @skip_if_lt_x_gpu(2)
6348    def test_cuda_future_value_on_bad_device(self):
6349        tensor0 = torch.zeros((100,), device="cuda:0")
6350        tensor1 = torch.zeros((100,), device="cuda:1")
6351        parent_future = Future(devices=["cuda:1"])
6352
6353        # As a plus, we test that futures still invoke callbacks even in case of
6354        # error, and that the child futures are successful if those callbacks
6355        # don't access the parent future.
6356        def cb(fut):
6357            with torch.cuda.device("cuda:1"):
6358                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
6359                tensor1.fill_(1)
6360                return tensor1
6361
6362        child_future = parent_future.then(cb)
6363        with torch.cuda.device("cuda:0"):
6364            stream = torch.cuda.Stream()
6365            with torch.cuda.stream(stream):
6366                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
6367                tensor0.fill_(1)
6368                parent_future.set_result(tensor0)
6369        with self.assertRaisesRegex(
6370            ValueError,
6371            r"The result contained tensors residing on device\(s\) cuda:0 "
6372            r"which are not among the expected device\(s\) cuda:1",
6373        ):
6374            parent_future.wait()
6375        with torch.cuda.device("cuda:1"):
6376            another_stream = torch.cuda.Stream()
6377            with torch.cuda.stream(another_stream):
6378                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
6379
6380    @skip_if_lt_x_gpu(1)
6381    def test_async_execution_with_cuda_future(self):
6382        dst = worker_name((self.rank + 1) % self.world_size)
6383        options = self.rpc_backend_options
6384        options.set_device_map(dst, {"cuda:0": "cuda:0"})
6385
6386        rpc.init_rpc(
6387            name=worker_name(self.rank),
6388            backend=self.rpc_backend,
6389            rank=self.rank,
6390            world_size=self.world_size,
6391            rpc_backend_options=options,
6392        )
6393
6394        t = torch.zeros((100,), device="cuda:0")
6395        fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,))
6396        another_stream = torch.cuda.Stream("cuda:0")
6397        with torch.cuda.stream(another_stream):
6398            self.assertTrue(torch.eq(fut.wait(), 1).all().item())
6399
6400        rpc.shutdown()
6401
6402    @skip_if_lt_x_gpu(1)
6403    def test_async_execution_nested_with_cuda_future(self):
6404        dst = worker_name((self.rank + 1) % self.world_size)
6405        nested_dst = worker_name((self.rank + 2) % self.world_size)
6406        options = self.rpc_backend_options
6407        options.set_device_map(dst, {"cuda:0": "cuda:0"})
6408
6409        rpc.init_rpc(
6410            name=worker_name(self.rank),
6411            backend=self.rpc_backend,
6412            rank=self.rank,
6413            world_size=self.world_size,
6414            rpc_backend_options=options,
6415        )
6416
6417        a = torch.ones((100,), device="cuda:0")
6418        b = torch.ones((100,), device="cuda:0")
6419        c = torch.ones((100,), device="cuda:0")
6420        fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c))
6421        another_stream = torch.cuda.Stream("cuda:0")
6422        with torch.cuda.stream(another_stream):
6423            self.assertTrue(torch.eq(fut.wait(), 3).all().item())
6424
6425        rpc.shutdown()
6426
6427    @skip_if_lt_x_gpu(1)
6428    def test_cuda_future_modify_tensor_inplace(self):
6429        tensor = torch.zeros((100,), device="cuda:0")
6430        future = Future(devices=["cuda:0"])
6431        future.set_result(tensor)
6432        # It's weird to modify the value of a future once it's complete, but
6433        # technically possible. Currently this is considered undefined behavior
6434        # (in practice the future will ignore the modification and still
6435        # synchronize with the original value). We could one day add logic to
6436        # detect and warn or throw in such cases, but for now we just check that
6437        # this doesn't crash.
6438        tensor.fill_(1)
6439        future.wait()
6440
6441    @skip_if_lt_x_gpu(1)
6442    def test_cuda_future_replace_tensor(self):
6443        tensor_list = [torch.zeros((100,), device="cuda:0")]
6444        future = Future(devices=["cuda:0"])
6445        future.set_result(tensor_list)
6446        # It's weird to modify the value of a future once it's complete, but
6447        # technically possible. Currently this is considered undefined behavior
6448        # (in practice the future will ignore the modification and still
6449        # synchronize with the original value). We could one day add logic to
6450        # detect and warn or throw in such cases, but for now we just check that
6451        # this doesn't crash.
6452        # We set things up so that the original tensor contained in the list
6453        # gets deleted once we replace it with the other one. This will
6454        # invalidate any cached information held by the future.
6455        tensor_list[0] = torch.ones((100,), device="cuda:0")
6456        future.wait()
6457
6458    @skip_if_lt_x_gpu(1)
6459    def test_rref_with_unpickleable_attributes(self):
6460        dst = worker_name((self.rank + 1) % self.world_size)
6461        options = self.rpc_backend_options
6462        options.set_device_map(dst, {"cuda:0": "cuda:0"})
6463
6464        rpc.init_rpc(
6465            name=worker_name(self.rank),
6466            backend=self.rpc_backend,
6467            rank=self.rank,
6468            world_size=self.world_size,
6469            rpc_backend_options=options,
6470        )
6471
6472        rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),))
6473        rref.rpc_sync().increase(1)
6474        ret = rref.rpc_sync().sum()
6475        self.assertEqual(ret, 42)
6476
6477        rpc.shutdown()
6478
6479    @skip_if_lt_x_gpu(1)
6480    def test_cuda_future_can_extract_cuda_sparse_tensor(self):
6481        self._test_cuda_future_extraction(
6482            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True
6483        )
6484
6485    @skip_if_lt_x_gpu(1)
6486    def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self):
6487        self._test_cuda_future_extraction(
6488            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True
6489        )
6490
6491    @skip_if_lt_x_gpu(1)
6492    def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self):
6493        self._test_cuda_future_extraction(
6494            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True
6495        )
6496