1# mypy: ignore-errors 2 3import re 4import sys 5import time 6from functools import partial, wraps 7from typing import Tuple 8 9import torch.distributed as dist 10import torch.distributed.rpc as rpc 11from torch.distributed.rpc import _rref_context_get_debug_info 12from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN 13 14 15if not dist.is_available(): 16 print("c10d not available, skipping tests", file=sys.stderr) 17 sys.exit(0) 18 19 20INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}" 21 22def dist_init( 23 old_test_method=None, 24 setup_rpc: bool = True, 25 clean_shutdown: bool = True, 26 faulty_messages=None, 27 messages_to_delay=None, 28): 29 """ 30 We use this decorator for setting up and tearing down state since 31 MultiProcessTestCase runs each `test*` method in a separate process and 32 each process just runs the `test*` method without actually calling 33 'setUp' and 'tearDown' methods of unittest. 34 35 Note: pass the string representation of MessageTypes that should be used 36 with the faulty agent's send function. By default, all retriable messages 37 ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE", 38 "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is 39 set from faulty_rpc_agent_test_fixture.py). 40 """ 41 # If we use dist_init without arguments (ex: @dist_init), old_test_method is 42 # appropriately set and we return the wrapper appropriately. On the other 43 # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)), 44 # old_test_method is None and we return a functools.partial which is the real 45 # decorator that is used and as a result we recursively call dist_init with 46 # old_test_method and the rest of the arguments appropriately set. 47 if old_test_method is None: 48 return partial( 49 dist_init, 50 setup_rpc=setup_rpc, 51 clean_shutdown=clean_shutdown, 52 faulty_messages=faulty_messages, 53 messages_to_delay=messages_to_delay, 54 ) 55 56 @wraps(old_test_method) 57 def new_test_method(self, *arg, **kwargs): 58 # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted 59 # in tests. 60 import torch.distributed.rpc.api as api 61 62 api._ignore_rref_leak = False 63 self.worker_id = self.rank 64 self.setup_fault_injection(faulty_messages, messages_to_delay) 65 66 rpc_backend_options = self.rpc_backend_options 67 if setup_rpc: 68 if TEST_WITH_TSAN: 69 # TSAN runs much slower. 70 rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5 71 rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60 72 73 rpc.init_rpc( 74 name="worker%d" % self.rank, 75 backend=self.rpc_backend, 76 rank=self.rank, 77 world_size=self.world_size, 78 rpc_backend_options=rpc_backend_options, 79 ) 80 81 return_value = old_test_method(self, *arg, **kwargs) 82 83 if setup_rpc: 84 rpc.shutdown(graceful=clean_shutdown) 85 86 return return_value 87 88 return new_test_method 89 90 91def noop() -> None: 92 pass 93 94 95def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str: 96 """ 97 Loops until an RPC to the given rank fails. This is used to 98 indicate that the node has failed in unit tests. 99 Args: 100 rank (int): Rank of the node expected to fail 101 expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure 102 occurs, not just any. 103 """ 104 while True: 105 try: 106 rpc.rpc_sync(f"worker{rank}", noop, args=()) 107 time.sleep(0.1) 108 except Exception as e: 109 if re.search(pattern=expected_error_regex, string=str(e)): 110 return str(e) 111 112 113def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None: 114 """ 115 The RRef protocol holds forkIds of rrefs in a map until those forks are 116 confirmed by the owner. The message confirming the fork may arrive after 117 our tests check whether this map is empty, which leads to failures and 118 flaky tests. to_here also does not guarantee that we have finished 119 processind the owner's confirmation message for the RRef. This function 120 loops until the map is empty, which means the messages have been received 121 as processed. Call this function before asserting the map returned by 122 _get_debug_info is empty. 123 """ 124 start = time.time() 125 while True: 126 debug_info = _rref_context_get_debug_info() 127 num_pending_futures = int(debug_info["num_pending_futures"]) 128 num_pending_users = int(debug_info["num_pending_users"]) 129 if num_pending_futures == 0 and num_pending_users == 0: 130 break 131 time.sleep(0.1) 132 if time.time() - start > timeout: 133 raise ValueError( 134 f"Timed out waiting to flush pending futures and users, " 135 f"had {num_pending_futures} pending futures and {num_pending_users} pending users" 136 ) 137 138 139def get_num_owners_and_forks() -> Tuple[str, str]: 140 """ 141 Retrieves number of OwnerRRefs and forks on this node from 142 _rref_context_get_debug_info. 143 """ 144 rref_dbg_info = _rref_context_get_debug_info() 145 num_owners = rref_dbg_info["num_owner_rrefs"] 146 num_forks = rref_dbg_info["num_forks"] 147 return num_owners, num_forks 148 149 150def wait_until_owners_and_forks_on_rank( 151 num_owners: int, num_forks: int, rank: int, timeout: int = 20 152) -> None: 153 """ 154 Waits until timeout for num_forks and num_owners to exist on the rank. Used 155 to ensure proper deletion of RRefs in tests. 156 """ 157 start = time.time() 158 while True: 159 num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync( 160 worker_name(rank), get_num_owners_and_forks, args=(), timeout=5 161 ) 162 num_owners_on_rank = int(num_owners_on_rank) 163 num_forks_on_rank = int(num_forks_on_rank) 164 if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks: 165 return 166 time.sleep(1) 167 if time.time() - start > timeout: 168 raise ValueError( 169 f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank," 170 f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks" 171 ) 172 173 174def initialize_pg(init_method, rank: int, world_size: int) -> None: 175 # This is for tests using `dist.barrier`. 176 if not dist.is_initialized(): 177 dist.init_process_group( 178 backend="gloo", 179 init_method=init_method, 180 rank=rank, 181 world_size=world_size, 182 ) 183 184 185def worker_name(rank: int) -> str: 186 return f"worker{rank}" 187 188 189def get_function_event(function_events, partial_event_name): 190 """ 191 Returns the first event that matches partial_event_name in the provided 192 function_events. These function_events should be the output of 193 torch.autograd.profiler.function_events(). 194 195 Args: 196 function_events: function_events returned by the profiler. 197 event_name (str): partial key that the event was profiled with. 198 """ 199 event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015 200 return event 201