1# mypy: allow-untyped-defs 2import contextlib 3import warnings 4from typing import Generator 5 6import torch 7from torch._C import default_generator 8 9 10def set_rng_state(new_state: torch.Tensor) -> None: 11 r"""Sets the random number generator state. 12 13 .. note:: This function only works for CPU. For CUDA, please use 14 :func:`torch.manual_seed`, which works for both CPU and CUDA. 15 16 Args: 17 new_state (torch.ByteTensor): The desired state 18 """ 19 default_generator.set_state(new_state) 20 21 22def get_rng_state() -> torch.Tensor: 23 r"""Returns the random number generator state as a `torch.ByteTensor`. 24 25 .. note:: The returned state is for the default generator on CPU only. 26 27 See also: :func:`torch.random.fork_rng`. 28 """ 29 return default_generator.get_state() 30 31 32def manual_seed(seed) -> torch._C.Generator: 33 r"""Sets the seed for generating random numbers on all devices. Returns a 34 `torch.Generator` object. 35 36 Args: 37 seed (int): The desired seed. Value must be within the inclusive range 38 `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError 39 is raised. Negative inputs are remapped to positive values with the formula 40 `0xffff_ffff_ffff_ffff + seed`. 41 """ 42 seed = int(seed) 43 import torch.cuda 44 45 if not torch.cuda._is_in_bad_fork(): 46 torch.cuda.manual_seed_all(seed) 47 48 import torch.mps 49 50 if not torch.mps._is_in_bad_fork(): 51 torch.mps.manual_seed(seed) 52 53 import torch.xpu 54 55 if not torch.xpu._is_in_bad_fork(): 56 torch.xpu.manual_seed_all(seed) 57 58 _seed_custom_device(seed) 59 60 return default_generator.manual_seed(seed) 61 62 63def seed() -> int: 64 r"""Sets the seed for generating random numbers to a non-deterministic 65 random number on all devices. Returns a 64 bit number used to seed the RNG. 66 """ 67 seed = default_generator.seed() 68 import torch.cuda 69 70 if not torch.cuda._is_in_bad_fork(): 71 torch.cuda.manual_seed_all(seed) 72 73 import torch.mps 74 75 if not torch.mps._is_in_bad_fork(): 76 torch.mps.manual_seed(seed) 77 78 import torch.xpu 79 80 if not torch.xpu._is_in_bad_fork(): 81 torch.xpu.manual_seed_all(seed) 82 83 _seed_custom_device(seed) 84 85 return seed 86 87 88def _seed_custom_device(seed) -> None: 89 r"""Sets the seed to generate random numbers for custom device. 90 91 Args: 92 seed (int): The desired seed. 93 94 See [Note: support the custom device with privateuse1] 95 """ 96 seed = int(seed) 97 custom_backend_name = torch._C._get_privateuse1_backend_name() 98 if hasattr(torch, custom_backend_name): 99 custom_device_mod = getattr(torch, custom_backend_name) 100 _bad_fork_name = "_is_in_bad_fork" 101 _seed_all_name = "manual_seed_all" 102 if hasattr(custom_device_mod, _bad_fork_name) and hasattr( 103 custom_device_mod, _seed_all_name 104 ): 105 if not getattr(custom_device_mod, _bad_fork_name)(): 106 getattr(custom_device_mod, _seed_all_name)(seed) 107 else: 108 message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's " 109 message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module." 110 warnings.warn(message, UserWarning, stacklevel=3) 111 112 113def initial_seed() -> int: 114 r"""Returns the initial seed for generating random numbers as a 115 Python `long`. 116 117 .. note:: The returned seed is for the default generator on CPU only. 118 """ 119 return default_generator.initial_seed() 120 121 122_fork_rng_warned_already = False 123 124 125@contextlib.contextmanager 126def fork_rng( 127 devices=None, 128 enabled=True, 129 _caller="fork_rng", 130 _devices_kw="devices", 131 device_type="cuda", 132) -> Generator: 133 """ 134 Forks the RNG, so that when you return, the RNG is reset 135 to the state that it was previously in. 136 137 Args: 138 devices (iterable of Device IDs): devices for which to fork 139 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 140 on all devices, but will emit a warning if your machine has a lot 141 of devices, since this function will run very slowly in that case. 142 If you explicitly specify devices, this warning will be suppressed 143 enabled (bool): if ``False``, the RNG is not forked. This is a convenience 144 argument for easily disabling the context manager without having 145 to delete it and unindent your Python code under it. 146 device_type (str): device type str, default is `cuda`. As for custom device, 147 see details in [Note: support the custom device with privateuse1] 148 """ 149 150 device_type = torch.device(device_type).type 151 device_mod = getattr(torch, device_type, None) 152 if device_mod is None: 153 raise RuntimeError( 154 f"torch has no module of `{device_type}`, you should register " 155 + "a module by `torch._register_device_module`." 156 ) 157 global _fork_rng_warned_already 158 159 # Internal arguments: 160 # _caller: the function which called fork_rng, which the user used 161 # _devices_kw: the devices keyword of _caller 162 163 if not enabled: 164 yield 165 return 166 167 if devices is None: 168 num_devices = device_mod.device_count() 169 if num_devices > 1 and not _fork_rng_warned_already: 170 message = ( 171 f"{device_type.upper()} reports that you have {num_devices} available devices, and " 172 f"you have used {_caller} without explicitly specifying which devices are being used. " 173 f"For safety, we initialize *every* {device_type.upper()} device by default, which can " 174 f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" 175 f" making use of a few {device_type.upper()} devices, set the environment variable " 176 f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " 177 "with the set of devices you are actually using. For example, if you are using CPU only, " 178 "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " 179 f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " 180 f"and suppress this warning, set the '{_devices_kw}' keyword argument to " 181 f"`range(torch.{device_type}.device_count())`." 182 ) 183 warnings.warn(message) 184 _fork_rng_warned_already = True 185 devices = list(range(num_devices)) 186 else: 187 # Protect against user passing us a generator; we need to traverse this 188 # multiple times but a generator will be exhausted upon first traversal 189 devices = list(devices) 190 191 cpu_rng_state = torch.get_rng_state() 192 device_rng_states = [device_mod.get_rng_state(device) for device in devices] 193 194 try: 195 yield 196 finally: 197 torch.set_rng_state(cpu_rng_state) 198 for device, device_rng_state in zip(devices, device_rng_states): 199 device_mod.set_rng_state(device_rng_state, device) 200