xref: /aosp_15_r20/external/pytorch/docs/source/notes/randomness.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _reproducibility:
2
3Reproducibility
4===============
5
6Completely reproducible results are not guaranteed across PyTorch releases,
7individual commits, or different platforms. Furthermore, results may not be
8reproducible between CPU and GPU executions, even when using identical seeds.
9
10However, there are some steps you can take to limit the number of sources of
11nondeterministic behavior for a specific platform, device, and PyTorch release.
12First, you can control sources of randomness that can cause multiple executions
13of your application to behave differently. Second, you can configure PyTorch
14to avoid using nondeterministic algorithms for some operations, so that multiple
15calls to those operations, given the same inputs, will produce the same result.
16
17.. warning::
18
19    Deterministic operations are often slower than nondeterministic operations, so
20    single-run performance may decrease for your model. However, determinism may
21    save time in development by facilitating experimentation, debugging, and
22    regression testing.
23
24Controlling sources of randomness
25.................................
26
27PyTorch random number generator
28-------------------------------
29You can use :meth:`torch.manual_seed()` to seed the RNG for all devices (both
30CPU and CUDA)::
31
32    import torch
33    torch.manual_seed(0)
34
35Some PyTorch operations may use random numbers internally.
36:meth:`torch.svd_lowrank()` does this, for instance. Consequently, calling it
37multiple times back-to-back with the same input arguments may give different
38results. However, as long as :meth:`torch.manual_seed()` is set to a constant
39at the beginning of an application and all other sources of nondeterminism have
40been eliminated, the same series of random numbers will be generated each time
41the application is run in the same environment.
42
43It is also possible to obtain identical results from an operation that uses
44random numbers by setting :meth:`torch.manual_seed()` to the same value between
45subsequent calls.
46
47Python
48------
49
50For custom operators, you might need to set python seed as well::
51
52    import random
53    random.seed(0)
54
55Random number generators in other libraries
56-------------------------------------------
57If you or any of the libraries you are using rely on NumPy, you can seed the global
58NumPy RNG with::
59
60    import numpy as np
61    np.random.seed(0)
62
63However, some applications and libraries may use NumPy Random Generator objects,
64not the global RNG
65(`<https://numpy.org/doc/stable/reference/random/generator.html>`_), and those will
66need to be seeded consistently as well.
67
68If you are using any other libraries that use random number generators, refer to
69the documentation for those libraries to see how to set consistent seeds for them.
70
71CUDA convolution benchmarking
72-----------------------------
73The cuDNN library, used by CUDA convolution operations, can be a source of nondeterminism
74across multiple executions of an application. When a cuDNN convolution is called with a
75new set of size parameters, an optional feature can run multiple convolution algorithms,
76benchmarking them to find the fastest one. Then, the fastest algorithm will be used
77consistently during the rest of the process for the corresponding set of size parameters.
78Due to benchmarking noise and different hardware, the benchmark may select different
79algorithms on subsequent runs, even on the same machine.
80
81Disabling the benchmarking feature with :code:`torch.backends.cudnn.benchmark = False`
82causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced
83performance.
84
85However, if you do not need reproducibility across multiple executions of your application,
86then performance might improve if the benchmarking feature is enabled with
87:code:`torch.backends.cudnn.benchmark = True`.
88
89Note that this setting is different from the :code:`torch.backends.cudnn.deterministic`
90setting discussed below.
91
92Avoiding nondeterministic algorithms
93....................................
94:meth:`torch.use_deterministic_algorithms` lets you configure PyTorch to use
95deterministic algorithms instead of nondeterministic ones where available, and
96to throw an error if an operation is known to be nondeterministic (and without
97a deterministic alternative).
98
99Please check the documentation for :meth:`torch.use_deterministic_algorithms()`
100for a full list of affected operations. If an operation does not act correctly
101according to the documentation, or if you need a deterministic implementation
102of an operation that does not have one, please submit an issue:
103`<https://github.com/pytorch/pytorch/issues?q=label:%22module:%20determinism%22>`_
104
105For example, running the nondeterministic CUDA implementation of :meth:`torch.Tensor.index_add_`
106will throw an error::
107
108    >>> import torch
109    >>> torch.use_deterministic_algorithms(True)
110    >>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
111    Traceback (most recent call last):
112    File "<stdin>", line 1, in <module>
113    RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
114    'torch.use_deterministic_algorithms(True)'. ...
115
116When :meth:`torch.bmm` is called with sparse-dense CUDA tensors it typically uses a
117nondeterministic algorithm, but when the deterministic flag is turned on, its alternate
118deterministic implementation will be used::
119
120    >>> import torch
121    >>> torch.use_deterministic_algorithms(True)
122    >>> torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(), torch.randn(2, 2, 2).cuda())
123    tensor([[[ 1.1900, -2.3409],
124             [ 0.4796,  0.8003]],
125            [[ 0.1509,  1.8027],
126             [ 0.0333, -1.1444]]], device='cuda:0')
127
128Furthermore, if you are using CUDA tensors, and your CUDA version is 10.2 or greater, you
129should set the environment variable `CUBLAS_WORKSPACE_CONFIG` according to CUDA documentation:
130`<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
131
132CUDA convolution determinism
133----------------------------
134While disabling CUDA convolution benchmarking (discussed above) ensures that
135CUDA selects the same algorithm each time an application is run, that algorithm
136itself may be nondeterministic, unless either
137:code:`torch.use_deterministic_algorithms(True)` or
138:code:`torch.backends.cudnn.deterministic = True` is set. The latter setting
139controls only this behavior, unlike :meth:`torch.use_deterministic_algorithms`
140which will make other PyTorch operations behave deterministically, too.
141
142CUDA RNN and LSTM
143-----------------
144In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
145See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
146
147Filling uninitialized memory
148----------------------------
149Operations like :meth:`torch.empty` and :meth:`torch.Tensor.resize_` can return
150tensors with uninitialized memory that contain undefined values. Using such a
151tensor as an input to another operation is invalid if determinism is required,
152because the output will be nondeterministic. But there is nothing to actually
153prevent such invalid code from being run. So for safety,
154:attr:`torch.utils.deterministic.fill_uninitialized_memory` is set to ``True``
155by default, which will fill the uninitialized memory with a known value if
156:code:`torch.use_deterministic_algorithms(True)` is set. This will to prevent
157the possibility of this kind of nondeterministic behavior.
158
159However, filling uninitialized memory is detrimental to performance. So if your
160program is valid and does not use uninitialized memory as the input to an
161operation, then this setting can be turned off for better performance.
162
163DataLoader
164..........
165
166DataLoader will reseed workers following :ref:`data-loading-randomness` algorithm.
167Use :meth:`worker_init_fn` and `generator` to preserve reproducibility::
168
169    def seed_worker(worker_id):
170        worker_seed = torch.initial_seed() % 2**32
171        numpy.random.seed(worker_seed)
172        random.seed(worker_seed)
173
174    g = torch.Generator()
175    g.manual_seed(0)
176
177    DataLoader(
178        train_dataset,
179        batch_size=batch_size,
180        num_workers=num_workers,
181        worker_init_fn=seed_worker,
182        generator=g,
183    )
184