xref: /aosp_15_r20/external/pytorch/torch/nn/utils/memory_format.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4
5def convert_conv2d_weight_memory_format(module, memory_format):
6    r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``.
7
8    The conversion recursively applies to nested ``nn.Module``, including ``module``.
9    Note that it only changes the memory_format, but not the semantics of each dimensions.
10    This function is used to facilitate the computation to adopt NHWC kernels, which
11    provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
12
13    .. note::
14        Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
15        than the utility function ``convert_conv2d_weight_memory_format``. Any
16        layer with 4d weight will be affected by ``model.to``, which does not
17        necessarily benefit from conversion to specified ``memory_format``.
18        One place we are confident in is that NHWC(channels_last) conversion for
19        convolution in cuDNN, as it is beneficial to run convolution in NHWC,
20        even in cases where we have to apply permutation to input tensors.
21
22        Hence our strategy here is to convert only the weight of convolution to
23        channels_last. This ensures that;
24        1. Fast convolution kernels will be used, the benefit of which could
25        outweigh overhead of permutation (if input is not in the same format).
26        2. No unnecessary permutations are applied on layers that do not benefit
27        from memory_format conversion.
28
29        The optimal case is that, layers between convolution layers are channels
30        last compatible. Input tensor would be permuted to channels last when it
31        encounters the first convolution layer and stay in that memory format.
32        Hence following convolutions will not need to permute its input tensor.
33
34        In case where a channels last incompatible layer is between convolution
35        layers, we need to permute the input tensor back to contiguous format
36        for that layer. The input tensor will go through the remaining layers in
37        contiguous format and be permuted to channels last when it encounters
38        another convolution layer. There's no point in propagating that
39        permutation to an earlier layer, as most layers are quite agnostic to
40        ``memory_format``.
41
42        This claim might change when PyTorch supports fusion of permutation, as
43        there might have been a better spot to fuse the permutation other than
44        immediately before a convolution.
45
46    Args:
47        module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
48                            ``nn.Module``
49        memory_format: user specified ``memory_format``,
50            e.g. ``torch.channels_last`` or ``torch.contiguous_format``
51
52    Returns:
53        The original module with updated ``nn.Conv2d``
54
55    Example:
56        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
57        >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
58        >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
59        >>> model = nn.Sequential(
60        >>>     nn.Conv2d(8, 4, 3)).cuda().half()
61        >>> # This is identical to:
62        >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
63        >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
64        >>> out = model(input)
65    """
66    # TODO: expand this to `_ConvNd` when channels_last support is extended
67    # beyond only 4d tensors.
68    if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
69        weight_data = (
70            module.weight.detach().clone().contiguous(memory_format=memory_format)
71        )
72        module.weight.data = weight_data.resize_(
73            weight_data.size(), memory_format=memory_format
74        )
75    for child in module.children():
76        convert_conv2d_weight_memory_format(child, memory_format)
77    return module
78
79
80def convert_conv3d_weight_memory_format(module, memory_format):
81    r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
82    The conversion recursively applies to nested ``nn.Module``, including ``module``.
83    Note that it only changes the memory_format, but not the semantics of each dimensions.
84    This function is used to facilitate the computation to adopt NHWC kernels, which
85    provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
86
87    .. note::
88        Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive
89        than the utility function ``convert_conv3d_weight_memory_format``. Any
90        layer with 4d weight will be affected by ``model.to``, which does not
91        necessarily benefit from conversion to specified ``memory_format``.
92        One place we are confident in is that NDHWC(channels_last_3d) conversion for
93        convolution in cuDNN, as it is beneficial to run convolution in NDHWC,
94        even in cases where we have to apply permutation to input tensors.
95
96        Hence our strategy here is to convert only the weight of convolution to
97        channels_last_3d. This ensures that;
98        1. Fast convolution kernels will be used, the benefit of which could
99        outweigh overhead of permutation (if input is not in the same format).
100        2. No unnecessary permutations are applied on layers that do not benefit
101        from memory_format conversion.
102
103        The optimal case is that, layers between convolution layers are channels
104        last compatible. Input tensor would be permuted to channels last when it
105        encounters the first convolution layer and stay in that memory format.
106        Hence following convolutions will not need to permute its input tensor.
107
108        In case where a channels last incompatible layer is between convolution
109        layers, we need to permute the input tensor back to contiguous format
110        for that layer. The input tensor will go through the remaining layers in
111        contiguous format and be permuted to channels last when it encounters
112        another convolution layer. There's no point in propagating that
113        permutation to an earlier layer, as most layers are quite agnostic to
114        ``memory_format``.
115
116        This claim might change when PyTorch supports fusion of permutation, as
117        there might have been a better spot to fuse the permutation other than
118        immediately before a convolution.
119
120    Args:
121        module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
122                            ``nn.Module``
123        memory_format: user specified ``memory_format``,
124            e.g. ``torch.channels_last`` or ``torch.contiguous_format``
125
126    Returns:
127        The original module with updated ``nn.Conv3d``
128
129    Example:
130        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
131        >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
132        >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
133        >>> model = nn.Sequential(
134        >>>     nn.Conv3d(8, 4, 3)).cuda().half()
135        >>> # This is identical to:
136        >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
137        >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
138        >>> out = model(input)
139    """
140
141    # TODO: expand this to `_ConvNd` when channels_last support is extended
142    # beyond only 4d tensors.
143    if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
144        weight_data = (
145            module.weight.detach().clone().contiguous(memory_format=memory_format)
146        )
147        module.weight.data = weight_data.resize_(
148            weight_data.size(), memory_format=memory_format
149        )
150    for child in module.children():
151        convert_conv3d_weight_memory_format(child, memory_format)
152    return module
153