1# mypy: allow-untyped-defs 2import torch 3 4 5def convert_conv2d_weight_memory_format(module, memory_format): 6 r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``. 7 8 The conversion recursively applies to nested ``nn.Module``, including ``module``. 9 Note that it only changes the memory_format, but not the semantics of each dimensions. 10 This function is used to facilitate the computation to adopt NHWC kernels, which 11 provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 12 13 .. note:: 14 Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive 15 than the utility function ``convert_conv2d_weight_memory_format``. Any 16 layer with 4d weight will be affected by ``model.to``, which does not 17 necessarily benefit from conversion to specified ``memory_format``. 18 One place we are confident in is that NHWC(channels_last) conversion for 19 convolution in cuDNN, as it is beneficial to run convolution in NHWC, 20 even in cases where we have to apply permutation to input tensors. 21 22 Hence our strategy here is to convert only the weight of convolution to 23 channels_last. This ensures that; 24 1. Fast convolution kernels will be used, the benefit of which could 25 outweigh overhead of permutation (if input is not in the same format). 26 2. No unnecessary permutations are applied on layers that do not benefit 27 from memory_format conversion. 28 29 The optimal case is that, layers between convolution layers are channels 30 last compatible. Input tensor would be permuted to channels last when it 31 encounters the first convolution layer and stay in that memory format. 32 Hence following convolutions will not need to permute its input tensor. 33 34 In case where a channels last incompatible layer is between convolution 35 layers, we need to permute the input tensor back to contiguous format 36 for that layer. The input tensor will go through the remaining layers in 37 contiguous format and be permuted to channels last when it encounters 38 another convolution layer. There's no point in propagating that 39 permutation to an earlier layer, as most layers are quite agnostic to 40 ``memory_format``. 41 42 This claim might change when PyTorch supports fusion of permutation, as 43 there might have been a better spot to fuse the permutation other than 44 immediately before a convolution. 45 46 Args: 47 module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container 48 ``nn.Module`` 49 memory_format: user specified ``memory_format``, 50 e.g. ``torch.channels_last`` or ``torch.contiguous_format`` 51 52 Returns: 53 The original module with updated ``nn.Conv2d`` 54 55 Example: 56 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 57 >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) 58 >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") 59 >>> model = nn.Sequential( 60 >>> nn.Conv2d(8, 4, 3)).cuda().half() 61 >>> # This is identical to: 62 >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) 63 >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) 64 >>> out = model(input) 65 """ 66 # TODO: expand this to `_ConvNd` when channels_last support is extended 67 # beyond only 4d tensors. 68 if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): 69 weight_data = ( 70 module.weight.detach().clone().contiguous(memory_format=memory_format) 71 ) 72 module.weight.data = weight_data.resize_( 73 weight_data.size(), memory_format=memory_format 74 ) 75 for child in module.children(): 76 convert_conv2d_weight_memory_format(child, memory_format) 77 return module 78 79 80def convert_conv3d_weight_memory_format(module, memory_format): 81 r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` 82 The conversion recursively applies to nested ``nn.Module``, including ``module``. 83 Note that it only changes the memory_format, but not the semantics of each dimensions. 84 This function is used to facilitate the computation to adopt NHWC kernels, which 85 provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 86 87 .. note:: 88 Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive 89 than the utility function ``convert_conv3d_weight_memory_format``. Any 90 layer with 4d weight will be affected by ``model.to``, which does not 91 necessarily benefit from conversion to specified ``memory_format``. 92 One place we are confident in is that NDHWC(channels_last_3d) conversion for 93 convolution in cuDNN, as it is beneficial to run convolution in NDHWC, 94 even in cases where we have to apply permutation to input tensors. 95 96 Hence our strategy here is to convert only the weight of convolution to 97 channels_last_3d. This ensures that; 98 1. Fast convolution kernels will be used, the benefit of which could 99 outweigh overhead of permutation (if input is not in the same format). 100 2. No unnecessary permutations are applied on layers that do not benefit 101 from memory_format conversion. 102 103 The optimal case is that, layers between convolution layers are channels 104 last compatible. Input tensor would be permuted to channels last when it 105 encounters the first convolution layer and stay in that memory format. 106 Hence following convolutions will not need to permute its input tensor. 107 108 In case where a channels last incompatible layer is between convolution 109 layers, we need to permute the input tensor back to contiguous format 110 for that layer. The input tensor will go through the remaining layers in 111 contiguous format and be permuted to channels last when it encounters 112 another convolution layer. There's no point in propagating that 113 permutation to an earlier layer, as most layers are quite agnostic to 114 ``memory_format``. 115 116 This claim might change when PyTorch supports fusion of permutation, as 117 there might have been a better spot to fuse the permutation other than 118 immediately before a convolution. 119 120 Args: 121 module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container 122 ``nn.Module`` 123 memory_format: user specified ``memory_format``, 124 e.g. ``torch.channels_last`` or ``torch.contiguous_format`` 125 126 Returns: 127 The original module with updated ``nn.Conv3d`` 128 129 Example: 130 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 131 >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) 132 >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") 133 >>> model = nn.Sequential( 134 >>> nn.Conv3d(8, 4, 3)).cuda().half() 135 >>> # This is identical to: 136 >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) 137 >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) 138 >>> out = model(input) 139 """ 140 141 # TODO: expand this to `_ConvNd` when channels_last support is extended 142 # beyond only 4d tensors. 143 if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): 144 weight_data = ( 145 module.weight.detach().clone().contiguous(memory_format=memory_format) 146 ) 147 module.weight.data = weight_data.resize_( 148 weight_data.size(), memory_format=memory_format 149 ) 150 for child in module.children(): 151 convert_conv3d_weight_memory_format(child, memory_format) 152 return module 153