xref: /aosp_15_r20/external/pytorch/torch/cuda/random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Iterable, List, Union
3
4import torch
5from torch import Tensor
6
7from . import _lazy_call, _lazy_init, current_device, device_count
8
9
10__all__ = [
11    "get_rng_state",
12    "get_rng_state_all",
13    "set_rng_state",
14    "set_rng_state_all",
15    "manual_seed",
16    "manual_seed_all",
17    "seed",
18    "seed_all",
19    "initial_seed",
20]
21
22
23def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
24    r"""Return the random number generator state of the specified GPU as a ByteTensor.
25
26    Args:
27        device (torch.device or int, optional): The device to return the RNG state of.
28            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
29
30    .. warning::
31        This function eagerly initializes CUDA.
32    """
33    _lazy_init()
34    if isinstance(device, str):
35        device = torch.device(device)
36    elif isinstance(device, int):
37        device = torch.device("cuda", device)
38    idx = device.index
39    if idx is None:
40        idx = current_device()
41    default_generator = torch.cuda.default_generators[idx]
42    return default_generator.get_state()
43
44
45def get_rng_state_all() -> List[Tensor]:
46    r"""Return a list of ByteTensor representing the random number states of all devices."""
47    results = []
48    for i in range(device_count()):
49        results.append(get_rng_state(i))
50    return results
51
52
53def set_rng_state(
54    new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
55) -> None:
56    r"""Set the random number generator state of the specified GPU.
57
58    Args:
59        new_state (torch.ByteTensor): The desired state
60        device (torch.device or int, optional): The device to set the RNG state.
61            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
62    """
63    with torch._C._DisableFuncTorch():
64        new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
65    if isinstance(device, str):
66        device = torch.device(device)
67    elif isinstance(device, int):
68        device = torch.device("cuda", device)
69
70    def cb():
71        idx = device.index
72        if idx is None:
73            idx = current_device()
74        default_generator = torch.cuda.default_generators[idx]
75        default_generator.set_state(new_state_copy)
76
77    _lazy_call(cb)
78
79
80def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
81    r"""Set the random number generator state of all devices.
82
83    Args:
84        new_states (Iterable of torch.ByteTensor): The desired state for each device.
85    """
86    for i, state in enumerate(new_states):
87        set_rng_state(state, i)
88
89
90def manual_seed(seed: int) -> None:
91    r"""Set the seed for generating random numbers for the current GPU.
92
93    It's safe to call this function if CUDA is not available; in that
94    case, it is silently ignored.
95
96    Args:
97        seed (int): The desired seed.
98
99    .. warning::
100        If you are working with a multi-GPU model, this function is insufficient
101        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
102    """
103    seed = int(seed)
104
105    def cb():
106        idx = current_device()
107        default_generator = torch.cuda.default_generators[idx]
108        default_generator.manual_seed(seed)
109
110    _lazy_call(cb, seed=True)
111
112
113def manual_seed_all(seed: int) -> None:
114    r"""Set the seed for generating random numbers on all GPUs.
115
116    It's safe to call this function if CUDA is not available; in that
117    case, it is silently ignored.
118
119    Args:
120        seed (int): The desired seed.
121    """
122    seed = int(seed)
123
124    def cb():
125        for i in range(device_count()):
126            default_generator = torch.cuda.default_generators[i]
127            default_generator.manual_seed(seed)
128
129    _lazy_call(cb, seed_all=True)
130
131
132def seed() -> None:
133    r"""Set the seed for generating random numbers to a random number for the current GPU.
134
135    It's safe to call this function if CUDA is not available; in that
136    case, it is silently ignored.
137
138    .. warning::
139        If you are working with a multi-GPU model, this function will only initialize
140        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
141    """
142
143    def cb():
144        idx = current_device()
145        default_generator = torch.cuda.default_generators[idx]
146        default_generator.seed()
147
148    _lazy_call(cb)
149
150
151def seed_all() -> None:
152    r"""Set the seed for generating random numbers to a random number on all GPUs.
153
154    It's safe to call this function if CUDA is not available; in that
155    case, it is silently ignored.
156    """
157
158    def cb():
159        random_seed = 0
160        seeded = False
161        for i in range(device_count()):
162            default_generator = torch.cuda.default_generators[i]
163            if not seeded:
164                default_generator.seed()
165                random_seed = default_generator.initial_seed()
166                seeded = True
167            else:
168                default_generator.manual_seed(random_seed)
169
170    _lazy_call(cb)
171
172
173def initial_seed() -> int:
174    r"""Return the current random seed of the current GPU.
175
176    .. warning::
177        This function eagerly initializes CUDA.
178    """
179    _lazy_init()
180    idx = current_device()
181    default_generator = torch.cuda.default_generators[idx]
182    return default_generator.initial_seed()
183