1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from functools import partial 4from typing import Any, Callable, List, no_type_check 5 6import torch 7import torch.distributed as dist 8from torch.autograd import Variable 9 10 11__all__: List[str] = [] 12 13_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" 14 15 16class _OptimizerHookState: 17 """ 18 Holds state for running optimizer in-line after DDP communication hook. 19 20 Currently contains only optimizer class which must have a method `step_param`. 21 """ 22 23 __slots__ = ["functional_optimizer", "params_to_optimize"] 24 25 def __init__(self, functional_optim, params=None): 26 self.functional_optimizer = functional_optim 27 self._check_valid_functional_optim() 28 self._set_params_to_optimize(params) 29 30 def _set_params_to_optimize(self, params): 31 if params is not None: 32 self.params_to_optimize = set(params) 33 34 def _check_valid_functional_optim(self): 35 if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME): 36 raise ValueError( 37 f"Class {type(self.functional_optimizer)} must implement method " 38 f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}." 39 ) 40 41 42@dataclass 43class _OptimInBackwardHookState: 44 optim_stream: torch.cuda.Stream 45 wait_for_optim_stream_enqueued: bool 46 47 48@no_type_check 49def _apply_optim_in_backward_hook( 50 gradient_is_bucket_view: bool, 51) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 52 r""" 53 Register hook to apply the optimizer in backward. 54 55 If torch.distributed.optim._apply_optimizer_in_backward is used to overlap 56 optimizer with backward pass, DDP will run the below hook to run optimizer 57 step for parameters after gradient communication has taken place. 58 """ 59 optim_in_bwd_state = _OptimInBackwardHookState( 60 optim_stream=torch.cuda.Stream(), 61 wait_for_optim_stream_enqueued=False, 62 ) 63 64 def apply_optim_in_backward_hook( 65 hook_state: Any, 66 bucket: dist.GradBucket, 67 optim_stream_state, 68 ) -> torch.futures.Future[torch.Tensor]: 69 # Run original hook 70 ddp_weakref = hook_state 71 ddp_inst = ddp_weakref() 72 reducer, process_group = ddp_inst.reducer, ddp_inst.process_group 73 fut = reducer._run_allreduce_hook(bucket) 74 optimizer_stream = optim_stream_state.optim_stream 75 with torch.cuda.stream(optimizer_stream): 76 fut.wait() 77 # Apply gradient division since C++ side only allreduces and does 78 # not average. TODO: (rohan-varma) the div factor may be different 79 # when running with join hook 80 bucket.buffer().div_(process_group.size()) 81 model_params = bucket.parameters() 82 grads = bucket.gradients() 83 # TODO (rohan-varma): upcast as needed for DDP mixed precision, 84 # once optimizer in backward + DDP mixed precision is supported. 85 for p, g in zip(model_params, grads): 86 if hasattr(p, "_in_backward_optimizers"): 87 # Note: need to set grad to the bucket's grad, because 88 # running allreduce results in the bucket's grad being 89 # reduced, but not grad field. 90 if not gradient_is_bucket_view: 91 p.grad = g 92 for optim in p._in_backward_optimizers: 93 optim.step() 94 95 # Need to return a Future[Tensor] to obey comm hook API contract. 96 ret_fut = torch.futures.Future() 97 ret_fut.set_result(bucket.buffer()) 98 99 # enqueue a callback to wait for this optimizer stream at the end of 100 # backward and set all DDP managed grads to None. 101 def wait_for_optim_stream_callback(): 102 torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream) 103 # Set DDP managed grads to None 104 for param in ddp_inst._get_data_parallel_params(ddp_inst.module): 105 if hasattr(param, "_in_backward_optimizers"): 106 param.grad = None 107 108 # reset for the next backwards pass 109 optim_stream_state.wait_for_optim_stream_enqueued = False 110 111 if not optim_stream_state.wait_for_optim_stream_enqueued: 112 Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) 113 # mark that the callback is enqueued 114 optim_stream_state.wait_for_optim_stream_enqueued = True 115 116 return ret_fut 117 118 comm_hook = partial( 119 apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state 120 ) 121 # These are needed for DDP's logging of comm hooks 122 comm_hook.__name__ = apply_optim_in_backward_hook.__name__ 123 comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__ 124 125 return comm_hook 126 127 128def _hook_then_optimizer( 129 hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], 130 optimizer_state: _OptimizerHookState, 131) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 132 r"""Run optimizer in a functional fashion after DDP communication hook.""" 133 has_set_params = ( 134 hasattr(optimizer_state, "params_to_optimize") 135 and optimizer_state.params_to_optimize is not None 136 ) 137 138 def hook_then_optimizer_wrapper( 139 hook_state, bucket: dist.GradBucket 140 ) -> torch.futures.Future[torch.Tensor]: 141 # Run original hook 142 fut = hook(hook_state, bucket) 143 144 def optimizer_step(fut): 145 gradient_tensors = bucket.gradients() 146 model_params = bucket.parameters() 147 for grad_tensor, model_param in zip(gradient_tensors, model_params): 148 if ( 149 not has_set_params 150 or model_param in optimizer_state.params_to_optimize 151 ): 152 optimizer_state.functional_optimizer.step_param( 153 model_param, 154 grad_tensor, 155 ) 156 return bucket.buffer() 157 158 return fut.then(optimizer_step) 159 160 return hook_then_optimizer_wrapper 161