1# mypy: allow-untyped-defs 2""" 3Contains various utils for AOTAutograd, including those for handling collections. 4""" 5 6import dataclasses 7import operator 8import warnings 9from contextlib import nullcontext 10from functools import wraps 11from typing import Any, Callable, List, Optional, Tuple, Union 12 13import torch 14import torch.utils._pytree as pytree 15from torch._library.fake_class_registry import FakeScriptObject 16from torch._logging import getArtifactLogger 17from torch.fx.experimental._backward_state import BackwardState 18from torch.fx.experimental.proxy_tensor import py_sym_types 19 20 21KNOWN_TYPES = [ 22 torch.Tensor, 23 BackwardState, 24 int, 25 str, 26 float, 27 bool, 28 type(None), 29 *py_sym_types, 30 FakeScriptObject, 31 torch.ScriptObject, 32] 33 34original_zip = zip 35 36aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") 37 38 39def strict_zip(*iterables, strict=True, **kwargs): 40 if not strict: 41 return original_zip(*iterables, **kwargs) 42 43 length = len(iterables[0]) 44 for iterable in iterables[1:]: 45 if len(iterable) != length: 46 raise ValueError( 47 "The iterables have different lengths and strict mode is enabled." 48 ) 49 50 return original_zip(*iterables, **kwargs) 51 52 53def _get_symint_hints(exprs): 54 """ 55 Get the hints of a list/tuple of int/SymInt. 56 """ 57 if isinstance(exprs, (list, tuple)): 58 return type(exprs)(_get_symint_hints(e) for e in exprs) 59 elif isinstance(exprs, torch.SymInt): 60 return exprs.node.shape_env.size_hint(exprs.node.expr) 61 else: 62 return exprs 63 64 65def partial_flatten_asdict(obj: Any) -> Any: 66 if dataclasses.is_dataclass(obj): 67 return { 68 field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) 69 } 70 elif isinstance(obj, (list, tuple)): 71 return obj.__class__([partial_flatten_asdict(item) for item in obj]) 72 elif isinstance(obj, dict): 73 return {k: partial_flatten_asdict(v) for k, v in obj.items()} 74 else: 75 return obj 76 77 78def normalize_as_list(x): 79 if isinstance(x, tuple): 80 return list(x) 81 elif isinstance(x, list): 82 return x 83 return [x] 84 85 86def _get_autocast_states(): 87 return [ 88 torch.is_autocast_enabled("cuda"), 89 torch.is_autocast_enabled("cpu"), 90 torch.get_autocast_dtype("cuda"), 91 torch.get_autocast_dtype("cpu"), 92 torch.is_autocast_cache_enabled(), 93 ] 94 95 96def make_boxed_func(f): 97 def g(args): 98 return f(*args) 99 100 g._boxed_call = True # type: ignore[attr-defined] 101 return g 102 103 104def make_boxed_compiler(compiler): 105 @wraps(compiler) 106 def f(fx_g, inps): 107 out_f = compiler(fx_g, inps) 108 fx_g = make_boxed_func(out_f) 109 return fx_g 110 111 return f 112 113 114def call_func_at_runtime_with_args( 115 f, args: Union[Tuple[Any], List[Any]], steal_args=False, disable_amp=False 116): 117 if not steal_args: 118 args = list(args) 119 assert isinstance(args, list) 120 121 context = torch._C._DisableAutocast if disable_amp else nullcontext 122 with context(): 123 if hasattr(f, "_boxed_call"): 124 out = normalize_as_list(f(args)) 125 else: 126 # TODO: Please remove soon 127 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 128 warnings.warn( 129 "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " 130 "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " 131 "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." 132 ) 133 out = normalize_as_list(f(*args)) 134 return out 135 136 137# Inspired by autodidax (thanks!) 138class PytreeThunk: 139 spec: Optional[pytree.TreeSpec] = None 140 # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. 141 is_simple: Optional[ 142 bool 143 ] = None # if the output spec is a tuple/list, we won't bother unflattening it. 144 is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec 145 146 def set(self, spec: pytree.TreeSpec) -> None: 147 assert self.spec is None or self.spec == spec 148 assert spec is not None 149 self.spec: pytree.TreeSpec = spec 150 if self.spec.type in {tuple, list} and all( 151 child.is_leaf() for child in spec.children_specs 152 ): 153 self.is_simple = True 154 if self.spec.is_leaf(): 155 self.is_really_simple = True 156 157 def unflatten(self, x: List[Any]) -> Any: 158 if self.is_really_simple: 159 return x[0] 160 if self.is_simple: 161 return x 162 assert self.spec is not None 163 return pytree.tree_unflatten(x, self.spec) 164 165 166# Creates a function that returns flattened inputs and outputs 167# Also returns the output tree spec, which is needed to recover the "unflattened" 168# output tree structure later. 169def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]: 170 if kwargs is None: 171 kwargs = {} 172 # Save the args_spec for flat_tensor_args to unflatten while tracing 173 _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) 174 out_spec = PytreeThunk() 175 176 def flat_fn(*flat_args): 177 # The input are flattened tensor args. Prepare the args in the 178 # order that original function expects. Add static args as well. 179 # They will appear as tensor constants in the traced graph. 180 nonlocal out_spec 181 args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec) 182 tree_out = fn(*args, **kwargs) 183 flat_out, spec = pytree.tree_flatten(tree_out) 184 for i in flat_out: 185 is_known_type = False 186 for j in KNOWN_TYPES: 187 if isinstance(i, j): 188 is_known_type = True 189 break 190 if not is_known_type: 191 raise RuntimeError( 192 f"Found {type(i)} in output, which is not a known type. " 193 "If this type holds tensors, you need to register a pytree for it. " 194 "See https://github.com/pytorch/functorch/issues/475 for a brief " 195 "explanation why. If you don't need to register a pytree, please " 196 "leave a comment explaining your use case and we'll make this more " 197 "ergonomic to deal with" 198 ) 199 out_spec.set(spec) 200 return flat_out 201 202 # Can't use functools.wraps here because the wrapper has different 203 # calling convention 204 if hasattr(fn, "_orig_mod"): 205 flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined] 206 207 return flat_fn, out_spec 208 209 210# This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). 211# When tracing the joint forward + backward, for any inputs in the graph that are mutated, 212# we need to clone them first (and similarly for metadata-only mutations, we need to view them first). 213# The idea is that when we trace the backward, we need to pass in the *original* primals 214# to autograd.grad(), before they were mutated. 215# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them. 216# This means that "idx" here represents the index of the (potentially) synthetic base. 217# What we need to do is: 218# (1) map the current (post-synthetic-base calling convention) input argument index 219# to int index pre-synthetic-base-calling-convention. 220# (2) There could be multiple, if this index corresponds to a synthetic base 221# that has multiple input aliases. 222# (3) If any of those corresponding inputs get metadata mutations, then we clone the base. 223def maybe_to_fresh_input(idx, t, meta): 224 if not isinstance(t, torch.Tensor): 225 return t 226 if idx in meta.mutated_inp_runtime_indices: 227 # We only need to bother cloning mutated inputs that participate in autograd. 228 mutated_inp_idx = meta.mutated_inp_runtime_indices.index(idx) 229 if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: 230 # Make sure the primal we pass to autograd.grad() 231 # sees the tensor before the mutation 232 return t.clone() 233 if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: 234 # Make sure the primal we pass to autograd.grad() 235 # sees the tensor before the metadata mutation 236 return t.view(t.shape) 237 return t 238 239 240def is_with_effects(node): 241 return ( 242 node.op == "call_function" 243 and node.target == torch.ops.higher_order.with_effects 244 ) 245 246 247def is_with_effects_op(node, op): 248 return is_with_effects(node) and node.args[1] == op 249 250 251def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): 252 # Remove the tokens from the inputs/outputs of the graph since inductor does 253 # not want these extra inputs/outputs, and replace them with 254 # _make_token() to create a token, and _sink_tokens() to collect the 255 # tokens. See Note [Side-Effectful Tokens in AOTAutograd] 256 # Logic: 257 # 1. Inputs identified as input tokens: 258 # - If used as a first argument in with_effects 259 # 260 # 2. Outputs identified as output tokens: 261 # - If Produced by getitem(with_effects, 0) 262 # 263 # 3. Checks invariants of number input output tokens: 264 # forward: 265 # expected_num_erased_inputs == len(fw_metadata.tokens) 266 # expected_num_erased_outputs == len(fw_metadata.tokens) 267 # backward: 268 # expected_num_erased_inputs == fw_metadata.num_backward_tokens 269 # expected_num_erased_outputs == fw_metadata.num_backward_tokens 270 num_forward_tokens = len(fw_metadata.tokens) 271 num_backward_tokens = fw_metadata.num_backward_tokens 272 273 def rewrite_with_effects_input_token(module, node): 274 with module.graph.inserting_before(node): 275 new_token_node = module.graph.call_function( 276 torch.ops.prims._make_token.default, () 277 ) 278 new_token_node.meta["val"] = torch.tensor([]) 279 new_token_node.meta["tensor_meta"] = torch.tensor([]) 280 281 args = list(node.args) 282 args[0] = new_token_node 283 node.args = tuple(args) 284 285 def rewrite_output(module, node, output_token_nodes, other_output_args): 286 for output_token_node in output_token_nodes: 287 assert ( 288 output_token_node.op == "call_function" 289 and output_token_node.target == operator.getitem 290 and output_token_node.args[1] == 0 291 ) 292 with module.graph.inserting_before(node): 293 module.graph.call_function( 294 torch.ops.prims._sink_tokens.default, 295 (output_token_nodes,), 296 ) 297 node.args = (other_output_args,) 298 299 def do(module, subgraph, expected_num_erased): 300 num_erased_inputs = 0 301 num_erased_outs = 0 302 input_nodes = [] 303 input_token_nodes = set() 304 with_effect_nodes = [] 305 output_token_nodes = [] 306 other_output_nodes = [] 307 for i, node in enumerate(module.graph.nodes): 308 if node.op == "placeholder": 309 input_nodes.append(node) 310 elif is_with_effects(node): 311 with_effect_nodes.append(node) 312 if node.args[0] in input_nodes: 313 input_token_nodes.add(node.args[0]) 314 rewrite_with_effects_input_token(module, node) 315 elif node.op == "output": 316 outs = node.args[0] 317 for out in outs: 318 if ( 319 isinstance(out, torch.fx.node.Node) 320 and out.op == "call_function" 321 and out.target == operator.getitem 322 and out.args[1] == 0 323 and out.args[0] in with_effect_nodes 324 ): 325 output_token_nodes.append(out) 326 else: 327 other_output_nodes.append(out) 328 329 rewrite_output(module, node, output_token_nodes, other_output_nodes) 330 num_erased_outs = len(output_token_nodes) 331 332 for input_token_node in input_token_nodes: 333 module.graph.erase_node(input_token_node) 334 335 num_erased_inputs = len(input_token_nodes) 336 337 assert ( 338 num_erased_inputs == expected_num_erased 339 ), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" 340 assert ( 341 num_erased_outs == expected_num_erased 342 ), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" 343 344 module.recompile() 345 346 if num_forward_tokens > 0: 347 if aot_config.enable_log: 348 from torch._dynamo.utils import lazy_format_graph_code 349 350 aot_graphs_effects_log.debug( 351 "%s", 352 lazy_format_graph_code( 353 "Forward graph before unlifting tokens", 354 fw_module, 355 aot_config.aot_id, 356 include_stride=True, 357 include_device=True, 358 colored=True, 359 ), 360 ) 361 do( 362 fw_module, 363 "forward", 364 num_forward_tokens, 365 ) 366 367 if bw_module is not None and num_backward_tokens > 0: 368 if aot_config.enable_log: 369 from torch._dynamo.utils import lazy_format_graph_code 370 371 aot_graphs_effects_log.debug( 372 "%s", 373 lazy_format_graph_code( 374 "Backward graph before unlifting tokens", 375 bw_module, 376 aot_config.aot_id, 377 include_stride=True, 378 include_device=True, 379 colored=True, 380 ), 381 ) 382 do(bw_module, "backward", num_backward_tokens) 383 384 # This is sad, but we need to update the metadata to get rid of 385 # the tokens. 386 fw_metadata.tokens = {} 387 fw_metadata.num_backward_tokens = 0 388 389 390def root_module_when_exporting_non_strict(flat_fn): 391 # When exporting in non-strict mode, we wrap the root module in a specific pattern. 392 # See `_aot_export_non_strict` in torch.export._trace.py. 393 # We look for that wrapping pattern here. 394 if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"): 395 return flat_fn._orig_mod._export_root 396 else: 397 return None 398 399 400def copy_fwd_metadata_to_bw_nodes(fx_g): 401 """ 402 Input: `fx_g` which contains the joint fwd+bwd FX graph created by 403 aot_autograd. 404 405 This function walks the graph and copies over metadata from forward nodes 406 to backward nodes, using the `seq_nr` field as a one-to-many mapping 407 from forward node to backward node. This metadata is useful for performance 408 profiling and debugging. 409 """ 410 411 def _is_forward_node_with_seq_nr(node): 412 # For now, assume that if nn_module_stack_metadata is populated, this 413 # node is from the forward. Ignore nodes without `seq_nr`. 414 # TODO(future): there is likely a less brittle way to do this by walking 415 # the descendants of graph inputs corresponding to fwd inputs, didn't 416 # seem obvious at first glance on how to partition graph inputs into 417 # fwd vs bwd without relying on string names. 418 return "nn_module_stack" in node.meta and "seq_nr" in node.meta 419 420 def _is_backward_node_with_seq_nr(node): 421 # For now, assume that if nn_module_stack_metadata is not populated, 422 # this node is from the backward. Ignore nodes without `seq_nr`. 423 # TODO(future): there is likely a less brittle way to do this, same 424 # as with the forward. 425 return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta 426 427 fwd_seq_nr_to_node = {} 428 for node in fx_g.graph.nodes: 429 if not _is_forward_node_with_seq_nr(node): 430 continue 431 seq_nr = node.meta["seq_nr"] 432 if seq_nr in fwd_seq_nr_to_node: 433 # If we already saw an op with the current `seq_nr`, that means 434 # that the current op did not create an autograd node, and there 435 # is no corresponding backward node, so we skip. 436 continue 437 fwd_seq_nr_to_node[node.meta["seq_nr"]] = node 438 439 for node in fx_g.graph.nodes: 440 if not _is_backward_node_with_seq_nr(node): 441 continue 442 # fwd_node should always exist, but handle non-existence just in case 443 fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"]) 444 if fwd_node is not None: 445 node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"] 446 node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") 447