xref: /aosp_15_r20/external/pytorch/docs/source/notes/multiprocessing.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _multiprocessing-best-practices:
2
3Multiprocessing best practices
4==============================
5
6:mod:`torch.multiprocessing` is a drop in replacement for Python's
7:mod:`python:multiprocessing` module. It supports the exact same operations,
8but extends it, so that all tensors sent through a
9:class:`python:multiprocessing.Queue`, will have their data moved into shared
10memory and will only send a handle to another process.
11
12.. note::
13
14    When a :class:`~torch.Tensor` is sent to another process, the
15    :class:`~torch.Tensor` data is shared. If :attr:`torch.Tensor.grad` is
16    not ``None``, it is also shared. After a :class:`~torch.Tensor` without
17    a :attr:`torch.Tensor.grad` field is sent to the other process, it
18    creates a standard process-specific ``.grad`` :class:`~torch.Tensor` that
19    is not automatically shared across all processes, unlike how the
20    :class:`~torch.Tensor`'s data has been shared.
21
22This allows to implement various training methods, like Hogwild, A3C, or any
23others that require asynchronous operation.
24
25.. _multiprocessing-cuda-note:
26
27CUDA in multiprocessing
28-----------------------
29
30The CUDA runtime does not support the ``fork`` start method; either the ``spawn`` or ``forkserver`` start method are
31required to use CUDA in subprocesses.
32
33.. note::
34  The start method can be set via either creating a context with
35  ``multiprocessing.get_context(...)`` or directly using
36  ``multiprocessing.set_start_method(...)``.
37
38Unlike CPU tensors, the sending process is required to keep the original tensor
39as long as the receiving process retains a copy of the tensor. It is implemented
40under the hood but requires users to follow the best practices for the program
41to run correctly. For example, the sending process must stay alive as long as
42the consumer process has references to the tensor, and the refcounting can not
43save you if the consumer process exits abnormally via a fatal signal. See
44:ref:`this section <multiprocessing-cuda-sharing-details>`.
45
46See also: :ref:`cuda-nn-ddp-instead`
47
48
49Best practices and tips
50-----------------------
51
52Avoiding and fighting deadlocks
53^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
54
55There are a lot of things that can go wrong when a new process is spawned, with
56the most common cause of deadlocks being background threads. If there's any
57thread that holds a lock or imports a module, and ``fork`` is called, it's very
58likely that the subprocess will be in a corrupted state and will deadlock or
59fail in a different way. Note that even if you don't, Python built in
60libraries do - no need to look further than :mod:`python:multiprocessing`.
61:class:`python:multiprocessing.Queue` is actually a very complex class, that
62spawns multiple threads used to serialize, send and receive objects, and they
63can cause aforementioned problems too. If you find yourself in such situation
64try using a :class:`~python:multiprocessing.queues.SimpleQueue`, that doesn't
65use any additional threads.
66
67We're trying our best to make it easy for you and ensure these deadlocks don't
68happen but some things are out of our control. If you have any issues you can't
69cope with for a while, try reaching out on forums, and we'll see if it's an
70issue we can fix.
71
72Reuse buffers passed through a Queue
73^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
74
75Remember that each time you put a :class:`~torch.Tensor` into a
76:class:`python:multiprocessing.Queue`, it has to be moved into shared memory.
77If it's already shared, it is a no-op, otherwise it will incur an additional
78memory copy that can slow down the whole process. Even if you have a pool of
79processes sending data to a single one, make it send the buffers back - this
80is nearly free and will let you avoid a copy when sending next batch.
81
82Asynchronous multiprocess training (e.g. Hogwild)
83^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
84
85Using :mod:`torch.multiprocessing`, it is possible to train a model
86asynchronously, with parameters either shared all the time, or being
87periodically synchronized. In the first case, we recommend sending over the whole
88model object, while in the latter, we advise to only send the
89:meth:`~torch.nn.Module.state_dict`.
90
91We recommend using :class:`python:multiprocessing.Queue` for passing all kinds
92of PyTorch objects between processes. It is possible to e.g. inherit the tensors
93and storages already in shared memory, when using the ``fork`` start method,
94however it is very bug prone and should be used with care, and only by advanced
95users. Queues, even though they're sometimes a less elegant solution, will work
96properly in all cases.
97
98.. warning::
99
100    You should be careful about having global statements, that are not guarded
101    with an ``if __name__ == '__main__'``. If a different start method than
102    ``fork`` is used, they will be executed in all subprocesses.
103
104Hogwild
105~~~~~~~
106
107A concrete Hogwild implementation can be found in the `examples repository`__,
108but to showcase the overall structure of the code, there's also a minimal
109example below as well::
110
111    import torch.multiprocessing as mp
112    from model import MyModel
113
114    def train(model):
115        # Construct data_loader, optimizer, etc.
116        for data, labels in data_loader:
117            optimizer.zero_grad()
118            loss_fn(model(data), labels).backward()
119            optimizer.step()  # This will update the shared parameters
120
121    if __name__ == '__main__':
122        num_processes = 4
123        model = MyModel()
124        # NOTE: this is required for the ``fork`` method to work
125        model.share_memory()
126        processes = []
127        for rank in range(num_processes):
128            p = mp.Process(target=train, args=(model,))
129            p.start()
130            processes.append(p)
131        for p in processes:
132            p.join()
133
134.. __: https://github.com/pytorch/examples/tree/master/mnist_hogwild
135
136
137
138CPU in multiprocessing
139----------------------
140
141Inappropriate multiprocessing can lead to CPU oversubscription, causing
142different processes to compete for CPU resources, resulting in low
143efficiency.
144
145This tutorial will explain what CPU oversubscription is and how to
146avoid it.
147
148CPU oversubscription
149^^^^^^^^^^^^^^^^^^^^
150
151CPU oversubscription is a technical term that refers to a situation
152where the total number of vCPUs allocated to a system exceeds the total
153number of vCPUs available on the hardware.
154
155This leads to severe contention for CPU resources. In such cases, there
156is frequent switching between processes, which increases processes
157switching overhead and decreases overall system efficiency.
158
159See CPU oversubscription with the code examples in the Hogwild
160implementation found in the `example
161repository <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__.
162
163When running the training example with the following command on CPU
164using 4 processes:
165
166.. code-block:: bash
167
168   python main.py --num-processes 4
169
170Assuming there are N vCPUs available on the machine, executing the above
171command will generate 4 subprocesses. Each subprocess will allocate N
172vCPUs for itself, resulting in a requirement of 4*N vCPUs. However, the
173machine only has N vCPUs available. Consequently, the different
174processes will compete for resources, leading to frequent process
175switching.
176
177The following observations indicate the presence of CPU over
178subscription:
179
180#. High CPU Utilization: By using the ``htop`` command, you can observe
181   that the CPU utilization is consistently high, often reaching or
182   exceeding its maximum capacity. This indicates that the demand for
183   CPU resources exceeds the available physical cores, causing
184   contention and competition among processes for CPU time.
185
186#. Frequent Context Switching with Low System Efficiency: In an
187   oversubscribed CPU scenario, processes compete for CPU time, and the
188   operating system needs to rapidly switch between different processes
189   to allocate resources fairly. This frequent context switching adds
190   overhead and reduces the overall system efficiency.
191
192Avoid CPU oversubscription
193^^^^^^^^^^^^^^^^^^^^^^^^^^
194
195A good way to avoid CPU oversubscription is proper resource allocation.
196Ensure that the number of processes or threads running concurrently does
197not exceed the available CPU resources.
198
199In this case, a solution would be to specify the appropriate number of
200threads in the subprocesses. This can be achieved by setting the number
201of threads for each process using the ``torch.set_num_threads(int)``
202function in subprocess.
203
204Assuming there are N vCPUs on the machine and M processes will be
205generated, the maximum ``num_threads`` value used by each process would
206be ``floor(N/M)``. To avoid CPU oversubscription in the mnist_hogwild
207example, the following changes are needed for the file ``train.py`` in
208`example
209repository <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__.
210
211.. code:: python
212
213   def train(rank, args, model, device, dataset, dataloader_kwargs):
214       torch.manual_seed(args.seed + rank)
215
216       #### define the num threads used in current sub-processes
217       torch.set_num_threads(floor(N/M))
218
219       train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
220
221       optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
222       for epoch in range(1, args.epochs + 1):
223           train_epoch(epoch, args, model, device, train_loader, optimizer)
224
225Set ``num_thread`` for each process using
226``torch.set_num_threads(floor(N/M))``. where you replace N with the
227number of vCPUs available and M with the chosen number of processes. The
228appropriate ``num_thread`` value will vary depending on the specific
229task at hand. However, as a general guideline, the maximum value for the
230``num_thread`` should be ``floor(N/M)`` to avoid CPU oversubscription.
231In the `mnist_hogwild <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__ training example, after avoiding CPU over
232subscription, you can achieve a 30x performance boost.
233