1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import collections 4import logging 5import weakref 6from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union 7 8import torch 9from torch.autograd.graph import GradientEdge, Node 10from torch.nn import Parameter 11 12from ._debug import map_debug_info 13 14 15logger = logging.getLogger(__name__) 16 17 18def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: 19 """ 20 Get the grad function or grad accumulator for a tensor. 21 22 Accumulate grad nodes are lazily created, so we need to a 23 dummy view in order to trigger its creation. 24 """ 25 if t.requires_grad and t.grad_fn is None: 26 # if no grad function (leaf tensors) we use view 27 viewed_t = t.view_as(t) 28 grad_fn = viewed_t.grad_fn 29 if grad_fn is not None: 30 return grad_fn.next_functions[0][0] 31 else: 32 raise RuntimeError( 33 "Attempted to get grad_fn, but got None." 34 "Is this being created in a no-grad context?" 35 ) 36 else: 37 return t.grad_fn 38 39 40def reverse_closure( 41 roots: List[Node], target_nodes: Set[Node] 42) -> Tuple[Set[Node], Set[Node]]: 43 """ 44 This function returns the reverse closure of the given roots, 45 i.e. the set of nodes that can be reached from the roots by following the 46 reverse edges of the graph. The target_nodes are the nodes that we want to 47 include in the closure. 48 """ 49 # Recurse until we reach a target node 50 closure: Set[Node] = set() 51 visited_target_nodes = set() 52 q: Deque[Node] = collections.deque() 53 for node in roots: 54 if node is not None and node not in closure: 55 closure.add(node) 56 q.append(node) 57 while q: 58 node = q.popleft() 59 metadata = cast(Dict[str, List], node.metadata) 60 reverse_edges = metadata.get("reverse_edges", []) 61 for holder_ref, idx in reverse_edges: 62 ref = holder_ref() 63 if ref is None: 64 # this reverse graph is no longer alive 65 # raise RuntimeError("Reverse graph is no longer alive") 66 continue 67 fn = ref.node 68 if fn in closure or fn is None: 69 continue 70 if fn in target_nodes: 71 visited_target_nodes.add(fn) 72 continue 73 closure.add(fn) 74 q.append(fn) 75 return closure, visited_target_nodes 76 77 78# Enable weak pointer 79class Holder: 80 def __init__(self, node: Node): 81 self.node = node 82 83 84def construct_reverse_graph(roots: List[Node]) -> List[Holder]: 85 q: Deque[Node] = collections.deque() 86 root_seen: Set[Node] = set() 87 reverse_graph_refs: List[Holder] = [] 88 for node in roots: 89 if node is not None and node not in root_seen: 90 q.append(node) 91 root_seen.add(node) 92 while q: 93 node = q.popleft() 94 for fn, idx in node.next_functions: 95 if fn is not None: 96 # Don't necessarily need to store on the graph 97 metadata = cast(Dict[str, List], fn.metadata) 98 reverse_edges = metadata.get("reverse_edges", []) 99 if len(reverse_edges) == 0: 100 q.append(fn) 101 holder = Holder(node) 102 holder_ref = weakref.ref(holder) 103 reverse_graph_refs.append(holder) 104 reverse_edges.append((holder_ref, idx)) 105 metadata["reverse_edges"] = reverse_edges 106 return reverse_graph_refs 107 108 109def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]: 110 """ 111 Given a list of inputs and a list of parameters, return a list of parameter 112 groups, where each group contains the parameters and the intermediates that 113 are connected to the parameters. 114 115 The returned list of parameter groups is a list of dictionaries, where each 116 dictionary contains the following keys: 117 - "params": a set of parameters 118 - "intermediates": a set of intermediates 119 120 The returned list of parameter groups is a list of dictionaries, 121 """ 122 # reverse graph that starts with inputs, and goes up to the dOutput or the loss, 123 # but omits weights and any subgraphs connecting weights to this closure 124 inputs_closure, _ = reverse_closure(inputs, set()) 125 param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates 126 for i, param in enumerate(params): 127 closure, intersected = reverse_closure([param], inputs_closure) 128 param_group: Dict[str, Set] = { 129 "params": {param}, 130 "intermediates": intersected, 131 } 132 for input_node in intersected: 133 existing = param_groups.get(input_node, None) 134 if existing is not None: 135 existing["params"] = existing["params"].union(param_group["params"]) 136 existing["intermediates"] = existing["intermediates"].union( 137 param_group["intermediates"] 138 ) 139 param_group = existing 140 else: 141 param_groups[input_node] = param_group 142 143 # Sanity check: union of all param_groups params should be equal to all params 144 union_params: Set[Node] = set() 145 seen_ids: Set[int] = set() 146 unique_param_groups = [] 147 for param_group in param_groups.values(): 148 if id(param_group) not in seen_ids: 149 seen_ids.add(id(param_group)) 150 unique_param_groups.append(param_group) 151 union_params = union_params.union(param_group["params"]) 152 153 # The assert will only be true if the input tensor requires gradients, 154 # otherwise the autograd graph will miss the first layer of inputs 155 # assert union_params == set(params) 156 return unique_param_groups 157 158 159def stage_backward_input( 160 stage_outputs: List[torch.Tensor], 161 output_grads: Optional[List[torch.Tensor]], 162 input_values: List[torch.Tensor], 163 weights: Iterator[Parameter], 164): 165 """ 166 compute the gradients for only the stage inputs with respect to the stage outputs 167 """ 168 stage_output_grad_fns: List[Node] = list( 169 filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) 170 ) 171 stage_input_grad_fns: List[Node] = list( 172 filter(None, map(_get_grad_fn_or_grad_acc, input_values)) 173 ) 174 weight_grad_fns: List[Node] = list( 175 filter(None, map(_get_grad_fn_or_grad_acc, weights)) 176 ) 177 178 reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns) 179 param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns) 180 del reverse_graph_refs 181 182 for param_group in param_groups: 183 for i, intermediate in enumerate(param_group["intermediates"]): 184 185 def get_hook(param_group, i): 186 def hook(grad_inputs): 187 if param_group.get("grads", None) is None: 188 param_group["grads"] = [None] * len( 189 param_group["intermediates"] 190 ) 191 param_group["grads"][i] = grad_inputs 192 193 return hook 194 195 # These are always "split" nodes that we need to recompute, so 196 # save their inputs. 197 intermediate.register_prehook(get_hook(param_group, i)) 198 199 # Stage 0 inputs do not require grads? Should we skip in that case? 200 if all(tensor.requires_grad for tensor in input_values): 201 if output_grads is None: 202 # In case this is the loss and there are no output_grads, then we just use 1s 203 output_grads = [ 204 torch.ones_like(stage_output) for stage_output in stage_outputs 205 ] 206 207 dinputs = torch.autograd.grad( 208 stage_outputs, 209 inputs=input_values, 210 grad_outputs=output_grads, 211 retain_graph=True, 212 ) 213 214 # update the gradients for inputs 215 for i, inp in enumerate(input_values): 216 if inp.grad is None: 217 inp.grad = dinputs[i] 218 else: 219 inp.grad += dinputs[i] 220 else: 221 dinputs = None 222 return dinputs, param_groups 223 224 225def stage_backward_weight( 226 weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] 227): 228 # map weights to param_group_weights 229 grad_acc_to_weight = {} 230 weight_grads = [] 231 for index, weight in enumerate(weights): 232 grad_acc = _get_grad_fn_or_grad_acc(weight) 233 grad_acc_to_weight[grad_acc] = weight, index 234 weight_grads.append(weight.grad) 235 236 for param_group in param_groups: 237 # TODO: Handle case where intermediate can have multiple outputs 238 intermediate_edges = tuple( 239 GradientEdge(i, 0) for i in param_group["intermediates"] 240 ) 241 weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) 242 243 assert all(len(g) == 1 for g in param_group["grads"]) 244 # [NEW!] Able to pass a GradientEdge to autograd.grad as output 245 # We do not need to retain_graph because... guarantee no overlap? 246 # print("trying to execute: ", intermediate_edges, weights_edges) 247 dweights = torch.autograd.grad( 248 intermediate_edges, 249 weights_edges, 250 grad_outputs=sum(param_group["grads"], tuple()), 251 ) 252 for grad_acc, dw in zip(param_group["params"], dweights): 253 weight, index = grad_acc_to_weight[grad_acc] 254 if weight.grad is None: 255 weight.grad = dw 256 else: 257 weight.grad += dw 258 # return grads in the original order weights were provided in 259 return weight_grads 260 261 262def stage_backward( 263 stage_output, 264 output_grads, 265 input_values, 266 outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used 267): 268 """ 269 This is a helper function to: 270 1. compute the gradients for the stage inputs, and 271 2. accumulate gradients for the stage module's parameters. 272 273 Given the input value(s) and the corresponding gradient for the output 274 value(s), compute and accumulate gradients for all parameter values (leaves 275 in the autograd trace) as well as return a list of the gradients for the 276 input values 277 """ 278 if outputs_with_grads_idxs is not None: 279 # Deprecated, not used in runtime calls, only exists in compiler 280 stage_output = [stage_output[i] for i in outputs_with_grads_idxs] 281 output_grads = [output_grads[i] for i in outputs_with_grads_idxs] 282 283 try: 284 # stage_output may be a composite datatype like dict. Extract all individual 285 # tensor values here 286 stage_output_tensors = [] 287 output_grad_tensors = [] 288 289 def extract_tensors_with_grads(output_val, grad_val): 290 if isinstance(output_val, torch.Tensor): 291 if not output_val.requires_grad and output_val.grad_fn is None: 292 return 293 assert isinstance( 294 grad_val, (torch.Tensor, type(None)) 295 ), f"Expected Tensor or None gradient but got {type(grad_val)}" 296 stage_output_tensors.append(output_val) 297 output_grad_tensors.append(grad_val) 298 elif isinstance(output_val, (tuple, list)): 299 if grad_val is None: 300 return 301 assert isinstance( 302 grad_val, (tuple, list) 303 ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" 304 assert len(output_val) == len(grad_val) 305 for ov, gv in zip(output_val, grad_val): 306 extract_tensors_with_grads(ov, gv) 307 elif isinstance(output_val, dict): 308 if grad_val is None: 309 return 310 assert isinstance(grad_val, dict) 311 assert set(output_val.keys()) == set(grad_val.keys()) 312 for k in output_val.keys(): 313 extract_tensors_with_grads(output_val[k], grad_val[k]) 314 else: 315 # Output is a non-tensor type; just ignore it 316 pass 317 318 extract_tensors_with_grads(stage_output, output_grads) 319 320 torch.autograd.backward( 321 stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] 322 ) 323 324 # Extract gradients wrt the input values 325 grad_inputs = [] 326 for val in input_values: 327 if isinstance(val, torch.Tensor): 328 grad_inputs.append(val.grad) 329 else: 330 grad_inputs.append(None) 331 332 # Alternative impl: `torch.autograd.grad`. 333 # Note that `torch.autograd.grad` will not accumulate gradients into the 334 # model's parameters. 335 """ 336 inputs_with_grad = [] 337 for val in input_values: 338 if isinstance(val, torch.Tensor) and val.requires_grad: 339 inputs_with_grad.append(val) 340 341 grad_inputs = torch.autograd.grad( 342 stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] 343 ) 344 """ 345 346 except Exception as e: 347 exc_msg = f""" 348 Failed to run stage backward: 349 Stage output: {map_debug_info(stage_output)} 350 Output gradient: {map_debug_info(output_grads)} 351 Input: {map_debug_info(input_values)} 352 """ 353 raise RuntimeError(exc_msg) from e 354 355 return grad_inputs 356 357 358# TODO: handling requires_grad=False dynamically. Can we analyze this during initial 359# IR emission? 360def _null_coalesce_accumulate(lhs, rhs): 361 """ 362 Coalesce two values, even if one of them is null, returning the non-null 363 value. 364 """ 365 if lhs is None: 366 return rhs 367 elif rhs is None: 368 return lhs 369 else: 370 return torch.add(lhs, rhs) 371