xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/model_averaging/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# flake8: noqa C101
3import itertools
4from typing import Dict, Iterable, Iterator, Union
5
6import torch
7import torch.distributed as dist
8
9# The two imports below are not always available depending on the
10# USE_DISTRIBUTED compile flag. Make sure they raise import error
11# if we're trying to use them.
12from torch.distributed import group, ProcessGroup
13
14
15__all__ = [
16    "average_parameters",
17    "get_params_to_average",
18    "average_parameters_or_parameter_groups",
19]
20
21
22def average_parameters(
23    params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
24):
25    """
26    Averages all the given parameters.
27
28    For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
29    Thus, it requires extra memory of the same size as the given parameters.
30    """
31    group_to_use = process_group if process_group is not None else group.WORLD
32    # Do not update any parameter if not in the process group.
33    if dist._rank_not_in_group(group_to_use):
34        return
35
36    params_it1, params_it2 = itertools.tee(params)
37    # If the input parameters have different data types,
38    # packing these parameters will trigger an implicit type up-casting.
39    # The original parameter data types will be restored during the subsequent unpacking.
40    flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
41    flat_params /= dist.get_world_size(group_to_use)
42    # Make sure the allreduce will not conflict with any other ongoing process group.
43    if torch.cuda.is_available():
44        torch.cuda.synchronize()
45    dist.all_reduce(flat_params, group=group_to_use)
46
47    offset = 0
48    for p in params_it2:
49        p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
50        offset += p.numel()
51
52
53def get_params_to_average(
54    params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
55):
56    """
57    Return a list of parameters that need to average.
58
59    This filters out the parameters that do not contain any gradients.
60    Args:
61        params: The parameters of a model or parameter groups of an optimizer.
62    """
63    filtered_params = []
64    for param in params:
65        if isinstance(param, torch.nn.Parameter):
66            # model.parameters() input
67            param_data = param
68            if param_data.grad is not None:
69                filtered_params.append(param_data)
70        elif isinstance(param, dict):
71            # optimizer.param_groups input
72            for param_data in param["params"]:
73                if param_data.grad is not None:
74                    filtered_params.append(param_data)
75        else:
76            raise NotImplementedError(
77                f"Parameter input of type {type(param)} is not supported"
78            )
79    return filtered_params
80
81
82def average_parameters_or_parameter_groups(
83    params: Union[
84        Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
85    ],
86    process_group: ProcessGroup,
87):
88    """Averages parameters of a model or parameter groups of an optimizer."""
89    average_parameters(iter(get_params_to_average(params)), process_group)
90