1# mypy: allow-untyped-defs
2# Copyright 2022 Cruise LLC
3import logging
4import warnings
5from collections import OrderedDict
6from typing import Dict, Iterable, Union
7
8import torch
9import torch.distributed as dist
10import torch.distributed.algorithms.model_averaging.averagers as averagers
11import torch.distributed.algorithms.model_averaging.utils as utils
12
13
14logger = logging.getLogger(__name__)
15
16
17class HierarchicalModelAverager(averagers.ModelAverager):
18    r"""
19    Runs hierarchical model averaging (`hierarchical SGD <https://arxiv.org/pdf/2010.12998.pdf>`_).
20
21    Process groups of different sizes are organized in a hierarchy, and they average parameters
22    by using different periods concurrently after the warm-up stage.
23    This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
24    that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
25    a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
26    level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
27    Similarly, the process groups within this class do not have such an intra-machine process
28    subgroup, which should be embedded by the post-local SGD communication hook instead.
29
30    Args:
31        period_group_size_dict: An ordered dict mapping keys of model averaging period to
32                                process group size, used for initializing process groups of
33                                different sizes in a hierarchy to average parameters concurrently.
34                                Particularly, at each iteration, there will be at most a single
35                                process group that runs averaging -- the period of such group should
36                                have the largest period which the current step can be divided by.
37                                For example, if the dict has three keys: 2, 4, and 8,
38                                then this means totally three process groups will be created to
39                                average parameters every 2, 4, and 8 iterations, respectively.
40                                At the 4th iteration, only the second process group will run
41                                averaging, because the first process group should be a
42                                subset of the second process group, and no need to execute the first
43                                process group redundantly.
44                                On the other hand, the third process group can only be triggered
45                                every 8 iterations, so it will not be triggered at the 4th iteration.
46        warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
47        process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
48                                                If ``None``, the default process group, which is created
49                                                by :func:`torch.distributed.init_process_group`, will be used.
50                                                (default: ``None``)
51
52    Example::
53        >>> # xdoctest: +SKIP('undefined rank')
54        >>> from collections import OrderedDict
55        >>> import torch
56        >>> import torch.distributed as dist
57        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
58        >>>     PostLocalSGDState,
59        >>>     post_localSGD_hook,
60        >>> )
61        >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
62        >>> import torch.nn as nn
63        >>>
64        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
65        >>> torch.cuda.set_device(rank)
66        >>> module = nn.Linear(1, 1, bias=False).to(rank)
67        >>> model = nn.parallel.DistributedDataParallel(
68        >>>    module, device_ids=[rank], output_device=rank
69        >>> )
70        >>> # Register a post-localSGD communication hook.
71        >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
72        >>> subgroup, _ = dist.new_subgroups()
73        >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
74        >>> model.register_comm_hook(state, post_localSGD_hook)
75        >>>
76        >>> # Average parameters among each group of 8 processes every 4 iterations, and among all
77        >>> # the 16 processes every 16 iterations.
78        >>> averager = hierarchicalSGD.HierarchicalModelAverager(
79        >>>     period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
80        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
81        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
82        >>> # After 100 steps, run model averaging at two levels.
83        >>> for step in range(0, 200):
84        >>>    optimizer.zero_grad()
85        >>>    loss = loss_fn(output, labels)
86        >>>    loss.backward()
87        >>>    optimizer.step()
88        >>>    # Average parameters after ``optimizer.step()``.
89        >>>    # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
90        >>>    averager.average_parameters(model.parameters())
91
92    .. warning ::
93        The last group size in the dict must be the size of the provided ``process_group``,
94        which indicates model averaging at the highest level of the hierarchy.
95        If ``process_group`` is not provided, then the last group size should be equal to the world size.
96
97    .. warning ::
98        `HierarchicalModelAverager` is experimental and subject to change.
99    """
100
101    def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
102        super().__init__(process_group)
103        if not period_group_size_dict:
104            raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
105        self._periods = list(period_group_size_dict.keys())
106        if self._periods[0] <= 0:
107            raise ValueError(
108                "The minimum period in arg ``period_group_size_dict`` must be a positive value."
109            )
110        elif self._periods[-1] == 1:
111            warnings.warn(
112                "When the maximum period in arg ``period_group_size_dict`` is 1, "
113                "no need to use model averaging because the communication cost "
114                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
115                "by DistributedDataParallel in the backward pass. Therefore, only "
116                "DistributedDataParallel should be used for this case."
117            )
118        overall_group_size = dist.get_world_size(group=self.process_group)
119        if list(period_group_size_dict.values())[-1] != overall_group_size:
120            raise ValueError(
121                f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
122                f"must be equal to the size of arg ``process_group`` {overall_group_size}."
123            )
124
125        self.period_process_group_dict = OrderedDict()
126        logger.info("Model averaging hierarchy:")
127        for period, group_size in period_group_size_dict.items():
128            logger.info(
129                "\tEach group that has %s processes average parameters every %s iterations, "
130                "if no higher-level averaging.",
131                group_size,
132                period,
133            )
134            if group_size != overall_group_size:
135                self.period_process_group_dict[period], _ = dist.new_subgroups(
136                    group_size=group_size, group=self.process_group
137                )
138            else:
139                self.period_process_group_dict[period] = self.process_group
140
141        if warmup_steps < 0:
142            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
143        self.warmup_steps = warmup_steps
144
145    def _find_process_group(self):
146        """
147        Return a process group as the value of an ``period_process_group_dict`` entry.
148
149        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
150        then the returned process group is the one corresponding to the largest period,
151        since this process group will be used for averaging parameters at this ``step``.
152        Returns ``None`` if not found.
153        """
154        for period in reversed(self._periods):
155            if self.step % period == 0:
156                return self.period_process_group_dict[period]
157        return None
158
159    def average_parameters(
160        self,
161        params: Union[
162            Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
163        ],
164    ):
165        """
166        Averages parameters or parameter groups of an optimizer.
167
168        Averaging only occurs if ``step`` is no less than ``warmup_steps``
169        and it can be divided by a period in the keys of ``period_process_group_dict``,
170        where ``step`` is increased by 1 at each iteration in the training loop.
171        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
172        only the largest period is used, and the corresponding process group is used for averaging parameters.
173        Args:
174            params: The parameters of a model or parameter groups of an optimizer.
175        """
176        if self.step >= self.warmup_steps:
177            group = self._find_process_group()
178            if group is not None:
179                utils.average_parameters_or_parameter_groups(params, group)
180        self.step += 1
181