xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/dist_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import re
4import sys
5import time
6from functools import partial, wraps
7from typing import Tuple
8
9import torch.distributed as dist
10import torch.distributed.rpc as rpc
11from torch.distributed.rpc import _rref_context_get_debug_info
12from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN
13
14
15if not dist.is_available():
16    print("c10d not available, skipping tests", file=sys.stderr)
17    sys.exit(0)
18
19
20INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}"
21
22def dist_init(
23    old_test_method=None,
24    setup_rpc: bool = True,
25    clean_shutdown: bool = True,
26    faulty_messages=None,
27    messages_to_delay=None,
28):
29    """
30    We use this decorator for setting up and tearing down state since
31    MultiProcessTestCase runs each `test*` method in a separate process and
32    each process just runs the `test*` method without actually calling
33    'setUp' and 'tearDown' methods of unittest.
34
35    Note: pass the string representation of MessageTypes that should be used
36    with the faulty agent's send function. By default, all retriable messages
37    ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE",
38    "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is
39    set from faulty_rpc_agent_test_fixture.py).
40    """
41    # If we use dist_init without arguments (ex: @dist_init), old_test_method is
42    # appropriately set and we return the wrapper appropriately. On the other
43    # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
44    # old_test_method is None and we return a functools.partial which is the real
45    # decorator that is used and as a result we recursively call dist_init with
46    # old_test_method and the rest of the arguments appropriately set.
47    if old_test_method is None:
48        return partial(
49            dist_init,
50            setup_rpc=setup_rpc,
51            clean_shutdown=clean_shutdown,
52            faulty_messages=faulty_messages,
53            messages_to_delay=messages_to_delay,
54        )
55
56    @wraps(old_test_method)
57    def new_test_method(self, *arg, **kwargs):
58        # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
59        # in tests.
60        import torch.distributed.rpc.api as api
61
62        api._ignore_rref_leak = False
63        self.worker_id = self.rank
64        self.setup_fault_injection(faulty_messages, messages_to_delay)
65
66        rpc_backend_options = self.rpc_backend_options
67        if setup_rpc:
68            if TEST_WITH_TSAN:
69                # TSAN runs much slower.
70                rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5
71                rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
72
73            rpc.init_rpc(
74                name="worker%d" % self.rank,
75                backend=self.rpc_backend,
76                rank=self.rank,
77                world_size=self.world_size,
78                rpc_backend_options=rpc_backend_options,
79            )
80
81        return_value = old_test_method(self, *arg, **kwargs)
82
83        if setup_rpc:
84            rpc.shutdown(graceful=clean_shutdown)
85
86        return return_value
87
88    return new_test_method
89
90
91def noop() -> None:
92    pass
93
94
95def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
96    """
97    Loops until an RPC to the given rank fails. This is used to
98    indicate that the node has failed in unit tests.
99    Args:
100    rank (int): Rank of the node expected to fail
101    expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
102    occurs, not just any.
103    """
104    while True:
105        try:
106            rpc.rpc_sync(f"worker{rank}", noop, args=())
107            time.sleep(0.1)
108        except Exception as e:
109            if re.search(pattern=expected_error_regex, string=str(e)):
110                return str(e)
111
112
113def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
114    """
115    The RRef protocol holds forkIds of rrefs in a map until those forks are
116    confirmed by the owner. The message confirming the fork may arrive after
117    our tests check whether this map is empty, which leads to failures and
118    flaky tests. to_here also does not guarantee that we have finished
119    processind the owner's confirmation message for the RRef. This function
120    loops until the map is empty, which means the messages have been received
121    as processed. Call this function before asserting the map returned by
122    _get_debug_info is empty.
123    """
124    start = time.time()
125    while True:
126        debug_info = _rref_context_get_debug_info()
127        num_pending_futures = int(debug_info["num_pending_futures"])
128        num_pending_users = int(debug_info["num_pending_users"])
129        if num_pending_futures == 0 and num_pending_users == 0:
130            break
131        time.sleep(0.1)
132        if time.time() - start > timeout:
133            raise ValueError(
134                f"Timed out waiting to flush pending futures and users, "
135                f"had {num_pending_futures} pending futures and {num_pending_users} pending users"
136            )
137
138
139def get_num_owners_and_forks() -> Tuple[str, str]:
140    """
141    Retrieves number of OwnerRRefs and forks on this node from
142    _rref_context_get_debug_info.
143    """
144    rref_dbg_info = _rref_context_get_debug_info()
145    num_owners = rref_dbg_info["num_owner_rrefs"]
146    num_forks = rref_dbg_info["num_forks"]
147    return num_owners, num_forks
148
149
150def wait_until_owners_and_forks_on_rank(
151    num_owners: int, num_forks: int, rank: int, timeout: int = 20
152) -> None:
153    """
154    Waits until timeout for num_forks and num_owners to exist on the rank. Used
155    to ensure proper deletion of RRefs in tests.
156    """
157    start = time.time()
158    while True:
159        num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync(
160            worker_name(rank), get_num_owners_and_forks, args=(), timeout=5
161        )
162        num_owners_on_rank = int(num_owners_on_rank)
163        num_forks_on_rank = int(num_forks_on_rank)
164        if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks:
165            return
166        time.sleep(1)
167        if time.time() - start > timeout:
168            raise ValueError(
169                f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank,"
170                f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks"
171            )
172
173
174def initialize_pg(init_method, rank: int, world_size: int) -> None:
175    # This is for tests using `dist.barrier`.
176    if not dist.is_initialized():
177        dist.init_process_group(
178            backend="gloo",
179            init_method=init_method,
180            rank=rank,
181            world_size=world_size,
182        )
183
184
185def worker_name(rank: int) -> str:
186    return f"worker{rank}"
187
188
189def get_function_event(function_events, partial_event_name):
190    """
191    Returns the first event that matches partial_event_name in the provided
192    function_events. These function_events should be the output of
193    torch.autograd.profiler.function_events().
194
195    Args:
196    function_events: function_events returned by the profiler.
197    event_name (str): partial key that the event was profiled with.
198    """
199    event = [event for event in function_events if partial_event_name in event.name][0]  # noqa: RUF015
200    return event
201