1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import Iterable, List, Union 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom . import _lazy_call, _lazy_init, current_device, device_count 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = [ 11*da0073e9SAndroid Build Coastguard Worker "get_rng_state", 12*da0073e9SAndroid Build Coastguard Worker "get_rng_state_all", 13*da0073e9SAndroid Build Coastguard Worker "set_rng_state", 14*da0073e9SAndroid Build Coastguard Worker "set_rng_state_all", 15*da0073e9SAndroid Build Coastguard Worker "manual_seed", 16*da0073e9SAndroid Build Coastguard Worker "manual_seed_all", 17*da0073e9SAndroid Build Coastguard Worker "seed", 18*da0073e9SAndroid Build Coastguard Worker "seed_all", 19*da0073e9SAndroid Build Coastguard Worker "initial_seed", 20*da0073e9SAndroid Build Coastguard Worker] 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerdef get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: 24*da0073e9SAndroid Build Coastguard Worker r"""Return the random number generator state of the specified GPU as a ByteTensor. 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker Args: 27*da0073e9SAndroid Build Coastguard Worker device (torch.device or int, optional): The device to return the RNG state of. 28*da0073e9SAndroid Build Coastguard Worker Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker .. warning:: 31*da0073e9SAndroid Build Coastguard Worker This function eagerly initializes CUDA. 32*da0073e9SAndroid Build Coastguard Worker """ 33*da0073e9SAndroid Build Coastguard Worker _lazy_init() 34*da0073e9SAndroid Build Coastguard Worker if isinstance(device, str): 35*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 36*da0073e9SAndroid Build Coastguard Worker elif isinstance(device, int): 37*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda", device) 38*da0073e9SAndroid Build Coastguard Worker idx = device.index 39*da0073e9SAndroid Build Coastguard Worker if idx is None: 40*da0073e9SAndroid Build Coastguard Worker idx = current_device() 41*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[idx] 42*da0073e9SAndroid Build Coastguard Worker return default_generator.get_state() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerdef get_rng_state_all() -> List[Tensor]: 46*da0073e9SAndroid Build Coastguard Worker r"""Return a list of ByteTensor representing the random number states of all devices.""" 47*da0073e9SAndroid Build Coastguard Worker results = [] 48*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 49*da0073e9SAndroid Build Coastguard Worker results.append(get_rng_state(i)) 50*da0073e9SAndroid Build Coastguard Worker return results 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef set_rng_state( 54*da0073e9SAndroid Build Coastguard Worker new_state: Tensor, device: Union[int, str, torch.device] = "cuda" 55*da0073e9SAndroid Build Coastguard Worker) -> None: 56*da0073e9SAndroid Build Coastguard Worker r"""Set the random number generator state of the specified GPU. 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker Args: 59*da0073e9SAndroid Build Coastguard Worker new_state (torch.ByteTensor): The desired state 60*da0073e9SAndroid Build Coastguard Worker device (torch.device or int, optional): The device to set the RNG state. 61*da0073e9SAndroid Build Coastguard Worker Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 62*da0073e9SAndroid Build Coastguard Worker """ 63*da0073e9SAndroid Build Coastguard Worker with torch._C._DisableFuncTorch(): 64*da0073e9SAndroid Build Coastguard Worker new_state_copy = new_state.clone(memory_format=torch.contiguous_format) 65*da0073e9SAndroid Build Coastguard Worker if isinstance(device, str): 66*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 67*da0073e9SAndroid Build Coastguard Worker elif isinstance(device, int): 68*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda", device) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def cb(): 71*da0073e9SAndroid Build Coastguard Worker idx = device.index 72*da0073e9SAndroid Build Coastguard Worker if idx is None: 73*da0073e9SAndroid Build Coastguard Worker idx = current_device() 74*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[idx] 75*da0073e9SAndroid Build Coastguard Worker default_generator.set_state(new_state_copy) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef set_rng_state_all(new_states: Iterable[Tensor]) -> None: 81*da0073e9SAndroid Build Coastguard Worker r"""Set the random number generator state of all devices. 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker Args: 84*da0073e9SAndroid Build Coastguard Worker new_states (Iterable of torch.ByteTensor): The desired state for each device. 85*da0073e9SAndroid Build Coastguard Worker """ 86*da0073e9SAndroid Build Coastguard Worker for i, state in enumerate(new_states): 87*da0073e9SAndroid Build Coastguard Worker set_rng_state(state, i) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Workerdef manual_seed(seed: int) -> None: 91*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers for the current GPU. 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if CUDA is not available; in that 94*da0073e9SAndroid Build Coastguard Worker case, it is silently ignored. 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker Args: 97*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker .. warning:: 100*da0073e9SAndroid Build Coastguard Worker If you are working with a multi-GPU model, this function is insufficient 101*da0073e9SAndroid Build Coastguard Worker to get determinism. To seed all GPUs, use :func:`manual_seed_all`. 102*da0073e9SAndroid Build Coastguard Worker """ 103*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker def cb(): 106*da0073e9SAndroid Build Coastguard Worker idx = current_device() 107*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[idx] 108*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(seed) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb, seed=True) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerdef manual_seed_all(seed: int) -> None: 114*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers on all GPUs. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if CUDA is not available; in that 117*da0073e9SAndroid Build Coastguard Worker case, it is silently ignored. 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker Args: 120*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. 121*da0073e9SAndroid Build Coastguard Worker """ 122*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def cb(): 125*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 126*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[i] 127*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(seed) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb, seed_all=True) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Workerdef seed() -> None: 133*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers to a random number for the current GPU. 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if CUDA is not available; in that 136*da0073e9SAndroid Build Coastguard Worker case, it is silently ignored. 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker .. warning:: 139*da0073e9SAndroid Build Coastguard Worker If you are working with a multi-GPU model, this function will only initialize 140*da0073e9SAndroid Build Coastguard Worker the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. 141*da0073e9SAndroid Build Coastguard Worker """ 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker def cb(): 144*da0073e9SAndroid Build Coastguard Worker idx = current_device() 145*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[idx] 146*da0073e9SAndroid Build Coastguard Worker default_generator.seed() 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Workerdef seed_all() -> None: 152*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers to a random number on all GPUs. 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if CUDA is not available; in that 155*da0073e9SAndroid Build Coastguard Worker case, it is silently ignored. 156*da0073e9SAndroid Build Coastguard Worker """ 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def cb(): 159*da0073e9SAndroid Build Coastguard Worker random_seed = 0 160*da0073e9SAndroid Build Coastguard Worker seeded = False 161*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 162*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[i] 163*da0073e9SAndroid Build Coastguard Worker if not seeded: 164*da0073e9SAndroid Build Coastguard Worker default_generator.seed() 165*da0073e9SAndroid Build Coastguard Worker random_seed = default_generator.initial_seed() 166*da0073e9SAndroid Build Coastguard Worker seeded = True 167*da0073e9SAndroid Build Coastguard Worker else: 168*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(random_seed) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Workerdef initial_seed() -> int: 174*da0073e9SAndroid Build Coastguard Worker r"""Return the current random seed of the current GPU. 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker .. warning:: 177*da0073e9SAndroid Build Coastguard Worker This function eagerly initializes CUDA. 178*da0073e9SAndroid Build Coastguard Worker """ 179*da0073e9SAndroid Build Coastguard Worker _lazy_init() 180*da0073e9SAndroid Build Coastguard Worker idx = current_device() 181*da0073e9SAndroid Build Coastguard Worker default_generator = torch.cuda.default_generators[idx] 182*da0073e9SAndroid Build Coastguard Worker return default_generator.initial_seed() 183