xref: /aosp_15_r20/external/pytorch/docs/source/notes/amp_examples.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _amp-examples:
2
3Automatic Mixed Precision examples
4=======================================
5
6.. currentmodule:: torch.amp
7
8Ordinarily, "automatic mixed precision training" means training with
9:class:`torch.autocast` and :class:`torch.amp.GradScaler` together.
10
11Instances of :class:`torch.autocast` enable autocasting for chosen regions.
12Autocasting automatically chooses the precision for operations to improve performance
13while maintaining accuracy.
14
15Instances of :class:`torch.amp.GradScaler` help perform the steps of
16gradient scaling conveniently.  Gradient scaling improves convergence for networks with ``float16`` (by default on CUDA and XPU)
17gradients by minimizing gradient underflow, as explained :ref:`here<gradient-scaling>`.
18
19:class:`torch.autocast` and :class:`torch.amp.GradScaler` are modular.
20In the samples below, each is used as its individual documentation suggests.
21
22(Samples here are illustrative.  See the
23`Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_
24for a runnable walkthrough.)
25
26.. contents:: :local:
27
28Typical Mixed Precision Training
29^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
30
31::
32
33    # Creates model and optimizer in default precision
34    model = Net().cuda()
35    optimizer = optim.SGD(model.parameters(), ...)
36
37    # Creates a GradScaler once at the beginning of training.
38    scaler = GradScaler()
39
40    for epoch in epochs:
41        for input, target in data:
42            optimizer.zero_grad()
43
44            # Runs the forward pass with autocasting.
45            with autocast(device_type='cuda', dtype=torch.float16):
46                output = model(input)
47                loss = loss_fn(output, target)
48
49            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
50            # Backward passes under autocast are not recommended.
51            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
52            scaler.scale(loss).backward()
53
54            # scaler.step() first unscales the gradients of the optimizer's assigned params.
55            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
56            # otherwise, optimizer.step() is skipped.
57            scaler.step(optimizer)
58
59            # Updates the scale for next iteration.
60            scaler.update()
61
62.. _working-with-unscaled-gradients:
63
64Working with Unscaled Gradients
65^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66
67All gradients produced by ``scaler.scale(loss).backward()`` are scaled.  If you wish to modify or inspect
68the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``,  you should
69unscale them first.  For example, gradient clipping manipulates a set of gradients such that their global norm
70(see :func:`torch.nn.utils.clip_grad_norm_`) or maximum magnitude (see :func:`torch.nn.utils.clip_grad_value_`)
71is :math:`<=` some user-imposed threshold.  If you attempted to clip *without* unscaling, the gradients' norm/maximum
72magnitude would also be scaled, so your requested threshold (which was meant to be the threshold for *unscaled*
73gradients) would be invalid.
74
75``scaler.unscale_(optimizer)`` unscales gradients held by ``optimizer``'s assigned parameters.
76If your model or models contain other parameters that were assigned to another optimizer
77(say ``optimizer2``), you may call ``scaler.unscale_(optimizer2)`` separately to unscale those
78parameters' gradients as well.
79
80Gradient clipping
81-----------------
82
83Calling ``scaler.unscale_(optimizer)`` before clipping enables you to clip unscaled gradients as usual::
84
85    scaler = GradScaler()
86
87    for epoch in epochs:
88        for input, target in data:
89            optimizer.zero_grad()
90            with autocast(device_type='cuda', dtype=torch.float16):
91                output = model(input)
92                loss = loss_fn(output, target)
93            scaler.scale(loss).backward()
94
95            # Unscales the gradients of optimizer's assigned params in-place
96            scaler.unscale_(optimizer)
97
98            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
99            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
100
101            # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
102            # although it still skips optimizer.step() if the gradients contain infs or NaNs.
103            scaler.step(optimizer)
104
105            # Updates the scale for next iteration.
106            scaler.update()
107
108``scaler`` records that ``scaler.unscale_(optimizer)`` was already called for this optimizer
109this iteration, so ``scaler.step(optimizer)`` knows not to redundantly unscale gradients before
110(internally) calling ``optimizer.step()``.
111
112.. currentmodule:: torch.amp.GradScaler
113
114.. warning::
115    :meth:`unscale_<unscale_>` should only be called once per optimizer per :meth:`step<step>` call,
116    and only after all gradients for that optimizer's assigned parameters have been accumulated.
117    Calling :meth:`unscale_<unscale_>` twice for a given optimizer between each :meth:`step<step>` triggers a RuntimeError.
118
119
120Working with Scaled Gradients
121^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
122
123Gradient accumulation
124---------------------
125
126Gradient accumulation adds gradients over an effective batch of size ``batch_per_iter * iters_to_accumulate``
127(``* num_procs`` if distributed).  The scale should be calibrated for the effective batch, which means inf/NaN checking,
128step skipping if inf/NaN grads are found, and scale updates should occur at effective-batch granularity.
129Also, grads should remain scaled, and the scale factor should remain constant, while grads for a given effective
130batch are accumulated.  If grads are unscaled (or the scale factor changes) before accumulation is complete,
131the next backward pass will add scaled grads to unscaled grads (or grads scaled by a different factor)
132after which it's impossible to recover the accumulated unscaled grads :meth:`step<step>` must apply.
133
134Therefore, if you want to :meth:`unscale_<unscale_>` grads (e.g., to allow clipping unscaled grads),
135call :meth:`unscale_<unscale_>` just before :meth:`step<step>`, after all (scaled) grads for the upcoming
136:meth:`step<step>` have been accumulated.  Also, only call :meth:`update<update>` at the end of iterations
137where you called :meth:`step<step>` for a full effective batch::
138
139    scaler = GradScaler()
140
141    for epoch in epochs:
142        for i, (input, target) in enumerate(data):
143            with autocast(device_type='cuda', dtype=torch.float16):
144                output = model(input)
145                loss = loss_fn(output, target)
146                loss = loss / iters_to_accumulate
147
148            # Accumulates scaled gradients.
149            scaler.scale(loss).backward()
150
151            if (i + 1) % iters_to_accumulate == 0:
152                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
153
154                scaler.step(optimizer)
155                scaler.update()
156                optimizer.zero_grad()
157
158.. currentmodule:: torch.amp
159
160Gradient penalty
161----------------
162
163A gradient penalty implementation commonly creates gradients using
164:func:`torch.autograd.grad`, combines them to create the penalty value,
165and adds the penalty value to the loss.
166
167Here's an ordinary example of an L2 penalty without gradient scaling or autocasting::
168
169    for epoch in epochs:
170        for input, target in data:
171            optimizer.zero_grad()
172            output = model(input)
173            loss = loss_fn(output, target)
174
175            # Creates gradients
176            grad_params = torch.autograd.grad(outputs=loss,
177                                              inputs=model.parameters(),
178                                              create_graph=True)
179
180            # Computes the penalty term and adds it to the loss
181            grad_norm = 0
182            for grad in grad_params:
183                grad_norm += grad.pow(2).sum()
184            grad_norm = grad_norm.sqrt()
185            loss = loss + grad_norm
186
187            loss.backward()
188
189            # clip gradients here, if desired
190
191            optimizer.step()
192
193To implement a gradient penalty *with* gradient scaling, the ``outputs`` Tensor(s)
194passed to :func:`torch.autograd.grad` should be scaled.  The resulting gradients
195will therefore be scaled, and should be unscaled before being combined to create the
196penalty value.
197
198Also, the penalty term computation is part of the forward pass, and therefore should be
199inside an :class:`autocast` context.
200
201Here's how that looks for the same L2 penalty::
202
203    scaler = GradScaler()
204
205    for epoch in epochs:
206        for input, target in data:
207            optimizer.zero_grad()
208            with autocast(device_type='cuda', dtype=torch.float16):
209                output = model(input)
210                loss = loss_fn(output, target)
211
212            # Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
213            scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
214                                                     inputs=model.parameters(),
215                                                     create_graph=True)
216
217            # Creates unscaled grad_params before computing the penalty. scaled_grad_params are
218            # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
219            inv_scale = 1./scaler.get_scale()
220            grad_params = [p * inv_scale for p in scaled_grad_params]
221
222            # Computes the penalty term and adds it to the loss
223            with autocast(device_type='cuda', dtype=torch.float16):
224                grad_norm = 0
225                for grad in grad_params:
226                    grad_norm += grad.pow(2).sum()
227                grad_norm = grad_norm.sqrt()
228                loss = loss + grad_norm
229
230            # Applies scaling to the backward call as usual.
231            # Accumulates leaf gradients that are correctly scaled.
232            scaler.scale(loss).backward()
233
234            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
235
236            # step() and update() proceed as usual.
237            scaler.step(optimizer)
238            scaler.update()
239
240
241Working with Multiple Models, Losses, and Optimizers
242^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
243
244.. currentmodule:: torch.amp.GradScaler
245
246If your network has multiple losses, you must call :meth:`scaler.scale<scale>` on each of them individually.
247If your network has multiple optimizers, you may call :meth:`scaler.unscale_<unscale_>` on any of them individually,
248and you must call :meth:`scaler.step<step>` on each of them individually.
249
250However, :meth:`scaler.update<update>` should only be called once,
251after all optimizers used this iteration have been stepped::
252
253    scaler = torch.amp.GradScaler()
254
255    for epoch in epochs:
256        for input, target in data:
257            optimizer0.zero_grad()
258            optimizer1.zero_grad()
259            with autocast(device_type='cuda', dtype=torch.float16):
260                output0 = model0(input)
261                output1 = model1(input)
262                loss0 = loss_fn(2 * output0 + 3 * output1, target)
263                loss1 = loss_fn(3 * output0 - 5 * output1, target)
264
265            # (retain_graph here is unrelated to amp, it's present because in this
266            # example, both backward() calls share some sections of graph.)
267            scaler.scale(loss0).backward(retain_graph=True)
268            scaler.scale(loss1).backward()
269
270            # You can choose which optimizers receive explicit unscaling, if you
271            # want to inspect or modify the gradients of the params they own.
272            scaler.unscale_(optimizer0)
273
274            scaler.step(optimizer0)
275            scaler.step(optimizer1)
276
277            scaler.update()
278
279Each optimizer checks its gradients for infs/NaNs and makes an independent decision
280whether or not to skip the step.  This may result in one optimizer skipping the step
281while the other one does not.  Since step skipping occurs rarely (every several hundred iterations)
282this should not impede convergence.  If you observe poor convergence after adding gradient scaling
283to a multiple-optimizer model, please report a bug.
284
285.. currentmodule:: torch.amp
286
287.. _amp-multigpu:
288
289Working with Multiple GPUs
290^^^^^^^^^^^^^^^^^^^^^^^^^^
291
292The issues described here only affect :class:`autocast`.  :class:`GradScaler`\ 's usage is unchanged.
293
294.. _amp-dataparallel:
295
296DataParallel in a single process
297--------------------------------
298
299Even if :class:`torch.nn.DataParallel` spawns threads to run the forward pass on each device.
300The autocast state is propagated in each one and the following will work::
301
302    model = MyModel()
303    dp_model = nn.DataParallel(model)
304
305    # Sets autocast in the main thread
306    with autocast(device_type='cuda', dtype=torch.float16):
307        # dp_model's internal threads will autocast.
308        output = dp_model(input)
309        # loss_fn also autocast
310        loss = loss_fn(output)
311
312DistributedDataParallel, one GPU per process
313--------------------------------------------
314
315:class:`torch.nn.parallel.DistributedDataParallel`'s documentation recommends one GPU per process for best
316performance.  In this case, ``DistributedDataParallel`` does not spawn threads internally,
317so usages of :class:`autocast` and :class:`GradScaler` are not affected.
318
319DistributedDataParallel, multiple GPUs per process
320--------------------------------------------------
321
322Here :class:`torch.nn.parallel.DistributedDataParallel` may spawn a side thread to run the forward pass on each
323device, like :class:`torch.nn.DataParallel`.  :ref:`The fix is the same<amp-dataparallel>`:
324apply autocast as part of your model's ``forward`` method to ensure it's enabled in side threads.
325
326.. _amp-custom-examples:
327
328Autocast and Custom Autograd Functions
329^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
330
331If your network uses :ref:`custom autograd functions<extending-autograd>`
332(subclasses of :class:`torch.autograd.Function`), changes are required for
333autocast compatibility if any function
334
335* takes multiple floating-point Tensor inputs,
336* wraps any autocastable op (see the :ref:`Autocast Op Reference<autocast-op-reference>`), or
337* requires a particular ``dtype`` (for example, if it wraps
338  `CUDA extensions <https://pytorch.org/tutorials/advanced/cpp_extension.html>`_
339  that were only compiled for ``dtype``).
340
341In all cases, if you're importing the function and can't alter its definition, a safe fallback
342is to disable autocast and force execution in ``float32`` ( or ``dtype``) at any points of use where errors occur::
343
344    with autocast(device_type='cuda', dtype=torch.float16):
345        ...
346        with autocast(device_type='cuda', dtype=torch.float16, enabled=False):
347            output = imported_function(input1.float(), input2.float())
348
349If you're the function's author (or can alter its definition) a better solution is to use the
350:func:`torch.amp.custom_fwd` and :func:`torch.amp.custom_bwd` decorators as shown in
351the relevant case below.
352
353Functions with multiple inputs or autocastable ops
354--------------------------------------------------
355
356Apply :func:`custom_fwd<custom_fwd>` and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``forward`` and
357``backward`` respectively.  These ensure ``forward`` executes with the current autocast state and ``backward``
358executes with the same autocast state as ``forward`` (which can prevent type mismatch errors)::
359
360    class MyMM(torch.autograd.Function):
361        @staticmethod
362        @custom_fwd
363        def forward(ctx, a, b):
364            ctx.save_for_backward(a, b)
365            return a.mm(b)
366        @staticmethod
367        @custom_bwd
368        def backward(ctx, grad):
369            a, b = ctx.saved_tensors
370            return grad.mm(b.t()), a.t().mm(grad)
371
372Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually casting inputs::
373
374    mymm = MyMM.apply
375
376    with autocast(device_type='cuda', dtype=torch.float16):
377        output = mymm(input1, input2)
378
379Functions that need a particular ``dtype``
380------------------------------------------
381
382Consider a custom function that requires ``torch.float32`` inputs.
383Apply :func:`custom_fwd(device_type='cuda', cast_inputs=torch.float32)<custom_fwd>` to ``forward``
384and :func:`custom_bwd(device_type='cuda')<custom_bwd>` to ``backward``.
385If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point Tensor
386inputs to ``float32`` on designated device assigned by the argument `device_type <../amp.html>`_,
387`CUDA` in this example, and locally disable autocast during ``forward`` and ``backward``::
388
389    class MyFloat32Func(torch.autograd.Function):
390        @staticmethod
391        @custom_fwd(device_type='cuda', cast_inputs=torch.float32)
392        def forward(ctx, input):
393            ctx.save_for_backward(input)
394            ...
395            return fwd_output
396        @staticmethod
397        @custom_bwd(device_type='cuda')
398        def backward(ctx, grad):
399            ...
400
401Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autocast or casting inputs::
402
403    func = MyFloat32Func.apply
404
405    with autocast(device_type='cuda', dtype=torch.float16):
406        # func will run in float32, regardless of the surrounding autocast state
407        output = func(input)
408