1# mypy: allow-untyped-defs 2# Copyright 2022 Cruise LLC 3import logging 4import warnings 5from collections import OrderedDict 6from typing import Dict, Iterable, Union 7 8import torch 9import torch.distributed as dist 10import torch.distributed.algorithms.model_averaging.averagers as averagers 11import torch.distributed.algorithms.model_averaging.utils as utils 12 13 14logger = logging.getLogger(__name__) 15 16 17class HierarchicalModelAverager(averagers.ModelAverager): 18 r""" 19 Runs hierarchical model averaging (`hierarchical SGD <https://arxiv.org/pdf/2010.12998.pdf>`_). 20 21 Process groups of different sizes are organized in a hierarchy, and they average parameters 22 by using different periods concurrently after the warm-up stage. 23 This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` 24 that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports 25 a two-level hierarchy: the intra-machine level and the global level, where the intra-machine 26 level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. 27 Similarly, the process groups within this class do not have such an intra-machine process 28 subgroup, which should be embedded by the post-local SGD communication hook instead. 29 30 Args: 31 period_group_size_dict: An ordered dict mapping keys of model averaging period to 32 process group size, used for initializing process groups of 33 different sizes in a hierarchy to average parameters concurrently. 34 Particularly, at each iteration, there will be at most a single 35 process group that runs averaging -- the period of such group should 36 have the largest period which the current step can be divided by. 37 For example, if the dict has three keys: 2, 4, and 8, 38 then this means totally three process groups will be created to 39 average parameters every 2, 4, and 8 iterations, respectively. 40 At the 4th iteration, only the second process group will run 41 averaging, because the first process group should be a 42 subset of the second process group, and no need to execute the first 43 process group redundantly. 44 On the other hand, the third process group can only be triggered 45 every 8 iterations, so it will not be triggered at the 4th iteration. 46 warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. 47 process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. 48 If ``None``, the default process group, which is created 49 by :func:`torch.distributed.init_process_group`, will be used. 50 (default: ``None``) 51 52 Example:: 53 >>> # xdoctest: +SKIP('undefined rank') 54 >>> from collections import OrderedDict 55 >>> import torch 56 >>> import torch.distributed as dist 57 >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( 58 >>> PostLocalSGDState, 59 >>> post_localSGD_hook, 60 >>> ) 61 >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD 62 >>> import torch.nn as nn 63 >>> 64 >>> dist.init_process_group("nccl", rank=rank, world_size=16) 65 >>> torch.cuda.set_device(rank) 66 >>> module = nn.Linear(1, 1, bias=False).to(rank) 67 >>> model = nn.parallel.DistributedDataParallel( 68 >>> module, device_ids=[rank], output_device=rank 69 >>> ) 70 >>> # Register a post-localSGD communication hook. 71 >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. 72 >>> subgroup, _ = dist.new_subgroups() 73 >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100) 74 >>> model.register_comm_hook(state, post_localSGD_hook) 75 >>> 76 >>> # Average parameters among each group of 8 processes every 4 iterations, and among all 77 >>> # the 16 processes every 16 iterations. 78 >>> averager = hierarchicalSGD.HierarchicalModelAverager( 79 >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) 80 >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. 81 >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. 82 >>> # After 100 steps, run model averaging at two levels. 83 >>> for step in range(0, 200): 84 >>> optimizer.zero_grad() 85 >>> loss = loss_fn(output, labels) 86 >>> loss.backward() 87 >>> optimizer.step() 88 >>> # Average parameters after ``optimizer.step()``. 89 >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. 90 >>> averager.average_parameters(model.parameters()) 91 92 .. warning :: 93 The last group size in the dict must be the size of the provided ``process_group``, 94 which indicates model averaging at the highest level of the hierarchy. 95 If ``process_group`` is not provided, then the last group size should be equal to the world size. 96 97 .. warning :: 98 `HierarchicalModelAverager` is experimental and subject to change. 99 """ 100 101 def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): 102 super().__init__(process_group) 103 if not period_group_size_dict: 104 raise ValueError("Arg ``period_group_size_dict`` must not be empty.") 105 self._periods = list(period_group_size_dict.keys()) 106 if self._periods[0] <= 0: 107 raise ValueError( 108 "The minimum period in arg ``period_group_size_dict`` must be a positive value." 109 ) 110 elif self._periods[-1] == 1: 111 warnings.warn( 112 "When the maximum period in arg ``period_group_size_dict`` is 1, " 113 "no need to use model averaging because the communication cost " 114 "of all-reducing parameters will be no less than the cost of all-reducing gradients " 115 "by DistributedDataParallel in the backward pass. Therefore, only " 116 "DistributedDataParallel should be used for this case." 117 ) 118 overall_group_size = dist.get_world_size(group=self.process_group) 119 if list(period_group_size_dict.values())[-1] != overall_group_size: 120 raise ValueError( 121 f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} " 122 f"must be equal to the size of arg ``process_group`` {overall_group_size}." 123 ) 124 125 self.period_process_group_dict = OrderedDict() 126 logger.info("Model averaging hierarchy:") 127 for period, group_size in period_group_size_dict.items(): 128 logger.info( 129 "\tEach group that has %s processes average parameters every %s iterations, " 130 "if no higher-level averaging.", 131 group_size, 132 period, 133 ) 134 if group_size != overall_group_size: 135 self.period_process_group_dict[period], _ = dist.new_subgroups( 136 group_size=group_size, group=self.process_group 137 ) 138 else: 139 self.period_process_group_dict[period] = self.process_group 140 141 if warmup_steps < 0: 142 raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") 143 self.warmup_steps = warmup_steps 144 145 def _find_process_group(self): 146 """ 147 Return a process group as the value of an ``period_process_group_dict`` entry. 148 149 If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, 150 then the returned process group is the one corresponding to the largest period, 151 since this process group will be used for averaging parameters at this ``step``. 152 Returns ``None`` if not found. 153 """ 154 for period in reversed(self._periods): 155 if self.step % period == 0: 156 return self.period_process_group_dict[period] 157 return None 158 159 def average_parameters( 160 self, 161 params: Union[ 162 Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] 163 ], 164 ): 165 """ 166 Averages parameters or parameter groups of an optimizer. 167 168 Averaging only occurs if ``step`` is no less than ``warmup_steps`` 169 and it can be divided by a period in the keys of ``period_process_group_dict``, 170 where ``step`` is increased by 1 at each iteration in the training loop. 171 If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, 172 only the largest period is used, and the corresponding process group is used for averaging parameters. 173 Args: 174 params: The parameters of a model or parameter groups of an optimizer. 175 """ 176 if self.step >= self.warmup_steps: 177 group = self._find_process_group() 178 if group is not None: 179 utils.average_parameters_or_parameter_groups(params, group) 180 self.step += 1 181