xref: /aosp_15_r20/external/pytorch/torch/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copyreg
3import functools
4import logging
5import sys
6import traceback
7import warnings
8from collections import defaultdict
9from typing import Any, Callable, DefaultDict, Generic, List, Optional
10from typing_extensions import ParamSpec
11
12import torch
13
14
15def _type(self, dtype=None, non_blocking=False, **kwargs):
16    """Returns the type if `dtype` is not provided, else casts this object to
17    the specified type.
18
19    If this is already of the correct type, no copy is performed and the
20    original object is returned.
21
22    Args:
23        dtype (type or string): The desired type
24        non_blocking (bool): If ``True``, and the source is in pinned memory
25            and destination is on the GPU or vice versa, the copy is performed
26            asynchronously with respect to the host. Otherwise, the argument
27            has no effect.
28        **kwargs: For compatibility, may contain the key ``async`` in place of
29            the ``non_blocking`` argument. The ``async`` arg is deprecated.
30    """
31    non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
32    if dtype is None:
33        return self.__module__ + "." + self.__class__.__name__
34
35    if isinstance(dtype, str):
36        dtype = _import_dotted_name(dtype)
37    if dtype == type(self):
38        return self
39    if self.is_sparse:
40        if not dtype.is_sparse:
41            raise RuntimeError("Cannot cast sparse tensor to dense tensor")
42        new_module_name = dtype.__module__.replace(".sparse", "")
43        new_values_type_name = new_module_name + "." + dtype.__name__
44        new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
45        new_indices_type_name = new_module_name + ".LongTensor"
46        new_indices = torch.Tensor._indices(self).type(
47            new_indices_type_name, non_blocking
48        )
49        return dtype(new_indices, new_values, self.size())
50    if dtype.is_sparse:
51        raise RuntimeError("Cannot cast dense tensor to sparse tensor")
52    return dtype(self.size()).copy_(self, non_blocking)
53
54
55def _to(self, device, non_blocking=False):
56    """Returns a copy of this object in device memory.
57
58    If this object is already on the correct device, then no copy is performed
59    and the original object is returned.
60
61    Args:
62        device (int): The destination device.
63        non_blocking (bool): If ``True`` and the source is in pinned memory,
64            the copy will be asynchronous with respect to the host. Otherwise,
65            the argument has no effect.
66    """
67    if self.device == device:
68        return self
69
70    device_module = getattr(torch, device.type, None)
71    assert (
72        device_module is not None
73    ), f"{device.type.upper()} device module is not loaded"
74    with device_module.device(device):
75        if self.is_sparse and hasattr(device_module, "sparse"):
76            new_type = getattr(device_module.sparse, self.__class__.__name__)
77            indices = getattr(torch.Tensor._indices(self), device.type)(
78                device, non_blocking
79            )
80            values = getattr(torch.Tensor._values(self), device.type)(
81                device, non_blocking
82            )
83            return new_type(indices, values, self.size())
84        else:
85            assert (
86                not self.is_sparse
87            ), f"sparse storage is not supported for {device.type.upper()} tensors"
88            untyped_storage = torch.UntypedStorage(self.size(), device=device)
89            untyped_storage.copy_(self, non_blocking)
90            return untyped_storage
91
92
93def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
94    """Return the non-blocking flag given the function name and kwargs.
95
96    Args:
97        function_name (str): the name of the function being used.
98        non_blocking (bool): the default value.
99        **kwargs (dict): the kwargs passed to the function.
100    """
101    if not kwargs:
102        return non_blocking
103    if len(kwargs) != 1 or "async" not in kwargs:
104        message = "{}() got an unexpected keyword argument '{}'"
105        argument = list(kwargs.keys()).pop()
106        raise TypeError(message.format(function_name, argument))
107    warnings.warn("'async' is deprecated; use 'non_blocking'")
108    return kwargs["async"]
109
110
111def _get_restore_location(device):
112    """Return the map_location location.
113
114    Used for rebuild functions where the tensor device is distinct from the storage
115    """
116
117    map_location = torch.serialization._serialization_tls.map_location
118    if map_location is None:
119        return device
120    else:
121        if isinstance(map_location, dict):
122            return map_location.get(device, device)
123        elif isinstance(map_location, (str, torch.device)):
124            return map_location
125        else:
126            assert callable(map_location)
127            raise RuntimeError(
128                "Callable map_location not supported with _rebuild_wrapper_subclass "
129                "or _rebuild_device_tensor_from_numpy"
130            )
131
132
133# Note [Don't serialize hooks]
134# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135# Since time immemorial, we have serialized the backward hooks associated with
136# variables.  This kind of half-worked--Python can pickle global functions
137# (but not closures!)--but there were problems.
138#
139#   - It's fragile.  If you serialize a backward hook into a saved
140#     model, and then you rename the function associated with the hook,
141#     now your saved model is broken and you can't load it anymore.
142#
143#   - It's not actually used.  The standard recommendation is to
144#     serialize the *state_dict* of a model, not the model itself
145#     (since this is more stable to code changes affecting the model
146#     serialization), and the state dict saves "data" only, thus
147#     stripping the backward hooks.  In some cases, hooks are
148#     essential to the well-functioning of a model (e.g., DDP),
149#     but DDP already manages readding the hooks!
150#
151#   - We didn't serialize them in many cases.  Prior to #10220, we
152#     were dropping backward hooks in ForkingPickler.  We "fixed" this
153#     to be convenient with other serialization sites, but lack of
154#     serializing backward hooks wasn't actually the root cause of
155#     the bug.
156#
157# With these cases in mind, we have decided that a better strategy
158# is to just NOT serialize hooks at all.
159#
160# Since this is a BC-breaking change, we should warn when we previously
161# serialized a hook, but no longer do so. This will be done by adding a special
162# sentinel property to hooks will be used to suppress this warning. If a hook
163# has the property _torch_serialize_ignore, we will not emit a warning if we
164# attempt to serialize a Tensor with this hook attached to it.
165#
166# By the way, when _backward_hooks is skipped, we must give an EMPTY
167# OrderedDict(), if you pass a None you'll run afoul #12219.
168
169
170# TODO: Once we decide to break serialization FC, `storage` no longer needs to
171# be a TypedStorage
172def _rebuild_tensor(storage, storage_offset, size, stride):
173    # first construct a tensor with the correct dtype/device
174    t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
175    return t.set_(storage._untyped_storage, storage_offset, size, stride)
176
177
178def get_tensor_metadata(tensor):
179    # Tensor's Metadata for serializing.
180    # Currently, this only returns a dict[string, bool] specifing whether
181    # `conj` or `neg` bit is set.
182    assert isinstance(tensor, torch.Tensor)
183    return torch._C._get_tensor_metadata(tensor)  # type: ignore[attr-defined]
184
185
186def set_tensor_metadata(tensor, metadata):
187    # See `get_tensor_metadata` above
188    assert isinstance(metadata, dict)
189    assert isinstance(tensor, torch.Tensor)
190    torch._C._set_tensor_metadata(tensor, metadata)  # type: ignore[attr-defined]
191
192
193def _rebuild_tensor_v2(
194    storage,
195    storage_offset,
196    size,
197    stride,
198    requires_grad,
199    backward_hooks,
200    metadata=None,
201):
202    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203    tensor.requires_grad = requires_grad
204    if metadata:
205        set_tensor_metadata(tensor, metadata)
206
207    # NB: This line exists only for backwards compatibility; the
208    # general expectation is that backward_hooks is an empty
209    # OrderedDict.  See Note [Don't serialize hooks]
210    tensor._backward_hooks = backward_hooks
211    return tensor
212
213
214def _rebuild_tensor_v3(
215    storage,
216    storage_offset,
217    size,
218    stride,
219    requires_grad,
220    backward_hooks,
221    dtype,
222    metadata=None,
223):
224    t = torch.empty(
225        (0,),
226        dtype=dtype,
227        device=storage._untyped_storage.device,
228        requires_grad=requires_grad,
229    )
230    t.set_(storage._untyped_storage, storage_offset, size, stride)
231    if metadata:
232        set_tensor_metadata(t, metadata)
233    t._backward_hooks = backward_hooks
234    return t
235
236
237_sparse_tensors_to_validate: List["torch.Tensor"] = []
238
239
240# In _legacy_load() in serialization.py we unpickle storages after the sparse
241# tensors have been already unpickled. Those storages contain data necessary for
242# validating sparse tensors: indices and values. That's why sparse tensors are
243# first unpickled without any validation, and then this function is called just
244# before _legacy_load() returns, so that all the sparse tensors can be validated
245# in bulk.
246#
247# The same procedure must be followed by _load() in serialization.py because due
248# to Pickler semantics, we have to use the same (non-validating) function for
249# unpickling sparse tensors, regardless of the caller.
250def _validate_loaded_sparse_tensors():
251    try:
252        for t in _sparse_tensors_to_validate:
253            if t.layout is torch.sparse_coo:
254                torch._validate_sparse_coo_tensor_args(
255                    t._indices(), t._values(), t.size(), t.is_coalesced()
256                )
257            elif t.layout in {
258                torch.sparse_csr,
259                torch.sparse_csc,
260                torch.sparse_bsr,
261                torch.sparse_bsc,
262            }:
263                # TODO: Validation currently involves an expensive traversal
264                # on CPU, which may include a device transfer.
265                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266                    compressed_indices, plain_indices = (
267                        t.crow_indices(),
268                        t.col_indices(),
269                    )
270                else:
271                    compressed_indices, plain_indices = (
272                        t.ccol_indices(),
273                        t.row_indices(),
274                    )
275                torch._validate_sparse_compressed_tensor_args(
276                    compressed_indices, plain_indices, t.values(), t.size(), t.layout
277                )
278            else:
279                raise NotImplementedError(
280                    f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
281                )
282
283    finally:
284        _sparse_tensors_to_validate.clear()
285
286
287def _rebuild_sparse_tensor(layout, data):
288    """
289    Rebuilds a sparse tensor from its sparse storage representation.
290
291    Args:
292        layout (str): The sparse storage layout of the tensor.
293        data (tuple): The tensor's sparse storage representation.
294    """
295    if layout == torch.sparse_coo:
296        if len(data) == 3:
297            # For BC:
298            indices, values, size = data
299            is_coalesced = None
300        else:
301            indices, values, size, is_coalesced = data
302        result = torch.sparse_coo_tensor(
303            indices, values, size, check_invariants=False, is_coalesced=is_coalesced
304        )
305        _sparse_tensors_to_validate.append(result)
306        return result
307
308    elif layout in {
309        torch.sparse_csr,
310        torch.sparse_csc,
311        torch.sparse_bsr,
312        torch.sparse_bsc,
313    }:
314        compressed_indices, plain_indices, values, size = data
315        result = torch.sparse_compressed_tensor(
316            compressed_indices,
317            plain_indices,
318            values,
319            size,
320            layout=layout,
321            check_invariants=False,
322        )
323        _sparse_tensors_to_validate.append(result)
324        return result
325
326    raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
327
328
329def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330    return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
331
332
333def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334    device = _get_restore_location(device)
335    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
336    tensor.requires_grad = requires_grad
337    return tensor
338
339
340# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
341_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
342
343
344def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
345    return torch.empty_strided(
346        size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
347    )
348
349
350def _rebuild_wrapper_subclass(
351    cls,
352    dtype,
353    size,
354    stride,
355    storage_offset,
356    layout,
357    device,
358    requires_grad,
359):
360    device = _get_restore_location(device)
361    return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
362        cls,
363        size,
364        strides=stride,
365        dtype=dtype,
366        storage_offset=storage_offset,
367        layout=layout,
368        device=device,
369        requires_grad=requires_grad,
370    )
371
372
373# TODO: Once we decide to break serialization FC, `storage` no longer needs to
374# be a TypedStorage
375def _rebuild_qtensor(
376    storage,
377    storage_offset,
378    size,
379    stride,
380    quantizer_params,
381    requires_grad,
382    backward_hooks,
383):
384    qscheme = quantizer_params[0]
385    if qscheme == torch.per_tensor_affine:
386        _, scale, zero_point = quantizer_params
387        tensor = torch._empty_affine_quantized(
388            size,
389            scale=scale,
390            zero_point=zero_point,
391            dtype=storage.dtype,
392            device=storage.device,
393        )
394    elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
395        _, scales, zero_points, axis = quantizer_params
396        if type(scales) is list and type(zero_points) is list:
397            if qscheme == torch.per_channel_affine:
398                scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
399                zero_points = torch.tensor(
400                    zero_points, dtype=torch.long, device=storage.device
401                )
402            else:
403                scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
404                zero_points = torch.tensor(
405                    zero_points, dtype=torch.float, device=storage.device
406                )
407        tensor = torch._empty_per_channel_affine_quantized(
408            size,
409            scales=scales,
410            zero_points=zero_points,
411            axis=axis,
412            dtype=storage.dtype,
413            device=storage.device,
414        )
415    else:
416        raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
417    tensor.set_(storage, storage_offset, size, stride)
418    tensor.requires_grad = requires_grad
419    # NB: This line exists only for backwards compatibility; the
420    # general expectation is that backward_hooks is an empty
421    # OrderedDict.  See Note [Don't serialize hooks]
422    tensor._backward_hooks = backward_hooks
423    return tensor
424
425
426def _rebuild_parameter(data, requires_grad, backward_hooks):
427    param = torch.nn.Parameter(data, requires_grad)
428    # NB: This line exists only for backwards compatibility; the
429    # general expectation is that backward_hooks is an empty
430    # OrderedDict.  See Note [Don't serialize hooks]
431    param._backward_hooks = backward_hooks
432
433    return param
434
435
436def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
437    param = torch.nn.Parameter(data, requires_grad)
438    # NB: This line exists only for backwards compatibility; the
439    # general expectation is that backward_hooks is an empty
440    # OrderedDict.  See Note [Don't serialize hooks]
441    param._backward_hooks = backward_hooks
442
443    # Restore state on Parameter like python attr.
444    param = _set_obj_state(param, state)
445    return param
446
447
448def _get_obj_state(obj):
449    # Get the state of the python subclass
450    # This loosely mimicks the function on the object class but since Tensor do not inherit
451    # from it, we cannot call that function directly
452    # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
453    # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
454    # the else branch will never be taken.
455    getstate_fn = getattr(obj, "__getstate__", None)
456    if getstate_fn:
457        state = getstate_fn()
458    else:
459        slots_to_save = copyreg._slotnames(obj.__class__)  # type: ignore[attr-defined]
460        if slots_to_save:
461            state = (
462                obj.__dict__,
463                {
464                    name: getattr(obj, name)
465                    for name in slots_to_save
466                    if hasattr(obj, name)
467                },
468            )
469        else:
470            state = obj.__dict__
471
472    return state
473
474
475def _set_obj_state(obj, state):
476    if isinstance(state, tuple):
477        if not len(state) == 2:
478            raise RuntimeError(f"Invalid serialized state: {state}")
479        dict_state = state[0]
480        slots_state = state[1]
481    else:
482        dict_state = state
483        slots_state = None
484
485    # Starting with Python 3.11, the __dict__ attribute is lazily created
486    # and is serialized as None when not needed.
487    if dict_state:
488        for k, v in dict_state.items():
489            setattr(obj, k, v)
490
491    if slots_state:
492        for k, v in slots_state.items():
493            setattr(obj, k, v)
494    return obj
495
496
497def _import_dotted_name(name):
498    components = name.split(".")
499    obj = __import__(components[0])
500    for component in components[1:]:
501        obj = getattr(obj, component)
502    return obj
503
504
505def _flatten_dense_tensors(tensors):
506    """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
507    same dense type.
508
509    Since inputs are dense, the resulting tensor will be a concatenated 1D
510    buffer. Element-wise operation on this buffer will be equivalent to
511    operating individually.
512
513    Args:
514        tensors (Iterable[Tensor]): dense tensors to flatten.
515
516    Returns:
517        A contiguous 1D buffer containing input tensors.
518    """
519    return torch._C._nn.flatten_dense_tensors(tensors)
520
521
522def _flatten_sparse_tensors(tensors):
523    """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
524    one of values. Assume tensors are of same sparse type.
525
526    Args:
527        tensors (Iterable[Tensor]): sparse tensors to flatten.
528
529    Returns:
530        A tuple of two contiguous 1D buffers, one containing input tensors'
531        indices and the other containing the values.
532    """
533    flat_indices = torch._C._nn.flatten_dense_tensors(
534        [torch.Tensor._indices(t) for t in tensors]
535    )
536    flat_values = torch._C._nn.flatten_dense_tensors(
537        [torch.Tensor._values(t) for t in tensors]
538    )
539    return flat_indices, flat_values
540
541
542def _unflatten_dense_tensors(flat, tensors):
543    """View a flat buffer using the sizes of tensors. Assume that tensors are of
544    same dense type, and that flat is given by _flatten_dense_tensors.
545
546    Args:
547        flat (Tensor): flattened dense tensors to unflatten.
548        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
549          unflatten flat.
550
551    Returns:
552        Unflattened dense tensors with sizes same as tensors and values from
553        flat.
554    """
555    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
556
557
558def _unflatten_sparse_tensors(flat, tensors):
559    """View flat buffer (containing indices and values) using the sizes of
560    tensors. Assume that tensors are of same sparse type, and that flat is given
561    by _flatten_sparse_tensors.
562
563    Args:
564        flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
565          tensors to unflatten.
566        tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
567          unflatten flat.
568
569    Returns:
570        Unflattened sparse tensors with sizes same as tensors and values from
571        flat.
572    """
573    flat_indices, flat_values = flat
574    indices = torch._C._nn.unflatten_dense_tensors(
575        flat_indices, [torch.Tensor._indices(t) for t in tensors]
576    )
577    values = torch._C._nn.unflatten_dense_tensors(
578        flat_values, [torch.Tensor._values(t) for t in tensors]
579    )
580    outputs = []
581    for t, i, v in zip(tensors, indices, values):
582        outputs.append(t.new(i, v, t.size()))
583    return tuple(outputs)
584
585
586def _reorder_tensors_as(tensors, ordered_tensors):
587    """Assume that tensors are of same order as ordered_tensors within their
588    types, e.g., from _take_tensors. Reorder them to be of same order as
589    ordered_tensors.
590
591    Args:
592        tensors (Iterable[Tensor]): tensors to be reordered. They should be of
593          the same order as ordered_tensors within their own types.
594        ordered_tensors (Iterable[Tensor]): tensors whose order will be the
595          reference.
596
597    Returns:
598        Ordered tuple of tensors with contents from tensors and order of
599        ordered_tensors.
600    """
601    type_dict = defaultdict(list)
602    for tensor in tensors:
603        type_dict[tensor.type()].append(tensor)
604    type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
605    return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
606
607
608def _take_tensors(tensors, size_limit):
609    """Group tensors into chunks. This generator yields a chunk at each time,
610    each containing tensors of same type up to certain byte limit in total size.
611
612    Args:
613        tensors (Sequence): A sequence of tensors to be separated into chunks.
614        size_limit (int): The limit of each chunk in bytes.
615
616    Yields:
617        Blocks of tensors of same type and within size_limit. The yielded
618        tensors are only ordered as the original sequence within its types.
619    """
620    buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
621    for tensor in tensors:
622        t = tensor.type()
623        if tensor.is_sparse:
624            indices = torch.Tensor._indices(tensor)
625            values = torch.Tensor._values(tensor)
626            size = (
627                indices.numel() * indices.element_size()
628                + values.numel() * values.element_size()
629            )
630        else:
631            size = tensor.numel() * tensor.element_size()
632        buf_and_size = buf_dict[t]
633        if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
634            yield buf_and_size[0]
635            buf_and_size = buf_dict[t] = [[], 0]
636        buf_and_size[0].append(tensor)
637        buf_and_size[1] += size
638    for buf, _ in buf_dict.values():
639        if len(buf) > 0:
640            yield buf
641
642
643# annotation decorator to get annotations in a way that is compatible
644# with both Python 2 and 3
645def annotate(ret, **kwargs):
646    def dec(fun):
647        fun.__annotations__ = dict(kwargs)
648        fun.__annotations__["return"] = ret
649        return fun
650
651    return dec
652
653
654def render_call(fn, args, kwargs):
655    str_fn = torch.overrides.resolve_name(fn)
656    if str_fn is None:
657        str_fn = str(fn)
658
659    str_args: List[str] = []
660    with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
661        str_args.extend(repr(a) for a in args)
662        str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
663        r = f"{str_fn}({', '.join(str_args)})"
664    return r
665
666
667# NOTE [ Python Traceback Reference Cycle Problem ]
668#
669# When using sys.exc_info(), it is important to **not** store the exc_info[2],
670# which is the traceback, because otherwise you will run into the traceback
671# reference cycle problem, i.e., the traceback holding reference to the frame,
672# and the frame (which holds reference to all the object in its temporary scope)
673# holding reference the traceback.
674
675
676class KeyErrorMessage(str):
677    r"""str subclass that returns itself in repr"""
678
679    def __repr__(self):
680        return self
681
682
683class ExceptionWrapper:
684    r"""Wraps an exception plus traceback to communicate across threads"""
685
686    def __init__(self, exc_info=None, where="in background"):
687        # It is important that we don't store exc_info, see
688        # NOTE [ Python Traceback Reference Cycle Problem ]
689        if exc_info is None:
690            exc_info = sys.exc_info()
691        self.exc_type = exc_info[0]
692        self.exc_msg = "".join(traceback.format_exception(*exc_info))
693        self.where = where
694
695    def reraise(self):
696        r"""Reraises the wrapped exception in the current thread"""
697        # Format a message such as: "Caught ValueError in DataLoader worker
698        # process 2. Original Traceback:", followed by the traceback.
699        msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
700        if self.exc_type == KeyError:
701            # KeyError calls repr() on its argument (usually a dict key). This
702            # makes stack traces unreadable. It will not be changed in Python
703            # (https://bugs.python.org/issue2651), so we work around it.
704            msg = KeyErrorMessage(msg)
705        elif getattr(self.exc_type, "message", None):
706            # Some exceptions have first argument as non-str but explicitly
707            # have message field
708            raise self.exc_type(message=msg)
709        try:
710            exception = self.exc_type(msg)
711        except TypeError:
712            # If the exception takes multiple arguments, don't try to
713            # instantiate since we don't know how to
714            raise RuntimeError(msg) from None
715        raise exception
716
717
718def _get_available_device_type():
719    if torch.cuda.is_available():
720        return "cuda"
721    if hasattr(torch, "xpu") and torch.xpu.is_available():  # type: ignore[attr-defined]
722        return "xpu"
723    if hasattr(torch, "mtia") and torch.mtia.is_available():
724        return "mtia"
725    custom_backend_name = torch._C._get_privateuse1_backend_name()
726    custom_device_mod = getattr(torch, custom_backend_name, None)
727    if custom_device_mod and custom_device_mod.is_available():
728        return custom_backend_name
729    # add more available device types here
730    return None
731
732
733def _get_device_attr(get_member):
734    device_type = _get_available_device_type()
735    if device_type and device_type.lower() == "cuda":
736        return get_member(torch.cuda)
737    if device_type and device_type.lower() == "xpu":
738        return get_member(torch.xpu)  # type: ignore[attr-defined]
739    if device_type and device_type.lower() == "mtia":
740        return get_member(torch.mtia)
741    if device_type == torch._C._get_privateuse1_backend_name():
742        return get_member(getattr(torch, device_type))
743    # add more available device types here
744    return None
745
746
747def _get_current_device_index():
748    # current device index
749    return _get_device_attr(lambda m: m.current_device())
750
751
752def _get_all_device_indices():
753    # all device index
754    return _get_device_attr(lambda m: list(range(m.device_count())))
755
756
757def _get_devices_properties(device_ids):
758    # all device properties
759    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
760
761
762def get_current_device_index() -> int:
763    r"""Checks if there are CUDA devices available and
764    returns the device index of the current default CUDA device.
765    Returns -1 in case there are no CUDA devices available.
766    Arguments: ``None``
767    """
768    if torch.cuda.device_count() > 0:
769        return torch.cuda.current_device()
770    return -1
771
772
773def _get_device_index(
774    device: Any,
775    optional: bool = False,
776    allow_cpu: bool = False,
777) -> int:
778    r"""Gets the device index from :attr:`device`, which can be a torch.device
779    object, a Python integer, or ``None``.
780
781    If :attr:`device` is a torch.device object, returns the device index if it
782    has index. Note that for a device without a specified index,
783    i.e., ``torch.device('xxx')``, this will return the current default
784    device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
785    CPU devices will be accepted and ``-1`` will be returned in this case.
786
787    If :attr:`device` is a Python integer, it is returned as is.
788
789    If :attr:`device` is ``None``, this will return the current default
790    device of the supported runtime platform if :attr:`optional` is ``True``.
791    i.e., the current default CUDA device will be returned if CUDA runtime is supported.
792    """
793    if isinstance(device, str):
794        device = torch.device(device)
795    device_idx: Optional[int] = None
796    if isinstance(device, torch.device):
797        if not allow_cpu and device.type == "cpu":
798            raise ValueError(f"Expected a non cpu device, but got: {device}")
799        device_idx = -1 if device.type == "cpu" else device.index
800    if isinstance(device, int):
801        device_idx = device
802    if device_idx is None:
803        if optional:
804            # The eager API _get_current_device_index uses `lambda` functions which are
805            # not supported in JIT and hence not scriptable. The JIT equivalent API to get
806            # the current device index is `get_current_device_index()` which can
807            # be scripted. We use is_scripting to check the mode we are in and call the
808            # appropriate API.
809            if torch.jit.is_scripting():
810                device_idx = get_current_device_index()
811            else:
812                device_idx = _get_current_device_index()
813        else:
814            raise ValueError(
815                f"Expected a torch.device with a specified index or an integer, but got:{device}"
816            )
817    return device_idx
818
819
820def _handle_complex(tensor):
821    """
822    Returns a real view of a tensor if complex dtype else just the tensor
823    need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
824    """
825    return (
826        torch.view_as_real(tensor)
827        if not isinstance(tensor, torch.nn.UninitializedParameter)
828        and tensor.is_complex()
829        else tensor
830    )
831
832
833def _element_size(dtype):
834    """
835    Returns the element size for a dtype, in bytes
836    """
837    if not isinstance(dtype, torch.dtype):
838        raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
839
840    if dtype.is_complex:
841        return torch.finfo(dtype).bits >> 2
842    elif dtype.is_floating_point:
843        return torch.finfo(dtype).bits >> 3
844    elif dtype == torch.bool:
845        # NOTE: torch.bool is not supported in torch.iinfo()
846        return 1
847    else:
848        return torch.iinfo(dtype).bits >> 3
849
850
851class _ClassPropertyDescriptor:
852    def __init__(self, fget, fset=None):
853        self.fget = fget
854
855    def __get__(self, instance, owner=None):
856        if owner is None:
857            owner = type(instance)
858        return self.fget.__get__(instance, owner)()
859
860
861def classproperty(func):
862    if not isinstance(func, (classmethod, staticmethod)):
863        func = classmethod(func)
864    return _ClassPropertyDescriptor(func)
865
866
867def is_compiling() -> bool:
868    """
869    Indicates whether we are tracing/compiling with torch.compile() or torch.export().
870
871    TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
872    """
873    return torch.compiler.is_compiling()
874
875
876def _functionalize_sync(t):
877    # This code lives in python instead of C++ since conditioning on a certain python subclass
878    # is much more of a pain in C++.
879    from torch._subclasses.functional_tensor import FunctionalTensor
880
881    if isinstance(t, FunctionalTensor):
882        # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
883        # when we sync our inner tensor.
884        # Why?
885        # (1) If there are input mutations in the graph, then they will be re-applied during
886        #     AOTAutograd when we call _sync() from inside of our functionalization kernels.
887        # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
888        #     which dispatches to a bunch of view ops
889        # (3) The input to these view ops is our inner FunctionalTensorWrapper
890        #     (since the sync was called from C++), not the python FunctionalTensor
891        # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
892        #     the view op, since it will see an input that is a C++ FunctionalTensorWrapper
893        #     (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
894        maybe_functional_mode = torch._C._unset_dispatch_mode(
895            torch._C._TorchDispatchModeKey.FUNCTIONAL
896        )
897        try:
898            torch._functionalize_sync(t.elem)  # type: ignore[attr-defined]
899        finally:
900            if maybe_functional_mode is not None:
901                torch._C._set_dispatch_mode(maybe_functional_mode)
902    else:
903        torch._functionalize_sync(t)  # type: ignore[attr-defined]
904
905
906@functools.lru_cache(2)
907def _get_device_module(device_type: str):
908    device_module = getattr(torch, device_type, None)
909    if device_module is None:
910        raise RuntimeError(
911            f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
912        )
913    return device_module
914
915
916def _dummy_type(name: str) -> type:
917    def get_err_fn(is_init: bool):
918        def err_fn(obj, *args, **kwargs):
919            if is_init:
920                class_name = obj.__class__.__name__
921            else:
922                class_name = obj.__name__
923            raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
924
925        return err_fn
926
927    return type(
928        name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
929    )
930
931
932class _LazySeedTracker:
933    # Since seeding is memory-less, only track the latest seed.
934    # Note: `manual_seed_all` followed by `manual_seed` overwrites
935    # the seed on current device. We track the order of **latest**
936    # calls between these two API.
937    def __init__(self):
938        self.manual_seed_all_cb = None
939        self.manual_seed_cb = None
940        self.call_order = []
941
942    def queue_seed_all(self, cb, traceback):
943        self.manual_seed_all_cb = (cb, traceback)
944        # update seed_all to be latest
945        self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
946
947    def queue_seed(self, cb, traceback):
948        self.manual_seed_cb = (cb, traceback)
949        # update seed to be latest
950        self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
951
952    def get_calls(self) -> List:
953        return self.call_order
954
955
956logger = logging.getLogger(__name__)
957P = ParamSpec("P")
958
959
960class CallbackRegistry(Generic[P]):
961    def __init__(self, name: str):
962        self.name = name
963        self.callback_list: List[Callable[P, None]] = []
964
965    def add_callback(self, cb: Callable[P, None]) -> None:
966        self.callback_list.append(cb)
967
968    def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
969        for cb in self.callback_list:
970            try:
971                cb(*args, **kwargs)
972            except Exception as e:
973                logger.exception(
974                    "Exception in callback for %s registered with gpu trace", self.name
975                )
976
977
978# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
979# for use in the weights_only Unpickler.
980
981IMPORT_MAPPING = {
982    "__builtin__": "builtins",
983    "copy_reg": "copyreg",
984    "Queue": "queue",
985    "repr": "reprlib",
986    "_abcoll": "collections.abc",
987    # Non-mutual mappings.
988    "UserDict": "collections",
989    "UserList": "collections",
990    "UserString": "collections",
991    "whichdb": "dbm",
992    "StringIO": "io",
993    "cStringIO": "io",
994}
995
996
997# This contains rename rules that are easy to handle.  We ignore the more
998# complex stuff (e.g. mapping the names in the urllib and types modules).
999# These rules should be run before import names are fixed.
1000NAME_MAPPING = {
1001    ("__builtin__", "xrange"): ("builtins", "range"),
1002    ("__builtin__", "reduce"): ("functools", "reduce"),
1003    ("__builtin__", "intern"): ("sys", "intern"),
1004    ("__builtin__", "unichr"): ("builtins", "chr"),
1005    ("__builtin__", "unicode"): ("builtins", "str"),
1006    ("__builtin__", "long"): ("builtins", "int"),
1007    ("itertools", "izip"): ("builtins", "zip"),
1008    ("itertools", "imap"): ("builtins", "map"),
1009    ("itertools", "ifilter"): ("builtins", "filter"),
1010    ("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1011    ("itertools", "izip_longest"): ("itertools", "zip_longest"),
1012    ("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1013    ("UserList", "UserList"): ("collections", "UserList"),
1014    ("UserString", "UserString"): ("collections", "UserString"),
1015    # Non-mutual mappings.
1016    ("__builtin__", "basestring"): ("builtins", "str"),
1017    ("exceptions", "StandardError"): ("builtins", "Exception"),
1018    ("UserDict", "UserDict"): ("collections", "UserDict"),
1019}
1020