1# mypy: allow-untyped-defs 2import cProfile 3import inspect 4import io 5import itertools 6import os 7import warnings 8from contextlib import contextmanager 9from functools import wraps 10from pstats import Stats 11from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union 12 13import torch 14import torch.distributed as dist 15from torch.distributed._shard.sharded_tensor import ShardedTensor 16from torch.distributed._shard.sharded_tensor.shard import Shard 17 18from .api import ( 19 _is_wrapped_exception, 20 _wrap_exception, 21 CheckpointException, 22 WRAPPED_EXCEPTION, 23) 24from .metadata import MetadataIndex, STATE_DICT_TYPE 25 26 27__all__ = ["find_tensor_shard", "find_state_dict_object"] 28 29T = TypeVar("T") 30R = TypeVar("R") 31 32 33def _get_failure_dict( 34 results: List[Union[T, WRAPPED_EXCEPTION]] 35) -> Dict[int, WRAPPED_EXCEPTION]: 36 return cast( 37 Dict[int, WRAPPED_EXCEPTION], 38 {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, 39 ) 40 41 42def _all_gather_keys( 43 local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None 44) -> List[Any]: 45 """Gathers all keys, and returns them sorted.""" 46 keys = list(local_dict.keys()) 47 gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] 48 49 dist.all_gather_object(gathered_keys, keys, group=group) 50 return sorted(set(itertools.chain.from_iterable(gathered_keys))) 51 52 53class _DistWrapper: 54 """ 55 This is a wrapper around PG that provides a series of features around object collectives. 56 57 It works without distributed initialized, where most collectives turns into nops. 58 59 All variants that take functions are exception robust, meaning that if one or more 60 ranks raise errors, all ranks will observe those. 61 """ 62 63 def __init__( 64 self, 65 group: Optional[dist.ProcessGroup], 66 use_dist: bool, 67 coordinator_rank: int, 68 ): 69 self.group = group 70 self.use_dist = use_dist 71 self.coordinator_rank = coordinator_rank 72 if self.use_dist: 73 self.rank = dist.get_rank(group) 74 self.is_coordinator = self.rank == coordinator_rank 75 else: 76 self.rank = 0 77 self.is_coordinator = True 78 79 def get_rank(self) -> int: 80 return self.rank 81 82 def get_world_size(self) -> int: 83 if self.use_dist: 84 return dist.get_world_size(self.group) 85 return 1 86 87 def broadcast_object(self, object: Optional[T]) -> T: 88 """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" 89 object_list = [object] 90 if self.use_dist: 91 dist.broadcast_object_list( 92 object_list=object_list, 93 group=self.group, 94 src=self.coordinator_rank, 95 ) 96 return cast(T, object_list[0]) 97 98 def gather_object(self, object: T) -> Optional[List[T]]: 99 """Implement functionality similar to c10d::gather_object but without distributed enabled.""" 100 if self.use_dist: 101 gather_objs = ( 102 cast(List[T], [None] * dist.get_world_size(self.group)) 103 if self.is_coordinator 104 else None 105 ) 106 107 dist.gather_object( 108 obj=object, 109 object_gather_list=gather_objs if self.is_coordinator else None, 110 dst=self.coordinator_rank, 111 group=self.group, 112 ) 113 result = gather_objs 114 else: 115 result = [object] 116 return result 117 118 def all_gather_object(self, object: T) -> List[T]: 119 """Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" 120 if self.use_dist: 121 gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) 122 123 dist.all_gather_object( 124 object_list=gather_objs, obj=object, group=self.group 125 ) 126 else: 127 gather_objs = [object] 128 return gather_objs 129 130 def scatter_object(self, object_list: Optional[List[T]]) -> T: 131 """Implement functionality similar to c10d::scatter_object but without distributed enabled.""" 132 if self.use_dist: 133 gather_result = cast(List[T], [None]) 134 dist.scatter_object_list( 135 scatter_object_output_list=gather_result, 136 scatter_object_input_list=object_list if self.is_coordinator else None, 137 src=self.coordinator_rank, 138 group=self.group, 139 ) 140 141 local_reply = gather_result[0] 142 else: 143 assert object_list is not None 144 local_reply = object_list[0] 145 return local_reply 146 147 def reduce_scatter( 148 self, 149 step: str, 150 map_fun: Callable[[], T], 151 reduce_fun: Callable[[List[T]], List[R]], 152 ) -> R: 153 """ 154 Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. 155 156 This method operates in the following way: 157 Run ``map_fun`` on all ranks 158 Gather results on rank 0 159 Call ``reduce_fun`` on all those values 160 Scatter to each rank part of the result. 161 """ 162 local_data: Union[WRAPPED_EXCEPTION, T] 163 try: 164 local_data = map_fun() 165 except BaseException as e: 166 local_data = _wrap_exception(e) 167 168 all_data = self.gather_object(local_data) 169 all_results: Optional[List[Union[R, CheckpointException]]] = None 170 if self.is_coordinator: 171 assert all_data is not None 172 node_failures = _get_failure_dict(all_data) 173 174 if len(node_failures) == 0: 175 try: 176 # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? 177 all_results = cast( 178 List[Union[R, CheckpointException]], 179 reduce_fun(cast(List[T], all_data)), 180 ) 181 except BaseException as e: 182 node_failures[self.rank] = _wrap_exception(e) 183 184 if len(node_failures) > 0: 185 all_results = [ 186 CheckpointException(step, node_failures) 187 ] * self.get_world_size() 188 189 result = self.scatter_object(all_results) 190 if isinstance(result, CheckpointException): 191 raise result 192 return result 193 194 def all_reduce( 195 self, 196 step: str, 197 map_fun: Callable[[], T], 198 reduce_fun: Callable[[List[T]], R], 199 ) -> R: 200 """ 201 Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. 202 203 This method operates in the following way: 204 Run ``map_fun`` on all ranks 205 Gather results on rank 0 206 Call ``reduce_fun`` on all those values 207 Broadcast the reduced value to all ranks. 208 """ 209 local_data: Union[T, WRAPPED_EXCEPTION] 210 try: 211 local_data = map_fun() 212 except BaseException as e: 213 local_data = _wrap_exception(e) 214 215 all_data = self.gather_object(local_data) 216 result: Optional[Union[R, CheckpointException]] = None 217 if self.is_coordinator: 218 assert all_data is not None 219 node_failures = _get_failure_dict(all_data) 220 if len(node_failures) == 0: 221 try: 222 result = reduce_fun(cast(List[T], all_data)) 223 except BaseException as e: 224 node_failures[self.rank] = _wrap_exception(e) 225 226 if len(node_failures) > 0: 227 result = CheckpointException(step, node_failures) 228 229 final_result = self.broadcast_object(result) 230 if isinstance(final_result, CheckpointException): 231 raise final_result 232 return cast(R, final_result) 233 234 def all_gather( 235 self, 236 step: str, 237 map_fun: Callable[[], T], 238 ) -> List[T]: 239 """ 240 Compute a value on each rank, then all_gather them. 241 242 This method operates in the following way: 243 Run ``map_cp`` on all ranks 244 all_gather the values to all ranks 245 """ 246 result: Union[T, WRAPPED_EXCEPTION] 247 try: 248 result = map_fun() 249 except BaseException as e: 250 result = _wrap_exception(e) 251 252 all_results = self.all_gather_object(result) 253 254 node_failures = _get_failure_dict(all_results) 255 if len(node_failures) > 0: 256 raise CheckpointException(step, node_failures) 257 return cast(List[T], all_results) 258 259 def broadcast( 260 self, 261 step: str, 262 map_fun: Callable[[], T], 263 ) -> T: 264 """ 265 Compute a value on rank 0 and broadcast it. 266 267 This method operates in the following way: 268 Run ``map_cp`` on rank 0 269 broadcast the value 270 """ 271 result: Optional[Union[T, CheckpointException]] = None 272 if self.is_coordinator: 273 try: 274 result = map_fun() 275 except BaseException as e: 276 result = CheckpointException(step, {self.rank: _wrap_exception(e)}) 277 final_result = self.broadcast_object(result) 278 if isinstance(final_result, CheckpointException): 279 raise final_result 280 return cast(T, final_result) 281 282 283def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: 284 if index.offset is None: 285 raise ValueError( 286 f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" 287 ) 288 289 shards = tensor.local_shards() 290 # index fast path 291 if index.index is not None: 292 if ( 293 len(shards) > index.index 294 and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset 295 ): 296 return shards[index.index] 297 298 for shard in shards: 299 if torch.Size(shard.metadata.shard_offsets) == index.offset: 300 return shard 301 raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") 302 303 304def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: 305 if hasattr(tensor, "__get_tensor_shard__"): 306 # DTensor implements _Checkpointable 307 return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] 308 if isinstance(tensor, ShardedTensor): 309 return _find_shard(tensor, index).tensor 310 if index.offset is not None: 311 # special case looking up a tensor by origin 312 if index.offset == torch.Size([0] * len(tensor.size())): 313 return tensor 314 raise ValueError( 315 f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" 316 ) 317 return tensor 318 319 320def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: 321 if index.fqn not in state_dict: 322 raise ValueError(f"Could not find FQN: '{index.fqn}'") 323 obj = state_dict[index.fqn] 324 325 if isinstance(obj, torch.Tensor): 326 return find_tensor_shard(obj, index) 327 elif index.offset is not None: 328 raise ValueError( 329 f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" 330 ) 331 return obj 332 333 334def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]: 335 return [i_a + i_b for i_a, i_b in zip(a, b)] 336 337 338def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]: 339 return [i_a - i_b for i_a, i_b in zip(a, b)] 340 341 342class _ReaderView(io.IOBase): 343 def __init__(self, base_stream: io.IOBase, offset: int, len: int): 344 super().__init__() 345 self.offset = offset 346 self.len = len 347 self.base_stream = base_stream 348 self.seek(0) 349 350 def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: 351 if __whence == os.SEEK_SET: 352 __offset = self.offset + __offset 353 elif __whence == os.SEEK_END: 354 __whence = os.SEEK_SET 355 __offset = (self.offset + self.len) - __offset 356 return self.base_stream.seek(__offset, __whence) 357 358 def tell(self) -> int: 359 return self.base_stream.tell() - self.offset 360 361 def readable(self) -> bool: 362 return self.base_stream.readable() 363 364 def seekable(self) -> bool: 365 return self.base_stream.seekable() 366 367 def readinto(self, b): 368 return self.base_stream.readinto(b) # type: ignore[attr-defined] 369 370 def read(self, size=-1): 371 return self.base_stream.read(size) 372 373 374def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: 375 # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader 376 return _ReaderView(file, offset, length) 377 378 379def _normalize_device_info(device_type: str, device_id: int) -> str: 380 """Device info normalization.""" 381 if device_type == "cpu": 382 return "cpu" 383 return f"{device_type}:{device_id}" 384 385 386# TODO: integrate with distributed logging flag 387ENABLE_PROFILE = False 388 389 390@contextmanager 391def _profile(): 392 # Only log the profiling when it is enable and is on rank0 or dist is not 393 # avaiable. 394 if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): 395 profiler = cProfile.Profile() 396 profiler.enable() 397 try: 398 yield 399 finally: 400 profiler.disable() 401 stats = Stats(profiler) 402 stats.sort_stats("time").print_stats(10) 403 else: 404 yield 405 406 407def _api_bc_check(func): 408 @wraps(func) 409 def inner_func(*args, **kwargs) -> Any: 410 if len(args) == 2: 411 warnings.warn( 412 f"The argument order of {func.__name__} has been changed. " 413 "Please check the document to avoid future breakages." 414 ) 415 sig = inspect.signature(func) 416 kwonlyargs = [ 417 p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY 418 ] 419 if "storage_writer" in kwonlyargs: 420 assert "storage_writer" not in kwargs, (args, kwargs) 421 kwargs["storage_writer"] = args[1] 422 elif "storage_reader" in kwonlyargs: 423 assert "storage_reader" not in kwargs, (args, kwargs) 424 kwargs["storage_reader"] = args[1] 425 else: 426 raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") 427 return func(args[0], **kwargs) 428 else: 429 return func(*args, **kwargs) 430 431 return inner_func 432