xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from functools import partial
4from typing import Any, Callable, List, no_type_check
5
6import torch
7import torch.distributed as dist
8from torch.autograd import Variable
9
10
11__all__: List[str] = []
12
13_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
14
15
16class _OptimizerHookState:
17    """
18    Holds state for running optimizer in-line after DDP communication hook.
19
20    Currently contains only optimizer class which must have a method `step_param`.
21    """
22
23    __slots__ = ["functional_optimizer", "params_to_optimize"]
24
25    def __init__(self, functional_optim, params=None):
26        self.functional_optimizer = functional_optim
27        self._check_valid_functional_optim()
28        self._set_params_to_optimize(params)
29
30    def _set_params_to_optimize(self, params):
31        if params is not None:
32            self.params_to_optimize = set(params)
33
34    def _check_valid_functional_optim(self):
35        if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
36            raise ValueError(
37                f"Class {type(self.functional_optimizer)} must implement method "
38                f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}."
39            )
40
41
42@dataclass
43class _OptimInBackwardHookState:
44    optim_stream: torch.cuda.Stream
45    wait_for_optim_stream_enqueued: bool
46
47
48@no_type_check
49def _apply_optim_in_backward_hook(
50    gradient_is_bucket_view: bool,
51) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
52    r"""
53    Register hook to apply the optimizer in backward.
54
55    If torch.distributed.optim._apply_optimizer_in_backward is used to overlap
56    optimizer with backward pass, DDP will run the below hook to run optimizer
57    step for parameters after gradient communication has taken place.
58    """
59    optim_in_bwd_state = _OptimInBackwardHookState(
60        optim_stream=torch.cuda.Stream(),
61        wait_for_optim_stream_enqueued=False,
62    )
63
64    def apply_optim_in_backward_hook(
65        hook_state: Any,
66        bucket: dist.GradBucket,
67        optim_stream_state,
68    ) -> torch.futures.Future[torch.Tensor]:
69        # Run original hook
70        ddp_weakref = hook_state
71        ddp_inst = ddp_weakref()
72        reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
73        fut = reducer._run_allreduce_hook(bucket)
74        optimizer_stream = optim_stream_state.optim_stream
75        with torch.cuda.stream(optimizer_stream):
76            fut.wait()
77            # Apply gradient division since C++ side only allreduces and does
78            # not average. TODO: (rohan-varma) the div factor may be different
79            # when running with join hook
80            bucket.buffer().div_(process_group.size())
81            model_params = bucket.parameters()
82            grads = bucket.gradients()
83            # TODO (rohan-varma): upcast as needed for DDP mixed precision,
84            # once optimizer in backward + DDP mixed precision is supported.
85            for p, g in zip(model_params, grads):
86                if hasattr(p, "_in_backward_optimizers"):
87                    # Note: need to set grad to the bucket's grad, because
88                    # running allreduce results in the bucket's grad being
89                    # reduced, but not grad field.
90                    if not gradient_is_bucket_view:
91                        p.grad = g
92                    for optim in p._in_backward_optimizers:
93                        optim.step()
94
95        # Need to return a Future[Tensor] to obey comm hook API contract.
96        ret_fut = torch.futures.Future()
97        ret_fut.set_result(bucket.buffer())
98
99        # enqueue a callback to wait for this optimizer stream at the end of
100        # backward and set all DDP managed grads to None.
101        def wait_for_optim_stream_callback():
102            torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream)
103            # Set DDP managed grads to None
104            for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
105                if hasattr(param, "_in_backward_optimizers"):
106                    param.grad = None
107
108            # reset for the next backwards pass
109            optim_stream_state.wait_for_optim_stream_enqueued = False
110
111        if not optim_stream_state.wait_for_optim_stream_enqueued:
112            Variable._execution_engine.queue_callback(wait_for_optim_stream_callback)
113            # mark that the callback is enqueued
114            optim_stream_state.wait_for_optim_stream_enqueued = True
115
116        return ret_fut
117
118    comm_hook = partial(
119        apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state
120    )
121    # These are needed for DDP's logging of comm hooks
122    comm_hook.__name__ = apply_optim_in_backward_hook.__name__
123    comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__
124
125    return comm_hook
126
127
128def _hook_then_optimizer(
129    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
130    optimizer_state: _OptimizerHookState,
131) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
132    r"""Run optimizer in a functional fashion after DDP communication hook."""
133    has_set_params = (
134        hasattr(optimizer_state, "params_to_optimize")
135        and optimizer_state.params_to_optimize is not None
136    )
137
138    def hook_then_optimizer_wrapper(
139        hook_state, bucket: dist.GradBucket
140    ) -> torch.futures.Future[torch.Tensor]:
141        # Run original hook
142        fut = hook(hook_state, bucket)
143
144        def optimizer_step(fut):
145            gradient_tensors = bucket.gradients()
146            model_params = bucket.parameters()
147            for grad_tensor, model_param in zip(gradient_tensors, model_params):
148                if (
149                    not has_set_params
150                    or model_param in optimizer_state.params_to_optimize
151                ):
152                    optimizer_state.functional_optimizer.step_param(
153                        model_param,
154                        grad_tensor,
155                    )
156            return bucket.buffer()
157
158        return fut.then(optimizer_step)
159
160    return hook_then_optimizer_wrapper
161