xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/_backward.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import collections
4import logging
5import weakref
6from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union
7
8import torch
9from torch.autograd.graph import GradientEdge, Node
10from torch.nn import Parameter
11
12from ._debug import map_debug_info
13
14
15logger = logging.getLogger(__name__)
16
17
18def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
19    """
20    Get the grad function or grad accumulator for a tensor.
21
22    Accumulate grad nodes are lazily created, so we need to a
23    dummy view in order to trigger its creation.
24    """
25    if t.requires_grad and t.grad_fn is None:
26        # if no grad function (leaf tensors) we use view
27        viewed_t = t.view_as(t)
28        grad_fn = viewed_t.grad_fn
29        if grad_fn is not None:
30            return grad_fn.next_functions[0][0]
31        else:
32            raise RuntimeError(
33                "Attempted to get grad_fn, but got None."
34                "Is this being created in a no-grad context?"
35            )
36    else:
37        return t.grad_fn
38
39
40def reverse_closure(
41    roots: List[Node], target_nodes: Set[Node]
42) -> Tuple[Set[Node], Set[Node]]:
43    """
44    This function returns the reverse closure of the given roots,
45    i.e. the set of nodes that can be reached from the roots by following the
46    reverse edges of the graph. The target_nodes are the nodes that we want to
47    include in the closure.
48    """
49    # Recurse until we reach a target node
50    closure: Set[Node] = set()
51    visited_target_nodes = set()
52    q: Deque[Node] = collections.deque()
53    for node in roots:
54        if node is not None and node not in closure:
55            closure.add(node)
56            q.append(node)
57    while q:
58        node = q.popleft()
59        metadata = cast(Dict[str, List], node.metadata)
60        reverse_edges = metadata.get("reverse_edges", [])
61        for holder_ref, idx in reverse_edges:
62            ref = holder_ref()
63            if ref is None:
64                # this reverse graph is no longer alive
65                # raise RuntimeError("Reverse graph is no longer alive")
66                continue
67            fn = ref.node
68            if fn in closure or fn is None:
69                continue
70            if fn in target_nodes:
71                visited_target_nodes.add(fn)
72                continue
73            closure.add(fn)
74            q.append(fn)
75    return closure, visited_target_nodes
76
77
78# Enable weak pointer
79class Holder:
80    def __init__(self, node: Node):
81        self.node = node
82
83
84def construct_reverse_graph(roots: List[Node]) -> List[Holder]:
85    q: Deque[Node] = collections.deque()
86    root_seen: Set[Node] = set()
87    reverse_graph_refs: List[Holder] = []
88    for node in roots:
89        if node is not None and node not in root_seen:
90            q.append(node)
91            root_seen.add(node)
92    while q:
93        node = q.popleft()
94        for fn, idx in node.next_functions:
95            if fn is not None:
96                # Don't necessarily need to store on the graph
97                metadata = cast(Dict[str, List], fn.metadata)
98                reverse_edges = metadata.get("reverse_edges", [])
99                if len(reverse_edges) == 0:
100                    q.append(fn)
101                holder = Holder(node)
102                holder_ref = weakref.ref(holder)
103                reverse_graph_refs.append(holder)
104                reverse_edges.append((holder_ref, idx))
105                metadata["reverse_edges"] = reverse_edges
106    return reverse_graph_refs
107
108
109def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]:
110    """
111    Given a list of inputs and a list of parameters, return a list of parameter
112    groups, where each group contains the parameters and the intermediates that
113    are connected to the parameters.
114
115    The returned list of parameter groups is a list of dictionaries, where each
116    dictionary contains the following keys:
117    - "params": a set of parameters
118    - "intermediates": a set of intermediates
119
120    The returned list of parameter groups is a list of dictionaries,
121    """
122    # reverse graph that starts with inputs, and goes up to the dOutput or the loss,
123    # but omits weights and any subgraphs connecting weights to this closure
124    inputs_closure, _ = reverse_closure(inputs, set())
125    param_groups: Dict[Node, Dict[str, Set]] = dict()  # keyed on intermediates
126    for i, param in enumerate(params):
127        closure, intersected = reverse_closure([param], inputs_closure)
128        param_group: Dict[str, Set] = {
129            "params": {param},
130            "intermediates": intersected,
131        }
132        for input_node in intersected:
133            existing = param_groups.get(input_node, None)
134            if existing is not None:
135                existing["params"] = existing["params"].union(param_group["params"])
136                existing["intermediates"] = existing["intermediates"].union(
137                    param_group["intermediates"]
138                )
139                param_group = existing
140            else:
141                param_groups[input_node] = param_group
142
143    # Sanity check: union of all param_groups params should be equal to all params
144    union_params: Set[Node] = set()
145    seen_ids: Set[int] = set()
146    unique_param_groups = []
147    for param_group in param_groups.values():
148        if id(param_group) not in seen_ids:
149            seen_ids.add(id(param_group))
150            unique_param_groups.append(param_group)
151            union_params = union_params.union(param_group["params"])
152
153    # The assert will only be true if the input tensor requires gradients,
154    # otherwise the autograd graph will miss the first layer of inputs
155    # assert union_params == set(params)
156    return unique_param_groups
157
158
159def stage_backward_input(
160    stage_outputs: List[torch.Tensor],
161    output_grads: Optional[List[torch.Tensor]],
162    input_values: List[torch.Tensor],
163    weights: Iterator[Parameter],
164):
165    """
166    compute the gradients for only the stage inputs with respect to the stage outputs
167    """
168    stage_output_grad_fns: List[Node] = list(
169        filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
170    )
171    stage_input_grad_fns: List[Node] = list(
172        filter(None, map(_get_grad_fn_or_grad_acc, input_values))
173    )
174    weight_grad_fns: List[Node] = list(
175        filter(None, map(_get_grad_fn_or_grad_acc, weights))
176    )
177
178    reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns)
179    param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns)
180    del reverse_graph_refs
181
182    for param_group in param_groups:
183        for i, intermediate in enumerate(param_group["intermediates"]):
184
185            def get_hook(param_group, i):
186                def hook(grad_inputs):
187                    if param_group.get("grads", None) is None:
188                        param_group["grads"] = [None] * len(
189                            param_group["intermediates"]
190                        )
191                    param_group["grads"][i] = grad_inputs
192
193                return hook
194
195            # These are always "split" nodes that we need to recompute, so
196            # save their inputs.
197            intermediate.register_prehook(get_hook(param_group, i))
198
199    # Stage 0 inputs do not require grads? Should we skip in that case?
200    if all(tensor.requires_grad for tensor in input_values):
201        if output_grads is None:
202            # In case this is the loss and there are no output_grads, then we just use 1s
203            output_grads = [
204                torch.ones_like(stage_output) for stage_output in stage_outputs
205            ]
206
207        dinputs = torch.autograd.grad(
208            stage_outputs,
209            inputs=input_values,
210            grad_outputs=output_grads,
211            retain_graph=True,
212        )
213
214        # update the gradients for inputs
215        for i, inp in enumerate(input_values):
216            if inp.grad is None:
217                inp.grad = dinputs[i]
218            else:
219                inp.grad += dinputs[i]
220    else:
221        dinputs = None
222    return dinputs, param_groups
223
224
225def stage_backward_weight(
226    weights: Iterator[Parameter], param_groups: List[Dict[str, Any]]
227):
228    # map weights to param_group_weights
229    grad_acc_to_weight = {}
230    weight_grads = []
231    for index, weight in enumerate(weights):
232        grad_acc = _get_grad_fn_or_grad_acc(weight)
233        grad_acc_to_weight[grad_acc] = weight, index
234        weight_grads.append(weight.grad)
235
236    for param_group in param_groups:
237        # TODO: Handle case where intermediate can have multiple outputs
238        intermediate_edges = tuple(
239            GradientEdge(i, 0) for i in param_group["intermediates"]
240        )
241        weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
242
243        assert all(len(g) == 1 for g in param_group["grads"])
244        # [NEW!] Able to pass a GradientEdge to autograd.grad as output
245        # We do not need to retain_graph because... guarantee no overlap?
246        # print("trying to execute: ", intermediate_edges, weights_edges)
247        dweights = torch.autograd.grad(
248            intermediate_edges,
249            weights_edges,
250            grad_outputs=sum(param_group["grads"], tuple()),
251        )
252        for grad_acc, dw in zip(param_group["params"], dweights):
253            weight, index = grad_acc_to_weight[grad_acc]
254            if weight.grad is None:
255                weight.grad = dw
256            else:
257                weight.grad += dw
258    # return grads in the original order weights were provided in
259    return weight_grads
260
261
262def stage_backward(
263    stage_output,
264    output_grads,
265    input_values,
266    outputs_with_grads_idxs: Optional[List[int]] = None,  # deprecated, not used
267):
268    """
269    This is a helper function to:
270    1. compute the gradients for the stage inputs, and
271    2. accumulate gradients for the stage module's parameters.
272
273    Given the input value(s) and the corresponding gradient for the output
274    value(s), compute and accumulate gradients for all parameter values (leaves
275    in the autograd trace) as well as return a list of the gradients for the
276    input values
277    """
278    if outputs_with_grads_idxs is not None:
279        # Deprecated, not used in runtime calls, only exists in compiler
280        stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
281        output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
282
283    try:
284        # stage_output may be a composite datatype like dict. Extract all individual
285        # tensor values here
286        stage_output_tensors = []
287        output_grad_tensors = []
288
289        def extract_tensors_with_grads(output_val, grad_val):
290            if isinstance(output_val, torch.Tensor):
291                if not output_val.requires_grad and output_val.grad_fn is None:
292                    return
293                assert isinstance(
294                    grad_val, (torch.Tensor, type(None))
295                ), f"Expected Tensor or None gradient but got {type(grad_val)}"
296                stage_output_tensors.append(output_val)
297                output_grad_tensors.append(grad_val)
298            elif isinstance(output_val, (tuple, list)):
299                if grad_val is None:
300                    return
301                assert isinstance(
302                    grad_val, (tuple, list)
303                ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
304                assert len(output_val) == len(grad_val)
305                for ov, gv in zip(output_val, grad_val):
306                    extract_tensors_with_grads(ov, gv)
307            elif isinstance(output_val, dict):
308                if grad_val is None:
309                    return
310                assert isinstance(grad_val, dict)
311                assert set(output_val.keys()) == set(grad_val.keys())
312                for k in output_val.keys():
313                    extract_tensors_with_grads(output_val[k], grad_val[k])
314            else:
315                # Output is a non-tensor type; just ignore it
316                pass
317
318        extract_tensors_with_grads(stage_output, output_grads)
319
320        torch.autograd.backward(
321            stage_output_tensors, grad_tensors=output_grad_tensors  # type: ignore[arg-type]
322        )
323
324        # Extract gradients wrt the input values
325        grad_inputs = []
326        for val in input_values:
327            if isinstance(val, torch.Tensor):
328                grad_inputs.append(val.grad)
329            else:
330                grad_inputs.append(None)
331
332        # Alternative impl: `torch.autograd.grad`.
333        # Note that `torch.autograd.grad` will not accumulate gradients into the
334        # model's parameters.
335        """
336        inputs_with_grad = []
337        for val in input_values:
338            if isinstance(val, torch.Tensor) and val.requires_grad:
339                inputs_with_grad.append(val)
340
341        grad_inputs = torch.autograd.grad(
342            stage_output_tensors, inputs_with_grad, output_grad_tensors,  # type: ignore[arg-type]
343        )
344        """
345
346    except Exception as e:
347        exc_msg = f"""
348        Failed to run stage backward:
349        Stage output: {map_debug_info(stage_output)}
350        Output gradient: {map_debug_info(output_grads)}
351        Input: {map_debug_info(input_values)}
352        """
353        raise RuntimeError(exc_msg) from e
354
355    return grad_inputs
356
357
358# TODO: handling requires_grad=False dynamically. Can we analyze this during initial
359# IR emission?
360def _null_coalesce_accumulate(lhs, rhs):
361    """
362    Coalesce two values, even if one of them is null, returning the non-null
363    value.
364    """
365    if lhs is None:
366        return rhs
367    elif rhs is None:
368        return lhs
369    else:
370        return torch.add(lhs, rhs)
371