1# mypy: allow-untyped-defs 2# flake8: noqa C101 3import itertools 4from typing import Dict, Iterable, Iterator, Union 5 6import torch 7import torch.distributed as dist 8 9# The two imports below are not always available depending on the 10# USE_DISTRIBUTED compile flag. Make sure they raise import error 11# if we're trying to use them. 12from torch.distributed import group, ProcessGroup 13 14 15__all__ = [ 16 "average_parameters", 17 "get_params_to_average", 18 "average_parameters_or_parameter_groups", 19] 20 21 22def average_parameters( 23 params: Iterator[torch.nn.Parameter], process_group: ProcessGroup 24): 25 """ 26 Averages all the given parameters. 27 28 For allreduce efficiency, all the parameters are flattened into a contiguous buffer. 29 Thus, it requires extra memory of the same size as the given parameters. 30 """ 31 group_to_use = process_group if process_group is not None else group.WORLD 32 # Do not update any parameter if not in the process group. 33 if dist._rank_not_in_group(group_to_use): 34 return 35 36 params_it1, params_it2 = itertools.tee(params) 37 # If the input parameters have different data types, 38 # packing these parameters will trigger an implicit type up-casting. 39 # The original parameter data types will be restored during the subsequent unpacking. 40 flat_params = torch.cat([p.data.reshape(-1) for p in params_it1]) 41 flat_params /= dist.get_world_size(group_to_use) 42 # Make sure the allreduce will not conflict with any other ongoing process group. 43 if torch.cuda.is_available(): 44 torch.cuda.synchronize() 45 dist.all_reduce(flat_params, group=group_to_use) 46 47 offset = 0 48 for p in params_it2: 49 p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p) 50 offset += p.numel() 51 52 53def get_params_to_average( 54 params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]] 55): 56 """ 57 Return a list of parameters that need to average. 58 59 This filters out the parameters that do not contain any gradients. 60 Args: 61 params: The parameters of a model or parameter groups of an optimizer. 62 """ 63 filtered_params = [] 64 for param in params: 65 if isinstance(param, torch.nn.Parameter): 66 # model.parameters() input 67 param_data = param 68 if param_data.grad is not None: 69 filtered_params.append(param_data) 70 elif isinstance(param, dict): 71 # optimizer.param_groups input 72 for param_data in param["params"]: 73 if param_data.grad is not None: 74 filtered_params.append(param_data) 75 else: 76 raise NotImplementedError( 77 f"Parameter input of type {type(param)} is not supported" 78 ) 79 return filtered_params 80 81 82def average_parameters_or_parameter_groups( 83 params: Union[ 84 Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] 85 ], 86 process_group: ProcessGroup, 87): 88 """Averages parameters of a model or parameter groups of an optimizer.""" 89 average_parameters(iter(get_params_to_average(params)), process_group) 90