xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/ddp_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torch.nn.parallel import DistributedDataParallel as DDP
2
3
4def basic_ddp_model(self, rank, model, process_group, hook_state, hook):
5    r"""
6    A function that creates a ddp_model and hook_state objects.
7    The ddp model is initialized with a single device id and
8    the process group. The ddp_model also registers the communication
9    hook.
10    Args:
11        rank (int): worker rank
12        model (nn.Module): neural network model
13        process_group (ProcessGroup): distributed process group
14        hook_state (class): class that will be used to keep track of state
15            during training.
16        hook (function): ddp communication hook
17    """
18    ddp_model = DDP(model, device_ids=[rank], process_group=process_group)
19    hook_state = hook_state(self, process_group)
20    ddp_model.register_comm_hook(hook_state, hook)
21    return ddp_model, hook_state
22