1from collections import OrderedDict 2from typing import ( 3 cast, 4 Dict, 5 Iterator, 6 List, 7 Optional, 8 Sequence, 9 Set, 10 TYPE_CHECKING, 11 TypeVar, 12 Union, 13) 14 15import torch 16from torch._utils import _get_device_index 17from torch.nn.modules import Module 18from torch.nn.parallel import comm 19 20 21if TYPE_CHECKING: 22 from torch.jit import ScriptModule 23 from torch.jit._state import EnabledProxy 24 25 26__all__ = ["replicate"] 27 28 29def _is_script_module(module: Module) -> bool: 30 import torch.jit 31 32 return isinstance(module, torch.jit.ScriptModule) 33 34 35def _is_script_method(module: Module) -> bool: 36 import torch.jit 37 38 return isinstance(module, torch._C.ScriptMethod) 39 40 41def _init_script_module() -> "ScriptModule": 42 import torch.jit 43 44 return torch.jit.ScriptModule() 45 46 47def _is_jit_enabled() -> "EnabledProxy": 48 import torch.jit._state 49 50 return torch.jit._state._enabled 51 52 53# Check if we can safely replicate the module. 54# there are two types of module: 55# 1. python modules 56# 2. ScriptModule 57# 58# currently a module cannot be replicated properly if the descendants of 59# any ScriptModule contains python module (type 1 above) 60def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: 61 # module.modules() contains module itself as the first element 62 def descendant_modules(module: Module) -> Iterator[Module]: 63 gen = module.modules() 64 next(gen) 65 return gen 66 67 if not _is_jit_enabled(): 68 return True 69 if memo is None: 70 memo = set() 71 72 # memoize visited modules 73 memo.add(module) 74 if _is_script_module(module): 75 memo.update(descendant_modules(module)) 76 return all( 77 _is_script_module(descendant) for descendant in descendant_modules(module) 78 ) 79 80 for child in module.children(): 81 # since any unreplicatable module will cause the check to return 82 # False early, visited modules here can be safely ignored. 83 if child in memo: 84 continue 85 if not _replicatable_module(child, memo): 86 return False 87 88 return True 89 90 91def _broadcast_coalesced_reshape( 92 tensors: Sequence[torch.Tensor], 93 devices: Sequence[Union[int, torch.device]], 94 detach: bool = False, 95) -> List[List[torch.Tensor]]: 96 from torch.nn.parallel._functions import Broadcast 97 98 if detach: 99 return comm.broadcast_coalesced(tensors, devices) 100 else: 101 # Use the autograd function to broadcast if not detach 102 if len(tensors) > 0: 103 tensor_copies = Broadcast.apply(devices, *tensors) 104 return [ 105 tensor_copies[i : i + len(tensors)] 106 for i in range(0, len(tensor_copies), len(tensors)) 107 ] 108 else: 109 return [] 110 111 112T = TypeVar("T", bound=Module) 113 114 115def replicate( 116 network: T, 117 devices: Sequence[Union[int, torch.device]], 118 detach: bool = False, 119) -> List[T]: 120 if not _replicatable_module(network): 121 raise RuntimeError( 122 "Cannot replicate network where python modules are " 123 "childrens of ScriptModule" 124 ) 125 126 if not devices: 127 return [] 128 129 devices = [_get_device_index(x, True) for x in devices] 130 num_replicas = len(devices) 131 132 params = list(network.parameters()) 133 param_indices = {param: idx for idx, param in enumerate(params)} 134 param_copies = _broadcast_coalesced_reshape(params, devices, detach) 135 136 buffers = list(network.buffers()) 137 buffers_rg: List[torch.Tensor] = [] 138 buffers_not_rg: List[torch.Tensor] = [] 139 for buf in buffers: 140 if buf.requires_grad and not detach: 141 buffers_rg.append(buf) 142 else: 143 buffers_not_rg.append(buf) 144 145 buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} 146 buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} 147 148 buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) 149 buffer_copies_not_rg = _broadcast_coalesced_reshape( 150 buffers_not_rg, devices, detach=True 151 ) 152 153 modules = list(network.modules()) 154 module_copies: List[List[Module]] = [[] for _ in devices] 155 module_indices: Dict[Module, int] = {} 156 157 for i, module in enumerate(modules): 158 module_indices[module] = i 159 for j in range(num_replicas): 160 replica = module._replicate_for_data_parallel() 161 # This is a temporary fix for DDP. DDP needs to access the 162 # replicated model parameters. It used to do so through 163 # `mode.parameters()`. The fix added in #33907 for DP stops the 164 # `parameters()` API from exposing the replicated parameters. 165 # Hence, we add a `_former_parameters` dict here to support DDP. 166 replica._former_parameters = OrderedDict() 167 168 module_copies[j].append(replica) 169 170 for i, module in enumerate(modules): 171 for key, child in module._modules.items(): 172 if child is None: 173 for j in range(num_replicas): 174 replica = module_copies[j][i] 175 replica._modules[key] = None 176 else: 177 module_idx = module_indices[child] 178 for j in range(num_replicas): 179 replica = module_copies[j][i] 180 setattr(replica, key, module_copies[j][module_idx]) 181 for key, param in module._parameters.items(): 182 if param is None: 183 for j in range(num_replicas): 184 replica = module_copies[j][i] 185 replica._parameters[key] = None 186 else: 187 param_idx = param_indices[param] 188 for j in range(num_replicas): 189 replica = module_copies[j][i] 190 param_copy = param_copies[j][param_idx] 191 # parameters in replicas are no longer leaves, 192 # so setattr them as non-parameter attributes 193 setattr(replica, key, param_copy) 194 # expose the parameter for DDP 195 replica._former_parameters[key] = param_copy 196 for key, buf in module._buffers.items(): # type: ignore[assignment] 197 if buf is None: 198 for j in range(num_replicas): 199 replica = module_copies[j][i] 200 replica._buffers[key] = None 201 else: 202 if buf.requires_grad and not detach: 203 buffer_copies = buffer_copies_rg 204 buffer_idx = buffer_indices_rg[buf] 205 else: 206 buffer_copies = buffer_copies_not_rg 207 buffer_idx = buffer_indices_not_rg[buf] 208 for j in range(num_replicas): 209 replica = module_copies[j][i] 210 setattr(replica, key, buffer_copies[j][buffer_idx]) 211 212 return [cast(T, module_copies[j][0]) for j in range(num_replicas)] 213