xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/replicate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from collections import OrderedDict
2from typing import (
3    cast,
4    Dict,
5    Iterator,
6    List,
7    Optional,
8    Sequence,
9    Set,
10    TYPE_CHECKING,
11    TypeVar,
12    Union,
13)
14
15import torch
16from torch._utils import _get_device_index
17from torch.nn.modules import Module
18from torch.nn.parallel import comm
19
20
21if TYPE_CHECKING:
22    from torch.jit import ScriptModule
23    from torch.jit._state import EnabledProxy
24
25
26__all__ = ["replicate"]
27
28
29def _is_script_module(module: Module) -> bool:
30    import torch.jit
31
32    return isinstance(module, torch.jit.ScriptModule)
33
34
35def _is_script_method(module: Module) -> bool:
36    import torch.jit
37
38    return isinstance(module, torch._C.ScriptMethod)
39
40
41def _init_script_module() -> "ScriptModule":
42    import torch.jit
43
44    return torch.jit.ScriptModule()
45
46
47def _is_jit_enabled() -> "EnabledProxy":
48    import torch.jit._state
49
50    return torch.jit._state._enabled
51
52
53# Check if we can safely replicate the module.
54# there are two types of module:
55# 1. python modules
56# 2. ScriptModule
57#
58# currently a module cannot be replicated properly if the descendants of
59# any ScriptModule contains python module (type 1 above)
60def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
61    # module.modules() contains module itself as the first element
62    def descendant_modules(module: Module) -> Iterator[Module]:
63        gen = module.modules()
64        next(gen)
65        return gen
66
67    if not _is_jit_enabled():
68        return True
69    if memo is None:
70        memo = set()
71
72    # memoize visited modules
73    memo.add(module)
74    if _is_script_module(module):
75        memo.update(descendant_modules(module))
76        return all(
77            _is_script_module(descendant) for descendant in descendant_modules(module)
78        )
79
80    for child in module.children():
81        # since any unreplicatable module will cause the check to return
82        # False early, visited modules here can be safely ignored.
83        if child in memo:
84            continue
85        if not _replicatable_module(child, memo):
86            return False
87
88    return True
89
90
91def _broadcast_coalesced_reshape(
92    tensors: Sequence[torch.Tensor],
93    devices: Sequence[Union[int, torch.device]],
94    detach: bool = False,
95) -> List[List[torch.Tensor]]:
96    from torch.nn.parallel._functions import Broadcast
97
98    if detach:
99        return comm.broadcast_coalesced(tensors, devices)
100    else:
101        # Use the autograd function to broadcast if not detach
102        if len(tensors) > 0:
103            tensor_copies = Broadcast.apply(devices, *tensors)
104            return [
105                tensor_copies[i : i + len(tensors)]
106                for i in range(0, len(tensor_copies), len(tensors))
107            ]
108        else:
109            return []
110
111
112T = TypeVar("T", bound=Module)
113
114
115def replicate(
116    network: T,
117    devices: Sequence[Union[int, torch.device]],
118    detach: bool = False,
119) -> List[T]:
120    if not _replicatable_module(network):
121        raise RuntimeError(
122            "Cannot replicate network where python modules are "
123            "childrens of ScriptModule"
124        )
125
126    if not devices:
127        return []
128
129    devices = [_get_device_index(x, True) for x in devices]
130    num_replicas = len(devices)
131
132    params = list(network.parameters())
133    param_indices = {param: idx for idx, param in enumerate(params)}
134    param_copies = _broadcast_coalesced_reshape(params, devices, detach)
135
136    buffers = list(network.buffers())
137    buffers_rg: List[torch.Tensor] = []
138    buffers_not_rg: List[torch.Tensor] = []
139    for buf in buffers:
140        if buf.requires_grad and not detach:
141            buffers_rg.append(buf)
142        else:
143            buffers_not_rg.append(buf)
144
145    buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
146    buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
147
148    buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
149    buffer_copies_not_rg = _broadcast_coalesced_reshape(
150        buffers_not_rg, devices, detach=True
151    )
152
153    modules = list(network.modules())
154    module_copies: List[List[Module]] = [[] for _ in devices]
155    module_indices: Dict[Module, int] = {}
156
157    for i, module in enumerate(modules):
158        module_indices[module] = i
159        for j in range(num_replicas):
160            replica = module._replicate_for_data_parallel()
161            # This is a temporary fix for DDP. DDP needs to access the
162            # replicated model parameters. It used to do so through
163            # `mode.parameters()`. The fix added in #33907 for DP stops the
164            # `parameters()` API from exposing the replicated parameters.
165            # Hence, we add a `_former_parameters` dict here to support DDP.
166            replica._former_parameters = OrderedDict()
167
168            module_copies[j].append(replica)
169
170    for i, module in enumerate(modules):
171        for key, child in module._modules.items():
172            if child is None:
173                for j in range(num_replicas):
174                    replica = module_copies[j][i]
175                    replica._modules[key] = None
176            else:
177                module_idx = module_indices[child]
178                for j in range(num_replicas):
179                    replica = module_copies[j][i]
180                    setattr(replica, key, module_copies[j][module_idx])
181        for key, param in module._parameters.items():
182            if param is None:
183                for j in range(num_replicas):
184                    replica = module_copies[j][i]
185                    replica._parameters[key] = None
186            else:
187                param_idx = param_indices[param]
188                for j in range(num_replicas):
189                    replica = module_copies[j][i]
190                    param_copy = param_copies[j][param_idx]
191                    # parameters in replicas are no longer leaves,
192                    # so setattr them as non-parameter attributes
193                    setattr(replica, key, param_copy)
194                    # expose the parameter for DDP
195                    replica._former_parameters[key] = param_copy
196        for key, buf in module._buffers.items():  # type: ignore[assignment]
197            if buf is None:
198                for j in range(num_replicas):
199                    replica = module_copies[j][i]
200                    replica._buffers[key] = None
201            else:
202                if buf.requires_grad and not detach:
203                    buffer_copies = buffer_copies_rg
204                    buffer_idx = buffer_indices_rg[buf]
205                else:
206                    buffer_copies = buffer_copies_not_rg
207                    buffer_idx = buffer_indices_not_rg[buf]
208                for j in range(num_replicas):
209                    replica = module_copies[j][i]
210                    setattr(replica, key, buffer_copies[j][buffer_idx])
211
212    return [cast(T, module_copies[j][0]) for j in range(num_replicas)]
213