1# mypy: allow-untyped-defs 2import copyreg 3import functools 4import logging 5import sys 6import traceback 7import warnings 8from collections import defaultdict 9from typing import Any, Callable, DefaultDict, Generic, List, Optional 10from typing_extensions import ParamSpec 11 12import torch 13 14 15def _type(self, dtype=None, non_blocking=False, **kwargs): 16 """Returns the type if `dtype` is not provided, else casts this object to 17 the specified type. 18 19 If this is already of the correct type, no copy is performed and the 20 original object is returned. 21 22 Args: 23 dtype (type or string): The desired type 24 non_blocking (bool): If ``True``, and the source is in pinned memory 25 and destination is on the GPU or vice versa, the copy is performed 26 asynchronously with respect to the host. Otherwise, the argument 27 has no effect. 28 **kwargs: For compatibility, may contain the key ``async`` in place of 29 the ``non_blocking`` argument. The ``async`` arg is deprecated. 30 """ 31 non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) 32 if dtype is None: 33 return self.__module__ + "." + self.__class__.__name__ 34 35 if isinstance(dtype, str): 36 dtype = _import_dotted_name(dtype) 37 if dtype == type(self): 38 return self 39 if self.is_sparse: 40 if not dtype.is_sparse: 41 raise RuntimeError("Cannot cast sparse tensor to dense tensor") 42 new_module_name = dtype.__module__.replace(".sparse", "") 43 new_values_type_name = new_module_name + "." + dtype.__name__ 44 new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) 45 new_indices_type_name = new_module_name + ".LongTensor" 46 new_indices = torch.Tensor._indices(self).type( 47 new_indices_type_name, non_blocking 48 ) 49 return dtype(new_indices, new_values, self.size()) 50 if dtype.is_sparse: 51 raise RuntimeError("Cannot cast dense tensor to sparse tensor") 52 return dtype(self.size()).copy_(self, non_blocking) 53 54 55def _to(self, device, non_blocking=False): 56 """Returns a copy of this object in device memory. 57 58 If this object is already on the correct device, then no copy is performed 59 and the original object is returned. 60 61 Args: 62 device (int): The destination device. 63 non_blocking (bool): If ``True`` and the source is in pinned memory, 64 the copy will be asynchronous with respect to the host. Otherwise, 65 the argument has no effect. 66 """ 67 if self.device == device: 68 return self 69 70 device_module = getattr(torch, device.type, None) 71 assert ( 72 device_module is not None 73 ), f"{device.type.upper()} device module is not loaded" 74 with device_module.device(device): 75 if self.is_sparse and hasattr(device_module, "sparse"): 76 new_type = getattr(device_module.sparse, self.__class__.__name__) 77 indices = getattr(torch.Tensor._indices(self), device.type)( 78 device, non_blocking 79 ) 80 values = getattr(torch.Tensor._values(self), device.type)( 81 device, non_blocking 82 ) 83 return new_type(indices, values, self.size()) 84 else: 85 assert ( 86 not self.is_sparse 87 ), f"sparse storage is not supported for {device.type.upper()} tensors" 88 untyped_storage = torch.UntypedStorage(self.size(), device=device) 89 untyped_storage.copy_(self, non_blocking) 90 return untyped_storage 91 92 93def _get_async_or_non_blocking(function_name, non_blocking, kwargs): 94 """Return the non-blocking flag given the function name and kwargs. 95 96 Args: 97 function_name (str): the name of the function being used. 98 non_blocking (bool): the default value. 99 **kwargs (dict): the kwargs passed to the function. 100 """ 101 if not kwargs: 102 return non_blocking 103 if len(kwargs) != 1 or "async" not in kwargs: 104 message = "{}() got an unexpected keyword argument '{}'" 105 argument = list(kwargs.keys()).pop() 106 raise TypeError(message.format(function_name, argument)) 107 warnings.warn("'async' is deprecated; use 'non_blocking'") 108 return kwargs["async"] 109 110 111def _get_restore_location(device): 112 """Return the map_location location. 113 114 Used for rebuild functions where the tensor device is distinct from the storage 115 """ 116 117 map_location = torch.serialization._serialization_tls.map_location 118 if map_location is None: 119 return device 120 else: 121 if isinstance(map_location, dict): 122 return map_location.get(device, device) 123 elif isinstance(map_location, (str, torch.device)): 124 return map_location 125 else: 126 assert callable(map_location) 127 raise RuntimeError( 128 "Callable map_location not supported with _rebuild_wrapper_subclass " 129 "or _rebuild_device_tensor_from_numpy" 130 ) 131 132 133# Note [Don't serialize hooks] 134# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 135# Since time immemorial, we have serialized the backward hooks associated with 136# variables. This kind of half-worked--Python can pickle global functions 137# (but not closures!)--but there were problems. 138# 139# - It's fragile. If you serialize a backward hook into a saved 140# model, and then you rename the function associated with the hook, 141# now your saved model is broken and you can't load it anymore. 142# 143# - It's not actually used. The standard recommendation is to 144# serialize the *state_dict* of a model, not the model itself 145# (since this is more stable to code changes affecting the model 146# serialization), and the state dict saves "data" only, thus 147# stripping the backward hooks. In some cases, hooks are 148# essential to the well-functioning of a model (e.g., DDP), 149# but DDP already manages readding the hooks! 150# 151# - We didn't serialize them in many cases. Prior to #10220, we 152# were dropping backward hooks in ForkingPickler. We "fixed" this 153# to be convenient with other serialization sites, but lack of 154# serializing backward hooks wasn't actually the root cause of 155# the bug. 156# 157# With these cases in mind, we have decided that a better strategy 158# is to just NOT serialize hooks at all. 159# 160# Since this is a BC-breaking change, we should warn when we previously 161# serialized a hook, but no longer do so. This will be done by adding a special 162# sentinel property to hooks will be used to suppress this warning. If a hook 163# has the property _torch_serialize_ignore, we will not emit a warning if we 164# attempt to serialize a Tensor with this hook attached to it. 165# 166# By the way, when _backward_hooks is skipped, we must give an EMPTY 167# OrderedDict(), if you pass a None you'll run afoul #12219. 168 169 170# TODO: Once we decide to break serialization FC, `storage` no longer needs to 171# be a TypedStorage 172def _rebuild_tensor(storage, storage_offset, size, stride): 173 # first construct a tensor with the correct dtype/device 174 t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device) 175 return t.set_(storage._untyped_storage, storage_offset, size, stride) 176 177 178def get_tensor_metadata(tensor): 179 # Tensor's Metadata for serializing. 180 # Currently, this only returns a dict[string, bool] specifing whether 181 # `conj` or `neg` bit is set. 182 assert isinstance(tensor, torch.Tensor) 183 return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] 184 185 186def set_tensor_metadata(tensor, metadata): 187 # See `get_tensor_metadata` above 188 assert isinstance(metadata, dict) 189 assert isinstance(tensor, torch.Tensor) 190 torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] 191 192 193def _rebuild_tensor_v2( 194 storage, 195 storage_offset, 196 size, 197 stride, 198 requires_grad, 199 backward_hooks, 200 metadata=None, 201): 202 tensor = _rebuild_tensor(storage, storage_offset, size, stride) 203 tensor.requires_grad = requires_grad 204 if metadata: 205 set_tensor_metadata(tensor, metadata) 206 207 # NB: This line exists only for backwards compatibility; the 208 # general expectation is that backward_hooks is an empty 209 # OrderedDict. See Note [Don't serialize hooks] 210 tensor._backward_hooks = backward_hooks 211 return tensor 212 213 214def _rebuild_tensor_v3( 215 storage, 216 storage_offset, 217 size, 218 stride, 219 requires_grad, 220 backward_hooks, 221 dtype, 222 metadata=None, 223): 224 t = torch.empty( 225 (0,), 226 dtype=dtype, 227 device=storage._untyped_storage.device, 228 requires_grad=requires_grad, 229 ) 230 t.set_(storage._untyped_storage, storage_offset, size, stride) 231 if metadata: 232 set_tensor_metadata(t, metadata) 233 t._backward_hooks = backward_hooks 234 return t 235 236 237_sparse_tensors_to_validate: List["torch.Tensor"] = [] 238 239 240# In _legacy_load() in serialization.py we unpickle storages after the sparse 241# tensors have been already unpickled. Those storages contain data necessary for 242# validating sparse tensors: indices and values. That's why sparse tensors are 243# first unpickled without any validation, and then this function is called just 244# before _legacy_load() returns, so that all the sparse tensors can be validated 245# in bulk. 246# 247# The same procedure must be followed by _load() in serialization.py because due 248# to Pickler semantics, we have to use the same (non-validating) function for 249# unpickling sparse tensors, regardless of the caller. 250def _validate_loaded_sparse_tensors(): 251 try: 252 for t in _sparse_tensors_to_validate: 253 if t.layout is torch.sparse_coo: 254 torch._validate_sparse_coo_tensor_args( 255 t._indices(), t._values(), t.size(), t.is_coalesced() 256 ) 257 elif t.layout in { 258 torch.sparse_csr, 259 torch.sparse_csc, 260 torch.sparse_bsr, 261 torch.sparse_bsc, 262 }: 263 # TODO: Validation currently involves an expensive traversal 264 # on CPU, which may include a device transfer. 265 if t.layout in {torch.sparse_csr, torch.sparse_bsr}: 266 compressed_indices, plain_indices = ( 267 t.crow_indices(), 268 t.col_indices(), 269 ) 270 else: 271 compressed_indices, plain_indices = ( 272 t.ccol_indices(), 273 t.row_indices(), 274 ) 275 torch._validate_sparse_compressed_tensor_args( 276 compressed_indices, plain_indices, t.values(), t.size(), t.layout 277 ) 278 else: 279 raise NotImplementedError( 280 f"_validate_loaded_sparse_tensors for layout `{t.layout}`" 281 ) 282 283 finally: 284 _sparse_tensors_to_validate.clear() 285 286 287def _rebuild_sparse_tensor(layout, data): 288 """ 289 Rebuilds a sparse tensor from its sparse storage representation. 290 291 Args: 292 layout (str): The sparse storage layout of the tensor. 293 data (tuple): The tensor's sparse storage representation. 294 """ 295 if layout == torch.sparse_coo: 296 if len(data) == 3: 297 # For BC: 298 indices, values, size = data 299 is_coalesced = None 300 else: 301 indices, values, size, is_coalesced = data 302 result = torch.sparse_coo_tensor( 303 indices, values, size, check_invariants=False, is_coalesced=is_coalesced 304 ) 305 _sparse_tensors_to_validate.append(result) 306 return result 307 308 elif layout in { 309 torch.sparse_csr, 310 torch.sparse_csc, 311 torch.sparse_bsr, 312 torch.sparse_bsc, 313 }: 314 compressed_indices, plain_indices, values, size = data 315 result = torch.sparse_compressed_tensor( 316 compressed_indices, 317 plain_indices, 318 values, 319 size, 320 layout=layout, 321 check_invariants=False, 322 ) 323 _sparse_tensors_to_validate.append(result) 324 return result 325 326 raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}") 327 328 329def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): 330 return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) 331 332 333def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): 334 device = _get_restore_location(device) 335 tensor = torch.from_numpy(data).to(dtype=dtype, device=device) 336 tensor.requires_grad = requires_grad 337 return tensor 338 339 340# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch 341_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy 342 343 344def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): 345 return torch.empty_strided( 346 size, stride, dtype=dtype, device="meta", requires_grad=requires_grad 347 ) 348 349 350def _rebuild_wrapper_subclass( 351 cls, 352 dtype, 353 size, 354 stride, 355 storage_offset, 356 layout, 357 device, 358 requires_grad, 359): 360 device = _get_restore_location(device) 361 return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 362 cls, 363 size, 364 strides=stride, 365 dtype=dtype, 366 storage_offset=storage_offset, 367 layout=layout, 368 device=device, 369 requires_grad=requires_grad, 370 ) 371 372 373# TODO: Once we decide to break serialization FC, `storage` no longer needs to 374# be a TypedStorage 375def _rebuild_qtensor( 376 storage, 377 storage_offset, 378 size, 379 stride, 380 quantizer_params, 381 requires_grad, 382 backward_hooks, 383): 384 qscheme = quantizer_params[0] 385 if qscheme == torch.per_tensor_affine: 386 _, scale, zero_point = quantizer_params 387 tensor = torch._empty_affine_quantized( 388 size, 389 scale=scale, 390 zero_point=zero_point, 391 dtype=storage.dtype, 392 device=storage.device, 393 ) 394 elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): 395 _, scales, zero_points, axis = quantizer_params 396 if type(scales) is list and type(zero_points) is list: 397 if qscheme == torch.per_channel_affine: 398 scales = torch.tensor(scales, dtype=torch.double, device=storage.device) 399 zero_points = torch.tensor( 400 zero_points, dtype=torch.long, device=storage.device 401 ) 402 else: 403 scales = torch.tensor(scales, dtype=torch.float, device=storage.device) 404 zero_points = torch.tensor( 405 zero_points, dtype=torch.float, device=storage.device 406 ) 407 tensor = torch._empty_per_channel_affine_quantized( 408 size, 409 scales=scales, 410 zero_points=zero_points, 411 axis=axis, 412 dtype=storage.dtype, 413 device=storage.device, 414 ) 415 else: 416 raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}") 417 tensor.set_(storage, storage_offset, size, stride) 418 tensor.requires_grad = requires_grad 419 # NB: This line exists only for backwards compatibility; the 420 # general expectation is that backward_hooks is an empty 421 # OrderedDict. See Note [Don't serialize hooks] 422 tensor._backward_hooks = backward_hooks 423 return tensor 424 425 426def _rebuild_parameter(data, requires_grad, backward_hooks): 427 param = torch.nn.Parameter(data, requires_grad) 428 # NB: This line exists only for backwards compatibility; the 429 # general expectation is that backward_hooks is an empty 430 # OrderedDict. See Note [Don't serialize hooks] 431 param._backward_hooks = backward_hooks 432 433 return param 434 435 436def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): 437 param = torch.nn.Parameter(data, requires_grad) 438 # NB: This line exists only for backwards compatibility; the 439 # general expectation is that backward_hooks is an empty 440 # OrderedDict. See Note [Don't serialize hooks] 441 param._backward_hooks = backward_hooks 442 443 # Restore state on Parameter like python attr. 444 param = _set_obj_state(param, state) 445 return param 446 447 448def _get_obj_state(obj): 449 # Get the state of the python subclass 450 # This loosely mimicks the function on the object class but since Tensor do not inherit 451 # from it, we cannot call that function directly 452 # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 453 # Note that starting with Python 3.11, this `__getstate__` is always defined and thus 454 # the else branch will never be taken. 455 getstate_fn = getattr(obj, "__getstate__", None) 456 if getstate_fn: 457 state = getstate_fn() 458 else: 459 slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] 460 if slots_to_save: 461 state = ( 462 obj.__dict__, 463 { 464 name: getattr(obj, name) 465 for name in slots_to_save 466 if hasattr(obj, name) 467 }, 468 ) 469 else: 470 state = obj.__dict__ 471 472 return state 473 474 475def _set_obj_state(obj, state): 476 if isinstance(state, tuple): 477 if not len(state) == 2: 478 raise RuntimeError(f"Invalid serialized state: {state}") 479 dict_state = state[0] 480 slots_state = state[1] 481 else: 482 dict_state = state 483 slots_state = None 484 485 # Starting with Python 3.11, the __dict__ attribute is lazily created 486 # and is serialized as None when not needed. 487 if dict_state: 488 for k, v in dict_state.items(): 489 setattr(obj, k, v) 490 491 if slots_state: 492 for k, v in slots_state.items(): 493 setattr(obj, k, v) 494 return obj 495 496 497def _import_dotted_name(name): 498 components = name.split(".") 499 obj = __import__(components[0]) 500 for component in components[1:]: 501 obj = getattr(obj, component) 502 return obj 503 504 505def _flatten_dense_tensors(tensors): 506 """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 507 same dense type. 508 509 Since inputs are dense, the resulting tensor will be a concatenated 1D 510 buffer. Element-wise operation on this buffer will be equivalent to 511 operating individually. 512 513 Args: 514 tensors (Iterable[Tensor]): dense tensors to flatten. 515 516 Returns: 517 A contiguous 1D buffer containing input tensors. 518 """ 519 return torch._C._nn.flatten_dense_tensors(tensors) 520 521 522def _flatten_sparse_tensors(tensors): 523 """Flatten sparse tensors into two contiguous 1D buffers, one of indices and 524 one of values. Assume tensors are of same sparse type. 525 526 Args: 527 tensors (Iterable[Tensor]): sparse tensors to flatten. 528 529 Returns: 530 A tuple of two contiguous 1D buffers, one containing input tensors' 531 indices and the other containing the values. 532 """ 533 flat_indices = torch._C._nn.flatten_dense_tensors( 534 [torch.Tensor._indices(t) for t in tensors] 535 ) 536 flat_values = torch._C._nn.flatten_dense_tensors( 537 [torch.Tensor._values(t) for t in tensors] 538 ) 539 return flat_indices, flat_values 540 541 542def _unflatten_dense_tensors(flat, tensors): 543 """View a flat buffer using the sizes of tensors. Assume that tensors are of 544 same dense type, and that flat is given by _flatten_dense_tensors. 545 546 Args: 547 flat (Tensor): flattened dense tensors to unflatten. 548 tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 549 unflatten flat. 550 551 Returns: 552 Unflattened dense tensors with sizes same as tensors and values from 553 flat. 554 """ 555 return torch._C._nn.unflatten_dense_tensors(flat, tensors) 556 557 558def _unflatten_sparse_tensors(flat, tensors): 559 """View flat buffer (containing indices and values) using the sizes of 560 tensors. Assume that tensors are of same sparse type, and that flat is given 561 by _flatten_sparse_tensors. 562 563 Args: 564 flat (tuple(Tensor, Tensor)): flattened indices and values of sparse 565 tensors to unflatten. 566 tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to 567 unflatten flat. 568 569 Returns: 570 Unflattened sparse tensors with sizes same as tensors and values from 571 flat. 572 """ 573 flat_indices, flat_values = flat 574 indices = torch._C._nn.unflatten_dense_tensors( 575 flat_indices, [torch.Tensor._indices(t) for t in tensors] 576 ) 577 values = torch._C._nn.unflatten_dense_tensors( 578 flat_values, [torch.Tensor._values(t) for t in tensors] 579 ) 580 outputs = [] 581 for t, i, v in zip(tensors, indices, values): 582 outputs.append(t.new(i, v, t.size())) 583 return tuple(outputs) 584 585 586def _reorder_tensors_as(tensors, ordered_tensors): 587 """Assume that tensors are of same order as ordered_tensors within their 588 types, e.g., from _take_tensors. Reorder them to be of same order as 589 ordered_tensors. 590 591 Args: 592 tensors (Iterable[Tensor]): tensors to be reordered. They should be of 593 the same order as ordered_tensors within their own types. 594 ordered_tensors (Iterable[Tensor]): tensors whose order will be the 595 reference. 596 597 Returns: 598 Ordered tuple of tensors with contents from tensors and order of 599 ordered_tensors. 600 """ 601 type_dict = defaultdict(list) 602 for tensor in tensors: 603 type_dict[tensor.type()].append(tensor) 604 type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} 605 return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) 606 607 608def _take_tensors(tensors, size_limit): 609 """Group tensors into chunks. This generator yields a chunk at each time, 610 each containing tensors of same type up to certain byte limit in total size. 611 612 Args: 613 tensors (Sequence): A sequence of tensors to be separated into chunks. 614 size_limit (int): The limit of each chunk in bytes. 615 616 Yields: 617 Blocks of tensors of same type and within size_limit. The yielded 618 tensors are only ordered as the original sequence within its types. 619 """ 620 buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) 621 for tensor in tensors: 622 t = tensor.type() 623 if tensor.is_sparse: 624 indices = torch.Tensor._indices(tensor) 625 values = torch.Tensor._values(tensor) 626 size = ( 627 indices.numel() * indices.element_size() 628 + values.numel() * values.element_size() 629 ) 630 else: 631 size = tensor.numel() * tensor.element_size() 632 buf_and_size = buf_dict[t] 633 if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: 634 yield buf_and_size[0] 635 buf_and_size = buf_dict[t] = [[], 0] 636 buf_and_size[0].append(tensor) 637 buf_and_size[1] += size 638 for buf, _ in buf_dict.values(): 639 if len(buf) > 0: 640 yield buf 641 642 643# annotation decorator to get annotations in a way that is compatible 644# with both Python 2 and 3 645def annotate(ret, **kwargs): 646 def dec(fun): 647 fun.__annotations__ = dict(kwargs) 648 fun.__annotations__["return"] = ret 649 return fun 650 651 return dec 652 653 654def render_call(fn, args, kwargs): 655 str_fn = torch.overrides.resolve_name(fn) 656 if str_fn is None: 657 str_fn = str(fn) 658 659 str_args: List[str] = [] 660 with torch._tensor_str.printoptions(threshold=0, edgeitems=0): 661 str_args.extend(repr(a) for a in args) 662 str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items()) 663 r = f"{str_fn}({', '.join(str_args)})" 664 return r 665 666 667# NOTE [ Python Traceback Reference Cycle Problem ] 668# 669# When using sys.exc_info(), it is important to **not** store the exc_info[2], 670# which is the traceback, because otherwise you will run into the traceback 671# reference cycle problem, i.e., the traceback holding reference to the frame, 672# and the frame (which holds reference to all the object in its temporary scope) 673# holding reference the traceback. 674 675 676class KeyErrorMessage(str): 677 r"""str subclass that returns itself in repr""" 678 679 def __repr__(self): 680 return self 681 682 683class ExceptionWrapper: 684 r"""Wraps an exception plus traceback to communicate across threads""" 685 686 def __init__(self, exc_info=None, where="in background"): 687 # It is important that we don't store exc_info, see 688 # NOTE [ Python Traceback Reference Cycle Problem ] 689 if exc_info is None: 690 exc_info = sys.exc_info() 691 self.exc_type = exc_info[0] 692 self.exc_msg = "".join(traceback.format_exception(*exc_info)) 693 self.where = where 694 695 def reraise(self): 696 r"""Reraises the wrapped exception in the current thread""" 697 # Format a message such as: "Caught ValueError in DataLoader worker 698 # process 2. Original Traceback:", followed by the traceback. 699 msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" 700 if self.exc_type == KeyError: 701 # KeyError calls repr() on its argument (usually a dict key). This 702 # makes stack traces unreadable. It will not be changed in Python 703 # (https://bugs.python.org/issue2651), so we work around it. 704 msg = KeyErrorMessage(msg) 705 elif getattr(self.exc_type, "message", None): 706 # Some exceptions have first argument as non-str but explicitly 707 # have message field 708 raise self.exc_type(message=msg) 709 try: 710 exception = self.exc_type(msg) 711 except TypeError: 712 # If the exception takes multiple arguments, don't try to 713 # instantiate since we don't know how to 714 raise RuntimeError(msg) from None 715 raise exception 716 717 718def _get_available_device_type(): 719 if torch.cuda.is_available(): 720 return "cuda" 721 if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] 722 return "xpu" 723 if hasattr(torch, "mtia") and torch.mtia.is_available(): 724 return "mtia" 725 custom_backend_name = torch._C._get_privateuse1_backend_name() 726 custom_device_mod = getattr(torch, custom_backend_name, None) 727 if custom_device_mod and custom_device_mod.is_available(): 728 return custom_backend_name 729 # add more available device types here 730 return None 731 732 733def _get_device_attr(get_member): 734 device_type = _get_available_device_type() 735 if device_type and device_type.lower() == "cuda": 736 return get_member(torch.cuda) 737 if device_type and device_type.lower() == "xpu": 738 return get_member(torch.xpu) # type: ignore[attr-defined] 739 if device_type and device_type.lower() == "mtia": 740 return get_member(torch.mtia) 741 if device_type == torch._C._get_privateuse1_backend_name(): 742 return get_member(getattr(torch, device_type)) 743 # add more available device types here 744 return None 745 746 747def _get_current_device_index(): 748 # current device index 749 return _get_device_attr(lambda m: m.current_device()) 750 751 752def _get_all_device_indices(): 753 # all device index 754 return _get_device_attr(lambda m: list(range(m.device_count()))) 755 756 757def _get_devices_properties(device_ids): 758 # all device properties 759 return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] 760 761 762def get_current_device_index() -> int: 763 r"""Checks if there are CUDA devices available and 764 returns the device index of the current default CUDA device. 765 Returns -1 in case there are no CUDA devices available. 766 Arguments: ``None`` 767 """ 768 if torch.cuda.device_count() > 0: 769 return torch.cuda.current_device() 770 return -1 771 772 773def _get_device_index( 774 device: Any, 775 optional: bool = False, 776 allow_cpu: bool = False, 777) -> int: 778 r"""Gets the device index from :attr:`device`, which can be a torch.device 779 object, a Python integer, or ``None``. 780 781 If :attr:`device` is a torch.device object, returns the device index if it 782 has index. Note that for a device without a specified index, 783 i.e., ``torch.device('xxx')``, this will return the current default 784 device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, 785 CPU devices will be accepted and ``-1`` will be returned in this case. 786 787 If :attr:`device` is a Python integer, it is returned as is. 788 789 If :attr:`device` is ``None``, this will return the current default 790 device of the supported runtime platform if :attr:`optional` is ``True``. 791 i.e., the current default CUDA device will be returned if CUDA runtime is supported. 792 """ 793 if isinstance(device, str): 794 device = torch.device(device) 795 device_idx: Optional[int] = None 796 if isinstance(device, torch.device): 797 if not allow_cpu and device.type == "cpu": 798 raise ValueError(f"Expected a non cpu device, but got: {device}") 799 device_idx = -1 if device.type == "cpu" else device.index 800 if isinstance(device, int): 801 device_idx = device 802 if device_idx is None: 803 if optional: 804 # The eager API _get_current_device_index uses `lambda` functions which are 805 # not supported in JIT and hence not scriptable. The JIT equivalent API to get 806 # the current device index is `get_current_device_index()` which can 807 # be scripted. We use is_scripting to check the mode we are in and call the 808 # appropriate API. 809 if torch.jit.is_scripting(): 810 device_idx = get_current_device_index() 811 else: 812 device_idx = _get_current_device_index() 813 else: 814 raise ValueError( 815 f"Expected a torch.device with a specified index or an integer, but got:{device}" 816 ) 817 return device_idx 818 819 820def _handle_complex(tensor): 821 """ 822 Returns a real view of a tensor if complex dtype else just the tensor 823 need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule 824 """ 825 return ( 826 torch.view_as_real(tensor) 827 if not isinstance(tensor, torch.nn.UninitializedParameter) 828 and tensor.is_complex() 829 else tensor 830 ) 831 832 833def _element_size(dtype): 834 """ 835 Returns the element size for a dtype, in bytes 836 """ 837 if not isinstance(dtype, torch.dtype): 838 raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") 839 840 if dtype.is_complex: 841 return torch.finfo(dtype).bits >> 2 842 elif dtype.is_floating_point: 843 return torch.finfo(dtype).bits >> 3 844 elif dtype == torch.bool: 845 # NOTE: torch.bool is not supported in torch.iinfo() 846 return 1 847 else: 848 return torch.iinfo(dtype).bits >> 3 849 850 851class _ClassPropertyDescriptor: 852 def __init__(self, fget, fset=None): 853 self.fget = fget 854 855 def __get__(self, instance, owner=None): 856 if owner is None: 857 owner = type(instance) 858 return self.fget.__get__(instance, owner)() 859 860 861def classproperty(func): 862 if not isinstance(func, (classmethod, staticmethod)): 863 func = classmethod(func) 864 return _ClassPropertyDescriptor(func) 865 866 867def is_compiling() -> bool: 868 """ 869 Indicates whether we are tracing/compiling with torch.compile() or torch.export(). 870 871 TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling(). 872 """ 873 return torch.compiler.is_compiling() 874 875 876def _functionalize_sync(t): 877 # This code lives in python instead of C++ since conditioning on a certain python subclass 878 # is much more of a pain in C++. 879 from torch._subclasses.functional_tensor import FunctionalTensor 880 881 if isinstance(t, FunctionalTensor): 882 # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called 883 # when we sync our inner tensor. 884 # Why? 885 # (1) If there are input mutations in the graph, then they will be re-applied during 886 # AOTAutograd when we call _sync() from inside of our functionalization kernels. 887 # (2) _sync() causes us to regenerate our updated the tensor from the updated base, 888 # which dispatches to a bunch of view ops 889 # (3) The input to these view ops is our inner FunctionalTensorWrapper 890 # (since the sync was called from C++), not the python FunctionalTensor 891 # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts 892 # the view op, since it will see an input that is a C++ FunctionalTensorWrapper 893 # (aka a normal torch.Tensor) instead of a python `FunctionalTensor). 894 maybe_functional_mode = torch._C._unset_dispatch_mode( 895 torch._C._TorchDispatchModeKey.FUNCTIONAL 896 ) 897 try: 898 torch._functionalize_sync(t.elem) # type: ignore[attr-defined] 899 finally: 900 if maybe_functional_mode is not None: 901 torch._C._set_dispatch_mode(maybe_functional_mode) 902 else: 903 torch._functionalize_sync(t) # type: ignore[attr-defined] 904 905 906@functools.lru_cache(2) 907def _get_device_module(device_type: str): 908 device_module = getattr(torch, device_type, None) 909 if device_module is None: 910 raise RuntimeError( 911 f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." 912 ) 913 return device_module 914 915 916def _dummy_type(name: str) -> type: 917 def get_err_fn(is_init: bool): 918 def err_fn(obj, *args, **kwargs): 919 if is_init: 920 class_name = obj.__class__.__name__ 921 else: 922 class_name = obj.__name__ 923 raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") 924 925 return err_fn 926 927 return type( 928 name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} 929 ) 930 931 932class _LazySeedTracker: 933 # Since seeding is memory-less, only track the latest seed. 934 # Note: `manual_seed_all` followed by `manual_seed` overwrites 935 # the seed on current device. We track the order of **latest** 936 # calls between these two API. 937 def __init__(self): 938 self.manual_seed_all_cb = None 939 self.manual_seed_cb = None 940 self.call_order = [] 941 942 def queue_seed_all(self, cb, traceback): 943 self.manual_seed_all_cb = (cb, traceback) 944 # update seed_all to be latest 945 self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] 946 947 def queue_seed(self, cb, traceback): 948 self.manual_seed_cb = (cb, traceback) 949 # update seed to be latest 950 self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] 951 952 def get_calls(self) -> List: 953 return self.call_order 954 955 956logger = logging.getLogger(__name__) 957P = ParamSpec("P") 958 959 960class CallbackRegistry(Generic[P]): 961 def __init__(self, name: str): 962 self.name = name 963 self.callback_list: List[Callable[P, None]] = [] 964 965 def add_callback(self, cb: Callable[P, None]) -> None: 966 self.callback_list.append(cb) 967 968 def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: 969 for cb in self.callback_list: 970 try: 971 cb(*args, **kwargs) 972 except Exception as e: 973 logger.exception( 974 "Exception in callback for %s registered with gpu trace", self.name 975 ) 976 977 978# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py 979# for use in the weights_only Unpickler. 980 981IMPORT_MAPPING = { 982 "__builtin__": "builtins", 983 "copy_reg": "copyreg", 984 "Queue": "queue", 985 "repr": "reprlib", 986 "_abcoll": "collections.abc", 987 # Non-mutual mappings. 988 "UserDict": "collections", 989 "UserList": "collections", 990 "UserString": "collections", 991 "whichdb": "dbm", 992 "StringIO": "io", 993 "cStringIO": "io", 994} 995 996 997# This contains rename rules that are easy to handle. We ignore the more 998# complex stuff (e.g. mapping the names in the urllib and types modules). 999# These rules should be run before import names are fixed. 1000NAME_MAPPING = { 1001 ("__builtin__", "xrange"): ("builtins", "range"), 1002 ("__builtin__", "reduce"): ("functools", "reduce"), 1003 ("__builtin__", "intern"): ("sys", "intern"), 1004 ("__builtin__", "unichr"): ("builtins", "chr"), 1005 ("__builtin__", "unicode"): ("builtins", "str"), 1006 ("__builtin__", "long"): ("builtins", "int"), 1007 ("itertools", "izip"): ("builtins", "zip"), 1008 ("itertools", "imap"): ("builtins", "map"), 1009 ("itertools", "ifilter"): ("builtins", "filter"), 1010 ("itertools", "ifilterfalse"): ("itertools", "filterfalse"), 1011 ("itertools", "izip_longest"): ("itertools", "zip_longest"), 1012 ("UserDict", "IterableUserDict"): ("collections", "UserDict"), 1013 ("UserList", "UserList"): ("collections", "UserList"), 1014 ("UserString", "UserString"): ("collections", "UserString"), 1015 # Non-mutual mappings. 1016 ("__builtin__", "basestring"): ("builtins", "str"), 1017 ("exceptions", "StandardError"): ("builtins", "Exception"), 1018 ("UserDict", "UserDict"): ("collections", "UserDict"), 1019} 1020