1.. _amp-examples: 2 3Automatic Mixed Precision examples 4======================================= 5 6.. currentmodule:: torch.amp 7 8Ordinarily, "automatic mixed precision training" means training with 9:class:`torch.autocast` and :class:`torch.amp.GradScaler` together. 10 11Instances of :class:`torch.autocast` enable autocasting for chosen regions. 12Autocasting automatically chooses the precision for operations to improve performance 13while maintaining accuracy. 14 15Instances of :class:`torch.amp.GradScaler` help perform the steps of 16gradient scaling conveniently. Gradient scaling improves convergence for networks with ``float16`` (by default on CUDA and XPU) 17gradients by minimizing gradient underflow, as explained :ref:`here<gradient-scaling>`. 18 19:class:`torch.autocast` and :class:`torch.amp.GradScaler` are modular. 20In the samples below, each is used as its individual documentation suggests. 21 22(Samples here are illustrative. See the 23`Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_ 24for a runnable walkthrough.) 25 26.. contents:: :local: 27 28Typical Mixed Precision Training 29^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 31:: 32 33 # Creates model and optimizer in default precision 34 model = Net().cuda() 35 optimizer = optim.SGD(model.parameters(), ...) 36 37 # Creates a GradScaler once at the beginning of training. 38 scaler = GradScaler() 39 40 for epoch in epochs: 41 for input, target in data: 42 optimizer.zero_grad() 43 44 # Runs the forward pass with autocasting. 45 with autocast(device_type='cuda', dtype=torch.float16): 46 output = model(input) 47 loss = loss_fn(output, target) 48 49 # Scales loss. Calls backward() on scaled loss to create scaled gradients. 50 # Backward passes under autocast are not recommended. 51 # Backward ops run in the same dtype autocast chose for corresponding forward ops. 52 scaler.scale(loss).backward() 53 54 # scaler.step() first unscales the gradients of the optimizer's assigned params. 55 # If these gradients do not contain infs or NaNs, optimizer.step() is then called, 56 # otherwise, optimizer.step() is skipped. 57 scaler.step(optimizer) 58 59 # Updates the scale for next iteration. 60 scaler.update() 61 62.. _working-with-unscaled-gradients: 63 64Working with Unscaled Gradients 65^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 66 67All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect 68the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should 69unscale them first. For example, gradient clipping manipulates a set of gradients such that their global norm 70(see :func:`torch.nn.utils.clip_grad_norm_`) or maximum magnitude (see :func:`torch.nn.utils.clip_grad_value_`) 71is :math:`<=` some user-imposed threshold. If you attempted to clip *without* unscaling, the gradients' norm/maximum 72magnitude would also be scaled, so your requested threshold (which was meant to be the threshold for *unscaled* 73gradients) would be invalid. 74 75``scaler.unscale_(optimizer)`` unscales gradients held by ``optimizer``'s assigned parameters. 76If your model or models contain other parameters that were assigned to another optimizer 77(say ``optimizer2``), you may call ``scaler.unscale_(optimizer2)`` separately to unscale those 78parameters' gradients as well. 79 80Gradient clipping 81----------------- 82 83Calling ``scaler.unscale_(optimizer)`` before clipping enables you to clip unscaled gradients as usual:: 84 85 scaler = GradScaler() 86 87 for epoch in epochs: 88 for input, target in data: 89 optimizer.zero_grad() 90 with autocast(device_type='cuda', dtype=torch.float16): 91 output = model(input) 92 loss = loss_fn(output, target) 93 scaler.scale(loss).backward() 94 95 # Unscales the gradients of optimizer's assigned params in-place 96 scaler.unscale_(optimizer) 97 98 # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 99 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 100 101 # optimizer's gradients are already unscaled, so scaler.step does not unscale them, 102 # although it still skips optimizer.step() if the gradients contain infs or NaNs. 103 scaler.step(optimizer) 104 105 # Updates the scale for next iteration. 106 scaler.update() 107 108``scaler`` records that ``scaler.unscale_(optimizer)`` was already called for this optimizer 109this iteration, so ``scaler.step(optimizer)`` knows not to redundantly unscale gradients before 110(internally) calling ``optimizer.step()``. 111 112.. currentmodule:: torch.amp.GradScaler 113 114.. warning:: 115 :meth:`unscale_<unscale_>` should only be called once per optimizer per :meth:`step<step>` call, 116 and only after all gradients for that optimizer's assigned parameters have been accumulated. 117 Calling :meth:`unscale_<unscale_>` twice for a given optimizer between each :meth:`step<step>` triggers a RuntimeError. 118 119 120Working with Scaled Gradients 121^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 122 123Gradient accumulation 124--------------------- 125 126Gradient accumulation adds gradients over an effective batch of size ``batch_per_iter * iters_to_accumulate`` 127(``* num_procs`` if distributed). The scale should be calibrated for the effective batch, which means inf/NaN checking, 128step skipping if inf/NaN grads are found, and scale updates should occur at effective-batch granularity. 129Also, grads should remain scaled, and the scale factor should remain constant, while grads for a given effective 130batch are accumulated. If grads are unscaled (or the scale factor changes) before accumulation is complete, 131the next backward pass will add scaled grads to unscaled grads (or grads scaled by a different factor) 132after which it's impossible to recover the accumulated unscaled grads :meth:`step<step>` must apply. 133 134Therefore, if you want to :meth:`unscale_<unscale_>` grads (e.g., to allow clipping unscaled grads), 135call :meth:`unscale_<unscale_>` just before :meth:`step<step>`, after all (scaled) grads for the upcoming 136:meth:`step<step>` have been accumulated. Also, only call :meth:`update<update>` at the end of iterations 137where you called :meth:`step<step>` for a full effective batch:: 138 139 scaler = GradScaler() 140 141 for epoch in epochs: 142 for i, (input, target) in enumerate(data): 143 with autocast(device_type='cuda', dtype=torch.float16): 144 output = model(input) 145 loss = loss_fn(output, target) 146 loss = loss / iters_to_accumulate 147 148 # Accumulates scaled gradients. 149 scaler.scale(loss).backward() 150 151 if (i + 1) % iters_to_accumulate == 0: 152 # may unscale_ here if desired (e.g., to allow clipping unscaled gradients) 153 154 scaler.step(optimizer) 155 scaler.update() 156 optimizer.zero_grad() 157 158.. currentmodule:: torch.amp 159 160Gradient penalty 161---------------- 162 163A gradient penalty implementation commonly creates gradients using 164:func:`torch.autograd.grad`, combines them to create the penalty value, 165and adds the penalty value to the loss. 166 167Here's an ordinary example of an L2 penalty without gradient scaling or autocasting:: 168 169 for epoch in epochs: 170 for input, target in data: 171 optimizer.zero_grad() 172 output = model(input) 173 loss = loss_fn(output, target) 174 175 # Creates gradients 176 grad_params = torch.autograd.grad(outputs=loss, 177 inputs=model.parameters(), 178 create_graph=True) 179 180 # Computes the penalty term and adds it to the loss 181 grad_norm = 0 182 for grad in grad_params: 183 grad_norm += grad.pow(2).sum() 184 grad_norm = grad_norm.sqrt() 185 loss = loss + grad_norm 186 187 loss.backward() 188 189 # clip gradients here, if desired 190 191 optimizer.step() 192 193To implement a gradient penalty *with* gradient scaling, the ``outputs`` Tensor(s) 194passed to :func:`torch.autograd.grad` should be scaled. The resulting gradients 195will therefore be scaled, and should be unscaled before being combined to create the 196penalty value. 197 198Also, the penalty term computation is part of the forward pass, and therefore should be 199inside an :class:`autocast` context. 200 201Here's how that looks for the same L2 penalty:: 202 203 scaler = GradScaler() 204 205 for epoch in epochs: 206 for input, target in data: 207 optimizer.zero_grad() 208 with autocast(device_type='cuda', dtype=torch.float16): 209 output = model(input) 210 loss = loss_fn(output, target) 211 212 # Scales the loss for autograd.grad's backward pass, producing scaled_grad_params 213 scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss), 214 inputs=model.parameters(), 215 create_graph=True) 216 217 # Creates unscaled grad_params before computing the penalty. scaled_grad_params are 218 # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_: 219 inv_scale = 1./scaler.get_scale() 220 grad_params = [p * inv_scale for p in scaled_grad_params] 221 222 # Computes the penalty term and adds it to the loss 223 with autocast(device_type='cuda', dtype=torch.float16): 224 grad_norm = 0 225 for grad in grad_params: 226 grad_norm += grad.pow(2).sum() 227 grad_norm = grad_norm.sqrt() 228 loss = loss + grad_norm 229 230 # Applies scaling to the backward call as usual. 231 # Accumulates leaf gradients that are correctly scaled. 232 scaler.scale(loss).backward() 233 234 # may unscale_ here if desired (e.g., to allow clipping unscaled gradients) 235 236 # step() and update() proceed as usual. 237 scaler.step(optimizer) 238 scaler.update() 239 240 241Working with Multiple Models, Losses, and Optimizers 242^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 243 244.. currentmodule:: torch.amp.GradScaler 245 246If your network has multiple losses, you must call :meth:`scaler.scale<scale>` on each of them individually. 247If your network has multiple optimizers, you may call :meth:`scaler.unscale_<unscale_>` on any of them individually, 248and you must call :meth:`scaler.step<step>` on each of them individually. 249 250However, :meth:`scaler.update<update>` should only be called once, 251after all optimizers used this iteration have been stepped:: 252 253 scaler = torch.amp.GradScaler() 254 255 for epoch in epochs: 256 for input, target in data: 257 optimizer0.zero_grad() 258 optimizer1.zero_grad() 259 with autocast(device_type='cuda', dtype=torch.float16): 260 output0 = model0(input) 261 output1 = model1(input) 262 loss0 = loss_fn(2 * output0 + 3 * output1, target) 263 loss1 = loss_fn(3 * output0 - 5 * output1, target) 264 265 # (retain_graph here is unrelated to amp, it's present because in this 266 # example, both backward() calls share some sections of graph.) 267 scaler.scale(loss0).backward(retain_graph=True) 268 scaler.scale(loss1).backward() 269 270 # You can choose which optimizers receive explicit unscaling, if you 271 # want to inspect or modify the gradients of the params they own. 272 scaler.unscale_(optimizer0) 273 274 scaler.step(optimizer0) 275 scaler.step(optimizer1) 276 277 scaler.update() 278 279Each optimizer checks its gradients for infs/NaNs and makes an independent decision 280whether or not to skip the step. This may result in one optimizer skipping the step 281while the other one does not. Since step skipping occurs rarely (every several hundred iterations) 282this should not impede convergence. If you observe poor convergence after adding gradient scaling 283to a multiple-optimizer model, please report a bug. 284 285.. currentmodule:: torch.amp 286 287.. _amp-multigpu: 288 289Working with Multiple GPUs 290^^^^^^^^^^^^^^^^^^^^^^^^^^ 291 292The issues described here only affect :class:`autocast`. :class:`GradScaler`\ 's usage is unchanged. 293 294.. _amp-dataparallel: 295 296DataParallel in a single process 297-------------------------------- 298 299Even if :class:`torch.nn.DataParallel` spawns threads to run the forward pass on each device. 300The autocast state is propagated in each one and the following will work:: 301 302 model = MyModel() 303 dp_model = nn.DataParallel(model) 304 305 # Sets autocast in the main thread 306 with autocast(device_type='cuda', dtype=torch.float16): 307 # dp_model's internal threads will autocast. 308 output = dp_model(input) 309 # loss_fn also autocast 310 loss = loss_fn(output) 311 312DistributedDataParallel, one GPU per process 313-------------------------------------------- 314 315:class:`torch.nn.parallel.DistributedDataParallel`'s documentation recommends one GPU per process for best 316performance. In this case, ``DistributedDataParallel`` does not spawn threads internally, 317so usages of :class:`autocast` and :class:`GradScaler` are not affected. 318 319DistributedDataParallel, multiple GPUs per process 320-------------------------------------------------- 321 322Here :class:`torch.nn.parallel.DistributedDataParallel` may spawn a side thread to run the forward pass on each 323device, like :class:`torch.nn.DataParallel`. :ref:`The fix is the same<amp-dataparallel>`: 324apply autocast as part of your model's ``forward`` method to ensure it's enabled in side threads. 325 326.. _amp-custom-examples: 327 328Autocast and Custom Autograd Functions 329^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 330 331If your network uses :ref:`custom autograd functions<extending-autograd>` 332(subclasses of :class:`torch.autograd.Function`), changes are required for 333autocast compatibility if any function 334 335* takes multiple floating-point Tensor inputs, 336* wraps any autocastable op (see the :ref:`Autocast Op Reference<autocast-op-reference>`), or 337* requires a particular ``dtype`` (for example, if it wraps 338 `CUDA extensions <https://pytorch.org/tutorials/advanced/cpp_extension.html>`_ 339 that were only compiled for ``dtype``). 340 341In all cases, if you're importing the function and can't alter its definition, a safe fallback 342is to disable autocast and force execution in ``float32`` ( or ``dtype``) at any points of use where errors occur:: 343 344 with autocast(device_type='cuda', dtype=torch.float16): 345 ... 346 with autocast(device_type='cuda', dtype=torch.float16, enabled=False): 347 output = imported_function(input1.float(), input2.float()) 348 349If you're the function's author (or can alter its definition) a better solution is to use the 350:func:`torch.amp.custom_fwd` and :func:`torch.amp.custom_bwd` decorators as shown in 351the relevant case below. 352 353Functions with multiple inputs or autocastable ops 354-------------------------------------------------- 355 356Apply :func:`custom_fwd<custom_fwd>` and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``forward`` and 357``backward`` respectively. These ensure ``forward`` executes with the current autocast state and ``backward`` 358executes with the same autocast state as ``forward`` (which can prevent type mismatch errors):: 359 360 class MyMM(torch.autograd.Function): 361 @staticmethod 362 @custom_fwd 363 def forward(ctx, a, b): 364 ctx.save_for_backward(a, b) 365 return a.mm(b) 366 @staticmethod 367 @custom_bwd 368 def backward(ctx, grad): 369 a, b = ctx.saved_tensors 370 return grad.mm(b.t()), a.t().mm(grad) 371 372Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually casting inputs:: 373 374 mymm = MyMM.apply 375 376 with autocast(device_type='cuda', dtype=torch.float16): 377 output = mymm(input1, input2) 378 379Functions that need a particular ``dtype`` 380------------------------------------------ 381 382Consider a custom function that requires ``torch.float32`` inputs. 383Apply :func:`custom_fwd(device_type='cuda', cast_inputs=torch.float32)<custom_fwd>` to ``forward`` 384and :func:`custom_bwd(device_type='cuda')<custom_bwd>` to ``backward``. 385If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point Tensor 386inputs to ``float32`` on designated device assigned by the argument `device_type <../amp.html>`_, 387`CUDA` in this example, and locally disable autocast during ``forward`` and ``backward``:: 388 389 class MyFloat32Func(torch.autograd.Function): 390 @staticmethod 391 @custom_fwd(device_type='cuda', cast_inputs=torch.float32) 392 def forward(ctx, input): 393 ctx.save_for_backward(input) 394 ... 395 return fwd_output 396 @staticmethod 397 @custom_bwd(device_type='cuda') 398 def backward(ctx, grad): 399 ... 400 401Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autocast or casting inputs:: 402 403 func = MyFloat32Func.apply 404 405 with autocast(device_type='cuda', dtype=torch.float16): 406 # func will run in float32, regardless of the surrounding autocast state 407 output = func(input) 408