1.. _multiprocessing-best-practices: 2 3Multiprocessing best practices 4============================== 5 6:mod:`torch.multiprocessing` is a drop in replacement for Python's 7:mod:`python:multiprocessing` module. It supports the exact same operations, 8but extends it, so that all tensors sent through a 9:class:`python:multiprocessing.Queue`, will have their data moved into shared 10memory and will only send a handle to another process. 11 12.. note:: 13 14 When a :class:`~torch.Tensor` is sent to another process, the 15 :class:`~torch.Tensor` data is shared. If :attr:`torch.Tensor.grad` is 16 not ``None``, it is also shared. After a :class:`~torch.Tensor` without 17 a :attr:`torch.Tensor.grad` field is sent to the other process, it 18 creates a standard process-specific ``.grad`` :class:`~torch.Tensor` that 19 is not automatically shared across all processes, unlike how the 20 :class:`~torch.Tensor`'s data has been shared. 21 22This allows to implement various training methods, like Hogwild, A3C, or any 23others that require asynchronous operation. 24 25.. _multiprocessing-cuda-note: 26 27CUDA in multiprocessing 28----------------------- 29 30The CUDA runtime does not support the ``fork`` start method; either the ``spawn`` or ``forkserver`` start method are 31required to use CUDA in subprocesses. 32 33.. note:: 34 The start method can be set via either creating a context with 35 ``multiprocessing.get_context(...)`` or directly using 36 ``multiprocessing.set_start_method(...)``. 37 38Unlike CPU tensors, the sending process is required to keep the original tensor 39as long as the receiving process retains a copy of the tensor. It is implemented 40under the hood but requires users to follow the best practices for the program 41to run correctly. For example, the sending process must stay alive as long as 42the consumer process has references to the tensor, and the refcounting can not 43save you if the consumer process exits abnormally via a fatal signal. See 44:ref:`this section <multiprocessing-cuda-sharing-details>`. 45 46See also: :ref:`cuda-nn-ddp-instead` 47 48 49Best practices and tips 50----------------------- 51 52Avoiding and fighting deadlocks 53^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 54 55There are a lot of things that can go wrong when a new process is spawned, with 56the most common cause of deadlocks being background threads. If there's any 57thread that holds a lock or imports a module, and ``fork`` is called, it's very 58likely that the subprocess will be in a corrupted state and will deadlock or 59fail in a different way. Note that even if you don't, Python built in 60libraries do - no need to look further than :mod:`python:multiprocessing`. 61:class:`python:multiprocessing.Queue` is actually a very complex class, that 62spawns multiple threads used to serialize, send and receive objects, and they 63can cause aforementioned problems too. If you find yourself in such situation 64try using a :class:`~python:multiprocessing.queues.SimpleQueue`, that doesn't 65use any additional threads. 66 67We're trying our best to make it easy for you and ensure these deadlocks don't 68happen but some things are out of our control. If you have any issues you can't 69cope with for a while, try reaching out on forums, and we'll see if it's an 70issue we can fix. 71 72Reuse buffers passed through a Queue 73^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 74 75Remember that each time you put a :class:`~torch.Tensor` into a 76:class:`python:multiprocessing.Queue`, it has to be moved into shared memory. 77If it's already shared, it is a no-op, otherwise it will incur an additional 78memory copy that can slow down the whole process. Even if you have a pool of 79processes sending data to a single one, make it send the buffers back - this 80is nearly free and will let you avoid a copy when sending next batch. 81 82Asynchronous multiprocess training (e.g. Hogwild) 83^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 84 85Using :mod:`torch.multiprocessing`, it is possible to train a model 86asynchronously, with parameters either shared all the time, or being 87periodically synchronized. In the first case, we recommend sending over the whole 88model object, while in the latter, we advise to only send the 89:meth:`~torch.nn.Module.state_dict`. 90 91We recommend using :class:`python:multiprocessing.Queue` for passing all kinds 92of PyTorch objects between processes. It is possible to e.g. inherit the tensors 93and storages already in shared memory, when using the ``fork`` start method, 94however it is very bug prone and should be used with care, and only by advanced 95users. Queues, even though they're sometimes a less elegant solution, will work 96properly in all cases. 97 98.. warning:: 99 100 You should be careful about having global statements, that are not guarded 101 with an ``if __name__ == '__main__'``. If a different start method than 102 ``fork`` is used, they will be executed in all subprocesses. 103 104Hogwild 105~~~~~~~ 106 107A concrete Hogwild implementation can be found in the `examples repository`__, 108but to showcase the overall structure of the code, there's also a minimal 109example below as well:: 110 111 import torch.multiprocessing as mp 112 from model import MyModel 113 114 def train(model): 115 # Construct data_loader, optimizer, etc. 116 for data, labels in data_loader: 117 optimizer.zero_grad() 118 loss_fn(model(data), labels).backward() 119 optimizer.step() # This will update the shared parameters 120 121 if __name__ == '__main__': 122 num_processes = 4 123 model = MyModel() 124 # NOTE: this is required for the ``fork`` method to work 125 model.share_memory() 126 processes = [] 127 for rank in range(num_processes): 128 p = mp.Process(target=train, args=(model,)) 129 p.start() 130 processes.append(p) 131 for p in processes: 132 p.join() 133 134.. __: https://github.com/pytorch/examples/tree/master/mnist_hogwild 135 136 137 138CPU in multiprocessing 139---------------------- 140 141Inappropriate multiprocessing can lead to CPU oversubscription, causing 142different processes to compete for CPU resources, resulting in low 143efficiency. 144 145This tutorial will explain what CPU oversubscription is and how to 146avoid it. 147 148CPU oversubscription 149^^^^^^^^^^^^^^^^^^^^ 150 151CPU oversubscription is a technical term that refers to a situation 152where the total number of vCPUs allocated to a system exceeds the total 153number of vCPUs available on the hardware. 154 155This leads to severe contention for CPU resources. In such cases, there 156is frequent switching between processes, which increases processes 157switching overhead and decreases overall system efficiency. 158 159See CPU oversubscription with the code examples in the Hogwild 160implementation found in the `example 161repository <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__. 162 163When running the training example with the following command on CPU 164using 4 processes: 165 166.. code-block:: bash 167 168 python main.py --num-processes 4 169 170Assuming there are N vCPUs available on the machine, executing the above 171command will generate 4 subprocesses. Each subprocess will allocate N 172vCPUs for itself, resulting in a requirement of 4*N vCPUs. However, the 173machine only has N vCPUs available. Consequently, the different 174processes will compete for resources, leading to frequent process 175switching. 176 177The following observations indicate the presence of CPU over 178subscription: 179 180#. High CPU Utilization: By using the ``htop`` command, you can observe 181 that the CPU utilization is consistently high, often reaching or 182 exceeding its maximum capacity. This indicates that the demand for 183 CPU resources exceeds the available physical cores, causing 184 contention and competition among processes for CPU time. 185 186#. Frequent Context Switching with Low System Efficiency: In an 187 oversubscribed CPU scenario, processes compete for CPU time, and the 188 operating system needs to rapidly switch between different processes 189 to allocate resources fairly. This frequent context switching adds 190 overhead and reduces the overall system efficiency. 191 192Avoid CPU oversubscription 193^^^^^^^^^^^^^^^^^^^^^^^^^^ 194 195A good way to avoid CPU oversubscription is proper resource allocation. 196Ensure that the number of processes or threads running concurrently does 197not exceed the available CPU resources. 198 199In this case, a solution would be to specify the appropriate number of 200threads in the subprocesses. This can be achieved by setting the number 201of threads for each process using the ``torch.set_num_threads(int)`` 202function in subprocess. 203 204Assuming there are N vCPUs on the machine and M processes will be 205generated, the maximum ``num_threads`` value used by each process would 206be ``floor(N/M)``. To avoid CPU oversubscription in the mnist_hogwild 207example, the following changes are needed for the file ``train.py`` in 208`example 209repository <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__. 210 211.. code:: python 212 213 def train(rank, args, model, device, dataset, dataloader_kwargs): 214 torch.manual_seed(args.seed + rank) 215 216 #### define the num threads used in current sub-processes 217 torch.set_num_threads(floor(N/M)) 218 219 train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs) 220 221 optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 222 for epoch in range(1, args.epochs + 1): 223 train_epoch(epoch, args, model, device, train_loader, optimizer) 224 225Set ``num_thread`` for each process using 226``torch.set_num_threads(floor(N/M))``. where you replace N with the 227number of vCPUs available and M with the chosen number of processes. The 228appropriate ``num_thread`` value will vary depending on the specific 229task at hand. However, as a general guideline, the maximum value for the 230``num_thread`` should be ``floor(N/M)`` to avoid CPU oversubscription. 231In the `mnist_hogwild <https://github.com/pytorch/examples/tree/main/mnist_hogwild>`__ training example, after avoiding CPU over 232subscription, you can achieve a 30x performance boost. 233