1_overwrite_module_params_on_conversion: bool = False 2_swap_module_params_on_conversion: bool = False 3 4 5def set_overwrite_module_params_on_conversion(value: bool) -> None: 6 """ 7 Sets whether to assign new tensors to the parameters instead of changing the 8 existing parameters in-place when converting an ``nn.Module``. 9 10 When enabled, the following methods will assign new parameters to the module: 11 12 #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices 13 #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype 14 #. :meth:`nn.Module.to` 15 #. :meth:`nn.Module.to_empty` 16 17 Args: 18 value (bool): Whether to assign new tensors or not. 19 20 """ 21 global _overwrite_module_params_on_conversion 22 _overwrite_module_params_on_conversion = value 23 24 25def get_overwrite_module_params_on_conversion() -> bool: 26 """ 27 Returns whether to assign new tensors to the parameters instead of changing the 28 existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``. 29 30 See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information. 31 """ 32 return _overwrite_module_params_on_conversion 33 34 35def set_swap_module_params_on_conversion(value: bool) -> None: 36 """ 37 Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to 38 change the existing parameters in-place when converting an ``nn.Module`` and instead 39 of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``. 40 41 .. note:: 42 This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion` 43 44 When enabled, the following methods will swap the existing parameters in-place: 45 46 #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices 47 #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype 48 #. :meth:`nn.Module.to` 49 #. :meth:`nn.Module.to_empty` 50 #. :meth:`nn.Module.load_state_dict` 51 52 The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows: 53 54 #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via 55 :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``) 56 #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter` 57 #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors` 58 with ``res`` 59 60 Args: 61 value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not. 62 63 """ 64 global _swap_module_params_on_conversion 65 _swap_module_params_on_conversion = value 66 67 68def get_swap_module_params_on_conversion() -> bool: 69 """ 70 Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to 71 change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``. 72 73 See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information. 74 """ 75 return _swap_module_params_on_conversion 76