xref: /aosp_15_r20/external/pytorch/torch/__future__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1_overwrite_module_params_on_conversion: bool = False
2_swap_module_params_on_conversion: bool = False
3
4
5def set_overwrite_module_params_on_conversion(value: bool) -> None:
6    """
7    Sets whether to assign new tensors to the parameters instead of changing the
8    existing parameters in-place when converting an ``nn.Module``.
9
10    When enabled, the following methods will assign new parameters to the module:
11
12    #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
13    #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
14    #. :meth:`nn.Module.to`
15    #. :meth:`nn.Module.to_empty`
16
17    Args:
18        value (bool): Whether to assign new tensors or not.
19
20    """
21    global _overwrite_module_params_on_conversion
22    _overwrite_module_params_on_conversion = value
23
24
25def get_overwrite_module_params_on_conversion() -> bool:
26    """
27    Returns whether to assign new tensors to the parameters instead of changing the
28    existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
29
30    See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
31    """
32    return _overwrite_module_params_on_conversion
33
34
35def set_swap_module_params_on_conversion(value: bool) -> None:
36    """
37    Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
38    change the existing parameters in-place when converting an ``nn.Module`` and instead
39    of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
40
41    .. note::
42        This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
43
44    When enabled, the following methods will swap the existing parameters in-place:
45
46    #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
47    #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
48    #. :meth:`nn.Module.to`
49    #. :meth:`nn.Module.to_empty`
50    #. :meth:`nn.Module.load_state_dict`
51
52    The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
53
54    #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
55       :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
56    #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
57    #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
58       with ``res``
59
60    Args:
61        value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
62
63    """
64    global _swap_module_params_on_conversion
65    _swap_module_params_on_conversion = value
66
67
68def get_swap_module_params_on_conversion() -> bool:
69    """
70    Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
71    change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
72
73    See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
74    """
75    return _swap_module_params_on_conversion
76