1# mypy: allow-untyped-defs 2from typing import Iterable, List, Union 3 4import torch 5from torch import Tensor 6 7from . import _lazy_call, _lazy_init, current_device, device_count 8 9 10__all__ = [ 11 "get_rng_state", 12 "get_rng_state_all", 13 "set_rng_state", 14 "set_rng_state_all", 15 "manual_seed", 16 "manual_seed_all", 17 "seed", 18 "seed_all", 19 "initial_seed", 20] 21 22 23def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: 24 r"""Return the random number generator state of the specified GPU as a ByteTensor. 25 26 Args: 27 device (torch.device or int, optional): The device to return the RNG state of. 28 Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 29 30 .. warning:: 31 This function eagerly initializes CUDA. 32 """ 33 _lazy_init() 34 if isinstance(device, str): 35 device = torch.device(device) 36 elif isinstance(device, int): 37 device = torch.device("cuda", device) 38 idx = device.index 39 if idx is None: 40 idx = current_device() 41 default_generator = torch.cuda.default_generators[idx] 42 return default_generator.get_state() 43 44 45def get_rng_state_all() -> List[Tensor]: 46 r"""Return a list of ByteTensor representing the random number states of all devices.""" 47 results = [] 48 for i in range(device_count()): 49 results.append(get_rng_state(i)) 50 return results 51 52 53def set_rng_state( 54 new_state: Tensor, device: Union[int, str, torch.device] = "cuda" 55) -> None: 56 r"""Set the random number generator state of the specified GPU. 57 58 Args: 59 new_state (torch.ByteTensor): The desired state 60 device (torch.device or int, optional): The device to set the RNG state. 61 Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 62 """ 63 with torch._C._DisableFuncTorch(): 64 new_state_copy = new_state.clone(memory_format=torch.contiguous_format) 65 if isinstance(device, str): 66 device = torch.device(device) 67 elif isinstance(device, int): 68 device = torch.device("cuda", device) 69 70 def cb(): 71 idx = device.index 72 if idx is None: 73 idx = current_device() 74 default_generator = torch.cuda.default_generators[idx] 75 default_generator.set_state(new_state_copy) 76 77 _lazy_call(cb) 78 79 80def set_rng_state_all(new_states: Iterable[Tensor]) -> None: 81 r"""Set the random number generator state of all devices. 82 83 Args: 84 new_states (Iterable of torch.ByteTensor): The desired state for each device. 85 """ 86 for i, state in enumerate(new_states): 87 set_rng_state(state, i) 88 89 90def manual_seed(seed: int) -> None: 91 r"""Set the seed for generating random numbers for the current GPU. 92 93 It's safe to call this function if CUDA is not available; in that 94 case, it is silently ignored. 95 96 Args: 97 seed (int): The desired seed. 98 99 .. warning:: 100 If you are working with a multi-GPU model, this function is insufficient 101 to get determinism. To seed all GPUs, use :func:`manual_seed_all`. 102 """ 103 seed = int(seed) 104 105 def cb(): 106 idx = current_device() 107 default_generator = torch.cuda.default_generators[idx] 108 default_generator.manual_seed(seed) 109 110 _lazy_call(cb, seed=True) 111 112 113def manual_seed_all(seed: int) -> None: 114 r"""Set the seed for generating random numbers on all GPUs. 115 116 It's safe to call this function if CUDA is not available; in that 117 case, it is silently ignored. 118 119 Args: 120 seed (int): The desired seed. 121 """ 122 seed = int(seed) 123 124 def cb(): 125 for i in range(device_count()): 126 default_generator = torch.cuda.default_generators[i] 127 default_generator.manual_seed(seed) 128 129 _lazy_call(cb, seed_all=True) 130 131 132def seed() -> None: 133 r"""Set the seed for generating random numbers to a random number for the current GPU. 134 135 It's safe to call this function if CUDA is not available; in that 136 case, it is silently ignored. 137 138 .. warning:: 139 If you are working with a multi-GPU model, this function will only initialize 140 the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. 141 """ 142 143 def cb(): 144 idx = current_device() 145 default_generator = torch.cuda.default_generators[idx] 146 default_generator.seed() 147 148 _lazy_call(cb) 149 150 151def seed_all() -> None: 152 r"""Set the seed for generating random numbers to a random number on all GPUs. 153 154 It's safe to call this function if CUDA is not available; in that 155 case, it is silently ignored. 156 """ 157 158 def cb(): 159 random_seed = 0 160 seeded = False 161 for i in range(device_count()): 162 default_generator = torch.cuda.default_generators[i] 163 if not seeded: 164 default_generator.seed() 165 random_seed = default_generator.initial_seed() 166 seeded = True 167 else: 168 default_generator.manual_seed(random_seed) 169 170 _lazy_call(cb) 171 172 173def initial_seed() -> int: 174 r"""Return the current random seed of the current GPU. 175 176 .. warning:: 177 This function eagerly initializes CUDA. 178 """ 179 _lazy_init() 180 idx = current_device() 181 default_generator = torch.cuda.default_generators[idx] 182 return default_generator.initial_seed() 183