1from torch.nn.parallel import DistributedDataParallel as DDP 2 3 4def basic_ddp_model(self, rank, model, process_group, hook_state, hook): 5 r""" 6 A function that creates a ddp_model and hook_state objects. 7 The ddp model is initialized with a single device id and 8 the process group. The ddp_model also registers the communication 9 hook. 10 Args: 11 rank (int): worker rank 12 model (nn.Module): neural network model 13 process_group (ProcessGroup): distributed process group 14 hook_state (class): class that will be used to keep track of state 15 during training. 16 hook (function): ddp communication hook 17 """ 18 ddp_model = DDP(model, device_ids=[rank], process_group=process_group) 19 hook_state = hook_state(self, process_group) 20 ddp_model.register_comm_hook(hook_state, hook) 21 return ddp_model, hook_state 22