1# mypy: allow-untyped-defs 2""" 3This module is responsible for transforming functions to be traced into a form 4that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) 5to handle. 6 7It does so by: 81. functionalization (including RNG functionalzation) 92. creating a joint graph when required 103. transforming mutations into extra outputs 114. dispatching subclasses 12""" 13 14import warnings 15from contextlib import contextmanager, nullcontext 16from functools import wraps 17from typing import Any, Callable, List, Tuple, Union 18from unittest.mock import patch 19 20import torch 21import torch.fx.traceback as fx_traceback 22import torch.utils._pytree as pytree 23from torch import Tensor 24from torch._decomp.decompositions_for_rng import PhiloxStateTracker 25from torch._guards import detect_fake_mode 26from torch._prims_common import CUDARngStateHelper 27from torch.fx.experimental.proxy_tensor import ( 28 maybe_disable_thunkify, 29 maybe_enable_thunkify, 30) 31from torch.fx.experimental.symbolic_shapes import ( 32 definitely_false, 33 PropagateUnbackedSymInts, 34 sym_eq, 35) 36from torch.nn.utils import stateless 37 38from .. import config 39from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata 40from .functional_utils import ( 41 from_fun, 42 has_data_mutation, 43 has_metadata_mutation, 44 is_fun, 45 sync_functional_tensor, 46 to_fun, 47 was_inductor_storage_resized, 48) 49from .logging_utils import setup_stacktrace_preservation_hooks 50from .schemas import ( 51 AOTConfig, 52 MutationType, 53 OutputType, 54 SubclassMeta, 55 SubclassTracingInfo, 56 ViewAndMutationMeta, 57) 58from .subclass_utils import ( 59 create_subclass_meta, 60 remap_unwrapped_subclass_arg_indices, 61 requires_subclass_dispatch, 62 unwrap_tensor_subclasses, 63 wrap_tensor_subclasses_maybe_joint, 64) 65from .utils import maybe_to_fresh_input 66 67 68# This function returns a new function that returns mutated inputs as outputs. 69# if keep_data_input_mutations is set, then we assume that data-only mutations 70# will be left in the graph, and we only return metadata-mutated inputs as outputs. 71def fn_input_mutations_to_outputs( 72 fn: Callable, 73 meta: ViewAndMutationMeta, 74 keep_data_input_mutations: bool, 75) -> Any: 76 @wraps(fn) 77 def inner_fn(*args): 78 outs = fn(*args) 79 assert len(meta.output_info) == len(outs) 80 # The compiled fw will return mutated input tensors, *including* metadata-only mutation. 81 # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. 82 # (because data-only input mutations are handled directly in the compiled graph) 83 mutated_inputs_to_return = [ 84 x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices 85 ] 86 return *mutated_inputs_to_return, *outs 87 88 return inner_fn 89 90 91# This function takes in a fn with external aliasing and mutation, 92# and returns a new fn with no external aliasing and mutation, 93# as needed for autograd. 94# The main transformations are: 95# - Return mutated inputs as extra outputs 96# - Clone mutated inputs that require gradients, 97# because autograd will require us to pass the pre-mutated inputs into autograd.grad 98# - Return intermediate bases of outputs as additional outputs, 99# needed to appease autograd.Function 100# The new function returns: 101# (1) The updated outputs 102# (2) A boolean mask of len(new_fn_outputs), 103# that can be used to tell autograd.grad which outputs should get tangents 104# if we trace the backward. 105def fn_prepped_for_autograd( 106 fn: Callable, 107 meta: ViewAndMutationMeta, 108) -> Any: 109 @wraps(fn) 110 def inner_fn(*args): 111 args_maybe_cloned = [ 112 maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) 113 ] 114 115 outs = fn(*args_maybe_cloned) 116 assert isinstance(outs, (tuple, list)) 117 outs = list(outs) 118 assert len(meta.output_info) == len(outs) 119 120 mutated_inputs_to_return = [ 121 x 122 for (i, x) in enumerate(args_maybe_cloned) 123 if i in meta.mutated_inp_runtime_indices 124 ] 125 126 intermediate_bases = [] 127 for i, (o, info) in enumerate(zip(outs, meta.output_info)): 128 if info.output_type == OutputType.alias_of_intermediate_save_as_output: 129 intermediate_bases.append(o._base) 130 131 assert meta.num_intermediate_bases == len(intermediate_bases) 132 133 # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) 134 fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases 135 136 # Also return a boolean mask specifying which outputs to this function will be used as tangents 137 mutated_inputs_grad_mask = [ 138 meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data 139 and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad 140 for (i, x) in enumerate(mutated_inputs_to_return) 141 ] 142 143 # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw 144 # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, 145 # which we *should* send to grad() 146 output_grad_mask = [ 147 meta.output_info[i].output_type 148 in [ 149 OutputType.non_alias, 150 OutputType.unsafe_view_alias, 151 OutputType.custom_function_view, 152 ] 153 # Also, only tensor outputs should participate in the backward 154 # (in particular, Symint outputs in the forward graph shouldn't get tangents) 155 and issubclass(meta.output_info[i].raw_type, Tensor) 156 and meta.output_info[i].requires_grad 157 for (i, x) in enumerate(outs) 158 ] 159 160 intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] 161 162 out_grad_mask = ( 163 mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask 164 ) 165 assert len(out_grad_mask) == len(fw_outs_to_return) 166 167 # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) 168 # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) 169 # This is annoying: our joint function needs to be aware of functionalization 170 # (syncing mutated inputs before calling autograd.grad()) 171 # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. 172 for arg in args_maybe_cloned: 173 if not isinstance(arg, Tensor): 174 continue 175 sync_functional_tensor(arg) 176 177 return fw_outs_to_return, out_grad_mask 178 179 return inner_fn 180 181 182# Given a fn, computes the joint. 183# NOTE: fn is expects the following behavior: 184# (1) fn() needs to return a tuple of (outs, mask), 185# where `mask` tells us which outputs are meant to have tangents. 186# we don't know this info automatically, because we don't actually want to blindly 187# compute tangents for every output that requires grad. 188# Specifically, outputs that alias inputs won't participate in the backward and get tangents. 189# (2) fn() cannot mutate any inputs that require gradient. 190# otherwise, when we compute autograd.grad(), we will not take those input mutations into account 191# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) 192def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: 193 def inner_fn(primals: List[Any], tangents: List[Any]): 194 outs, tangent_mask = fn(*primals) 195 196 assert len(tangent_mask) == len(outs) 197 outs_to_grad = [ 198 o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent 199 ] 200 assert len(outs_to_grad) == len(tangents) 201 202 # Get the inputs that need gradients 203 grad_primals = [] 204 inputs_needs_grads = [] 205 # Note that we're not using primals here, 206 # being carefully not to pass any mutated inputs into autograd.grad() 207 for p in primals: 208 is_grad_tensor = isinstance(p, Tensor) and p.requires_grad 209 inputs_needs_grads.append(is_grad_tensor) 210 if is_grad_tensor: 211 grad_primals.append(p) 212 213 # Get the outputs that need gradients 214 needed_outs = [] 215 needed_tangents = [] 216 for out, tangent in zip(outs_to_grad, tangents): 217 if isinstance(out, Tensor) and out.requires_grad: 218 # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 219 # The issue is that we are sensitive to decomps that don't accurately maintain 220 # their output's _base.shape compared to eager mode, and this helps mitigate a bit. 221 # The not definitely_false is also sketchy; if unbacked 222 # symints are involved, we're just going to assume that the 223 # decomps setup the base shape correctly 224 needed_outs.append( 225 out 226 if not definitely_false(sym_eq(out.shape, tangent.shape)) 227 else out.view(tangent.shape) 228 ) 229 needed_tangents.append(tangent) 230 231 setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) 232 233 if config.functionalize_rng_ops: 234 PhiloxStateTracker.mark_beginning_of_backward() 235 backward_out: Tuple[Tensor, ...] = () 236 # Call the backwards pass 237 if grad_primals: 238 functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( 239 torch._C._TorchDispatchModeKey.FUNCTIONAL 240 ) 241 if functional_tensor_mode is not None: 242 # Side-Effect Tokens: 243 # We want to have independent chains of tokens for forward and backward. 244 # functional_tensor_mode._tokens is used by both. 245 # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output, 246 # to return them as joint graph outputs. 247 # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward. 248 # Joint graph tracing allows tokens discovery, 249 # So all the tokens in backward will be created and added as a graph inputs during tracing. 250 functional_tensor_mode._tokens_forward_output = ( 251 functional_tensor_mode._tokens 252 ) 253 functional_tensor_mode._tokens = {} 254 255 with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta(): 256 # for full graph export, we always export a joint graph where we assume no tangents are needed. 257 if aot_config.no_tangents: 258 assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 259 backward_out = torch.autograd.grad( 260 needed_outs, 261 grad_primals, 262 allow_unused=True, 263 ) 264 else: 265 backward_out = torch.autograd.grad( 266 needed_outs, 267 grad_primals, 268 grad_outputs=needed_tangents, 269 allow_unused=True, 270 ) 271 backward_out_iter = iter(backward_out) 272 return outs, [ 273 next(backward_out_iter) if i else None for i in inputs_needs_grads 274 ] 275 276 def inner_fn_with_anomaly(*args): 277 with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): 278 warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") 279 with torch.autograd.detect_anomaly(check_nan=False): 280 return inner_fn(*args) 281 282 return inner_fn_with_anomaly 283 284 285def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any: 286 # Functionalization of rng ops changes the calling convention of the joint graph. 287 # It goes from (primals, tangents) to (seed, offset, primals, tangents) 288 # At runtime, we pass on the current seed and offset. This is hidden from 289 # the user. 290 fake_mode = detect_fake_mode() 291 if fake_mode is None: 292 fake_mode = nullcontext() 293 294 def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): 295 out = PhiloxStateTracker.get_state_as_tensor() 296 return out 297 298 def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): 299 PhiloxStateTracker.set_state_from_tensor(x) 300 301 def append_rng_offsets(args): 302 if trace_joint: 303 # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) 304 # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) 305 return ( 306 (*args[0], PhiloxStateTracker.get_updated_fwd_offset()), 307 (*args[1], PhiloxStateTracker.get_updated_bwd_offset()), 308 ) 309 else: 310 # args signature before: Tuple(fwd_outputs) 311 # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset) 312 return (*args, PhiloxStateTracker.get_updated_fwd_offset()) 313 314 def traced_joint( 315 primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset 316 ): 317 with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( 318 "torch.cuda.set_rng_state", override_set_rng_state 319 ): 320 return append_rng_offsets(func(primals, tangents)) 321 322 def traced_forward(*primals_fwd_seed_fwd_base_offset): 323 # The signature is (*primals, seed, offset) 324 with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( 325 "torch.cuda.set_rng_state", override_set_rng_state 326 ): 327 return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) 328 329 if trace_joint: 330 # Get the current seed and offset to setup tracing. 331 fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( 332 fake_mode 333 ) 334 bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( 335 fake_mode 336 ) 337 PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") 338 PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") 339 return traced_joint, ( 340 *args, 341 fwd_seed, 342 fwd_base_offset, 343 bwd_seed, 344 bwd_base_offset, 345 ) 346 else: 347 # Get the current seed and offset to setup tracing. 348 fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( 349 fake_mode 350 ) 351 PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") 352 return traced_forward, (*args, fwd_seed, fwd_base_offset) 353 354 355@contextmanager 356def set_partitioner_tag(tag: str): 357 meta_key = "partitioner_tag" 358 assert fx_traceback.has_preserved_node_meta() 359 360 original_val = fx_traceback.current_meta.get(meta_key, None) 361 fx_traceback.current_meta[meta_key] = tag 362 try: 363 yield 364 finally: 365 fx_traceback.current_meta[meta_key] = original_val 366 367 368def set_partitioner_tag_is_backward(): 369 return set_partitioner_tag("is_backward") 370 371 372def set_partitioner_tag_must_be_in_backward(): 373 return set_partitioner_tag("must_be_in_backward") 374 375 376# This creates the final function that we want to trace using make_fx(), 377# in both aot_dispatch_autograd and aot_dispatch_base. 378# Preconditions: 379# - fn corresponds to the user's fw function 380# - fn arguments have been flattened, duplicate arguments have been handled 381# - In the returned function, the "primals" arguments *includes* synthetic bases. 382# This function does the work of functionalizing the input function, 383# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. 384# The function returned has signature that is either: 385# (1) "traced_fn(primals: List[Any])" if trace_joint is False 386# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True 387# Returns a new (functionalized) function, and updated arguments to call it with. 388def create_functionalized_fn( 389 fn, 390 args, 391 *, 392 meta: ViewAndMutationMeta, 393 aot_config: AOTConfig, 394 trace_joint: bool, 395) -> Any: 396 @wraps(fn) 397 def _functionalized_f_helper(*args): 398 with maybe_enable_thunkify(): 399 # See Note [Disabling Functionalize TLS Above Python Functionalization] 400 disable_above = torch._C._ExcludeDispatchKeyGuard( 401 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 402 ) 403 404 with disable_above: 405 # The functionalization code here can potentially trigger traces 406 # into the graph, but we'd prefer to NOT do this, because if we 407 # trace them now, we will end up with FX nodes that don't have 408 # module stack annotations, which makes unflattener unhappy. 409 # Wrap inputs into functional wrappers 410 f_args = pytree.tree_map(to_fun, args) 411 412 # Run the joint 413 f_outs = fn(*f_args) 414 415 if trace_joint: 416 # We support a limited amount of mutation of graph inputs during the backward pass. 417 # (This is used e.g. by Float8, which needs to update buffers during the backward pass) 418 # Here, we perform extra checks for primals that were mutated in the **backward** 419 # We're doing the checks here instead of doing them with the rest of the input mutation handling because: 420 # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened 421 # during the forward, because the handling is different: some input mutations from the the forward 422 # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same 423 # types of mutations in the backward we would need a bw-only runtime epilogue. 424 # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in 425 # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would 426 # require an extra round of tracing though, so it's more efficient to do in-line here. 427 assert ( 428 isinstance(args, tuple) 429 and len(args) == 2 430 and isinstance(args[0], (list, tuple)) 431 ) 432 # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) 433 primals_before = args[0] 434 primals_after = pytree.tree_map(from_fun, f_args[0]) 435 for idx, (f_inpt, before, after, inpt_info) in enumerate( 436 zip(f_args[0], primals_before, primals_after, meta.input_info) 437 ): 438 # Store information about mutations in joint(for backward analysis) 439 joint_mutates_data = has_data_mutation(f_inpt) 440 441 joint_mutates_metadata = has_metadata_mutation( 442 f_inpt, before, check_only_storage_mutation=False 443 ) 444 445 # Ban metadata mutations on fw inputs during the bw 446 if not inpt_info.mutates_metadata: 447 assert ( 448 not joint_mutates_metadata 449 ), "Found a graph input that had its metadata mutated in the backward. This is not supported" 450 451 # Ban storage resizing on fw inputs during the bw 452 if not inpt_info.mutation_inductor_storage_resize: 453 assert not was_inductor_storage_resized( 454 f_inpt 455 ), "Found a graph input that had storage resizing in the backward. This is not supported" 456 457 # Allow data mutations on fw inputs during the bw, but only if they do not require grad 458 # So we can guarantee that we can keep the mutations in the graph 459 if ( 460 joint_mutates_data 461 and not inpt_info.mutates_data 462 and not inpt_info.mutates_storage_metadata 463 ): 464 # Not banning here mutations on inpt_info.requires_grad - 465 # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) 466 # Add node meta for copy_ for partitioner that this node should be in backward graph. 467 with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward(): 468 before.copy_(after) 469 meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( 470 idx 471 ) 472 # Now that we covered mutations to *forward* inputs during the backward, 473 # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). 474 # Today, we will just error in all cases of this happening unless someone needs us to support it. 475 tangents_before = args[1] 476 tangents_after = pytree.tree_map(from_fun, f_args[1]) 477 for f_inpt, before, after in zip( 478 f_args[1], tangents_before, tangents_after 479 ): 480 assert not has_metadata_mutation( 481 f_inpt, before, check_only_storage_mutation=False 482 ) and not has_data_mutation( 483 f_inpt 484 ), "Found an input to the backward that was mutated during the backward pass. This is not supported" 485 486 if aot_config.keep_inference_input_mutations: 487 # Note: This is a bit annoying. There's a layering issue here, where: 488 # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. 489 # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. 490 # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, 491 # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). 492 # This makes it pretty difficult for this logic to operate on synthetic bases. 493 # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual 494 # (unpacked) input aliases, instead of the synthetic base. 495 # Example case where (3) could be important: 496 # 497 # def f(x, y): 498 # x.mul_(2) 499 # y.mul_(3) 500 # return x, y 501 # a = torch.ones(1'000'000) 502 # x, y = out(a[0:9], a[1:10]) 503 # 504 # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing 505 # a giant "updated synthetic base" and copying into a's entire storage. 506 # 507 # For now, we are pessimistically not performing the optimization from (3); 508 # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. 509 # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry 510 # about synthetic bases. 511 for i, (inpt_old, inpt_f) in enumerate( 512 zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) 513 ): 514 if not isinstance(inpt_f, torch.Tensor): 515 continue 516 assert is_fun(inpt_f) 517 inpt_new = from_fun(inpt_f) 518 if ( 519 meta.input_info[i].mutation_type 520 == MutationType.MUTATED_IN_GRAPH 521 ): 522 # See Note [set_() Input Mutations in AOTAutograd] 523 # all mutations on the input must be under no_grad, so it is safe to put in the graph 524 # Here, we're saying that if an input experienced a set call, inp.set_(other), 525 # then we can effectively not have to worry about whether its data was mutated. 526 # There are 3 cases: 527 # (1) We mutate inp *after* the set_() call. other is a graph intermediate. 528 # In this case, we're not really mutating the input storage of "inp"; 529 # we're mutating the storage of an intermdiate value (other), 530 # and slamming that storage into the input tensor. So no data mutation is necessary. 531 # (2) We mutate inp *after* the set_() call. other is a graph *input*. 532 # In this case, the data mutation will be properly handled in the runtime 533 # epilogue during the processing of "other" 534 # (3) We mutate inp *before* the set_() call. 535 # This case is *not* currently handled. 536 if meta.input_info[i].mutates_storage_metadata: 537 with torch.no_grad(): 538 inpt_old.set_(inpt_new) 539 540 # Note [Ordering of resize_() and set_()] 541 # Importantly: the common usage in FSDP is that we have a dummy parameter 542 # that sees a set_() and **Then** a resize_(). 543 # We must put those mutations into the graph in the same order, 544 # Since running them in the opposite order will have different behavior. 545 # We fully ban resize_() followed by set_() for now, although in principal 546 # we could support this 547 if meta.input_info[i].mutation_inductor_storage_resize: 548 # resizing is not supported on subclasses (we error earlier if this happens) 549 from torch._subclasses.functional_tensor import ( 550 FunctionalTensor, 551 ) 552 553 assert isinstance(inpt_f, FunctionalTensor) 554 old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] 555 inpt_f.elem, before=True 556 ) 557 new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] 558 inpt_f.elem, before=False 559 ) 560 if old_storage_size != new_storage_size: 561 assert ( 562 old_storage_size == 0 or new_storage_size == 0 563 ), f"""\ 564 Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} 565 We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 566 (the case for FSDP)""" 567 torch.ops.inductor.resize_storage_bytes_( 568 inpt_old, new_storage_size 569 ) 570 if new_storage_size == 0: 571 # Even if we marked the input as having a data mutation (thus needing a copy_()), 572 # We should **ignore** it if our input has no storage 573 # (this can happen if, e.g. we temporarily resize our input, copy data into it, 574 # and resize it back down to zero) 575 continue 576 # Optimization: if the copy_() is a no-op then don't include it in the graph. 577 # In theory inductor could optimize this away, however in fsdp, we end up with 578 # param.copy_(param), where param is a zero-storage-size tensor, 579 # and running this op in eager mode (using the aot_eager backend) will result in a segfault. 580 # So we may as well optimize it away here. 581 if inpt_old is inpt_new: 582 # (This check needs to be done after putting resize_() in the graph, 583 # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) 584 continue 585 # We found an input that had a (data-only) mutation. 586 # Since keep_input_mutations is set, we need to faithfully apply a copy_() 587 # so the compiler will see the input mutation in the graph. 588 if ( 589 meta.input_info[i].mutates_data 590 and meta.input_info[i].mutations_hidden_from_autograd 591 ): 592 # Hidden from autograd = run under no_grad, **and** don't bump VC 593 # (although if the tensor was created in inference mode, it has no VC) 594 if inpt_old.is_inference(): 595 maybe_preserve_vc = nullcontext() 596 else: 597 maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( 598 inpt_old # type: ignore[assignment] 599 ) 600 with torch.no_grad(), maybe_preserve_vc: 601 inpt_old.copy_(inpt_new) 602 elif ( 603 meta.input_info[i].mutates_data 604 and meta.input_info[ 605 i 606 ].mutations_under_no_grad_or_inference_mode 607 ): 608 # Under no_grad = run under no_grad (we still bump the VC though) 609 # (inference_mode will also bump the VC, as long as the tensor in question 610 # was created outside of inference_mode) 611 with torch.no_grad(): 612 inpt_old.copy_(inpt_new) 613 elif meta.input_info[i].mutates_data: 614 inpt_old.copy_(inpt_new) 615 616 # When an output tensor is a functionalized mutated input, and we 617 # were able to move the mutation in to the graph then we can return 618 # the mutated input directly. This prevents duplicating the 619 # tensors contents. 620 flat_outs, outs_spec = pytree.tree_flatten(f_outs) 621 flat_outs = [from_fun(o) for o in flat_outs] 622 num_outs = len(meta.output_info) 623 624 for i, outp in enumerate(flat_outs[:num_outs]): 625 info = meta.output_info[i] 626 if info.output_type != OutputType.is_input: 627 continue 628 629 assert info.base_idx is not None 630 if ( 631 meta.input_info[info.base_idx].mutation_type 632 == MutationType.MUTATED_IN_GRAPH 633 ): 634 fw_args = args[0] if trace_joint else args 635 flat_outs[i] = fw_args[info.base_idx] 636 return pytree.tree_unflatten(flat_outs, outs_spec) 637 638 return pytree.tree_map(from_fun, f_outs) 639 640 # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" 641 # and "tangents" as its input names (which are special-cased by the partitioner) 642 # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export 643 def joint_helper(primals, tangents): 644 return _functionalized_f_helper(primals, tangents) 645 646 helper = joint_helper if trace_joint else _functionalized_f_helper 647 if config.functionalize_rng_ops: 648 # Setup the wrapper for functionalization of rng ops 649 helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint) 650 651 return helper, args 652 653 654def handle_effect_tokens_fn( 655 fn, 656 args, 657 *, 658 meta: ViewAndMutationMeta, 659 trace_joint: bool, 660) -> Any: 661 num_tokens = len(meta.tokens) 662 663 @wraps(fn) 664 def inner_fn(*args): 665 # See Note [Disabling Functionalize TLS Above Python Functionalization] 666 disable_above = torch._C._ExcludeDispatchKeyGuard( 667 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 668 ) 669 670 with disable_above: 671 # See Note [Side-Effectful Tokens in AOTAutograd] 672 if trace_joint: 673 assert isinstance(args, tuple) and isinstance(args[0], (list, tuple)) 674 tokens = args[0][:num_tokens] 675 assert all(token.numel() == 0 for token in tokens) 676 args = (args[0][num_tokens:], *args[1:]) 677 else: 678 tokens = args[:num_tokens] 679 assert all(token.numel() == 0 for token in tokens) 680 args = args[num_tokens:] 681 682 # Populate the current FunctionalTensorMode with the tokens per 683 # operator. See Note [FunctionalTensorMode is Stateful] 684 functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( 685 torch._C._TorchDispatchModeKey.FUNCTIONAL 686 ) 687 assert functional_tensor_mode is not None 688 f_tokens = pytree.tree_map(to_fun, tokens) 689 for i, k in enumerate(meta.tokens.keys()): 690 functional_tensor_mode._tokens[k] = f_tokens[i] 691 692 # Run the joint 693 outs = fn(*args) 694 695 # Return both the tokens and the outputs 696 # See Note [Side-Effectful Tokens in AOTAutograd] 697 if trace_joint: 698 assert len(outs) == 2 699 assert len(functional_tensor_mode._tokens_forward_output) == num_tokens 700 fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values() 701 702 bwd_out_tokens = functional_tensor_mode._tokens.values() 703 704 f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens] 705 f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens] 706 707 meta.num_backward_tokens = len(bwd_out_tokens) 708 return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)) 709 710 out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()] 711 return (*out_tokens, *outs) 712 713 # Additionally pass in tokens as inputs 714 # See Note [Side-Effectful Tokens in AOTAutograd] 715 additional_fwd_token_inputs = [torch.tensor([])] * num_tokens 716 717 if trace_joint: 718 args = ([*additional_fwd_token_inputs, *args[0]], *args[1:]) 719 else: 720 args = [*additional_fwd_token_inputs, *args] 721 return inner_fn, args 722 723 724# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor 725# Also returns: 726# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) 727# - the updated ViewAndMutationMeta for this dense -> dense function. 728# The other important arguments are: 729# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. 730# when is_joint_structure=False, this is just the forward function. 731# - fw_only: this is *always* the forward-only function. 732# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. 733# In particular, we need this to tell the partitioner how many dense forward outputs there are. 734def aot_dispatch_subclass( 735 flat_fn_maybe_joint, 736 args: List[Any], 737 *, 738 is_joint_structure: bool, 739 meta: ViewAndMutationMeta, 740 fw_only: Callable, 741) -> SubclassTracingInfo: 742 # Skip logic if we don't need to trace through any subclasses 743 req_subclass_dispatch = requires_subclass_dispatch(args, meta) 744 if not req_subclass_dispatch: 745 return SubclassTracingInfo( 746 plain_tensor_trace_fn=flat_fn_maybe_joint, 747 plain_tensor_args=args, 748 maybe_subclass_meta=None, 749 ) 750 751 # TODO: add subclass guards (later PR). 752 753 # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). 754 # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, 755 # so we set it later, while we're tracing the joint (see inner_fn() below). 756 # Another option would be to run our run_functionalized_fw_and_collect_metadata() function 757 # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). 758 subclass_meta = SubclassMeta() 759 760 def inner_fn(fn, args, *, use_trace_joint: bool): 761 # Step 1: wrap tensor inputs into subclasses if necessary 762 all_args = wrap_tensor_subclasses_maybe_joint( 763 args, is_joint_structure=use_trace_joint, meta=meta 764 ) 765 766 # Step 2: call the inner function, with our (maybe subclass) inputs 767 wrapped_outs = fn(*all_args) 768 769 if use_trace_joint: 770 # See Note: [Computing Subclass Metadata about grad_inputs] 771 # We also stash subclass info on our grad_inputs, if we're tracing the joint. 772 nonlocal subclass_meta 773 assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2 774 # Don't need fw outs since we already have subclass metadata on them 775 grad_inputs = wrapped_outs[1] 776 subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) 777 778 # Step 3: Unwrap any subclass outputs back into dense tensors 779 unwrapped_outs = unwrap_tensor_subclasses( 780 wrapped_outs, is_joint_structure=use_trace_joint 781 ) 782 return unwrapped_outs 783 784 def joint_fn(primals, tangents): 785 with maybe_enable_thunkify(): 786 return inner_fn( 787 flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True 788 ) 789 790 def fw_fn(*primals): 791 with maybe_enable_thunkify(): 792 return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) 793 794 def metadata_fn(*primals): 795 return inner_fn(fw_only, primals, use_trace_joint=False) 796 797 args_unwrapped = unwrap_tensor_subclasses( 798 args, is_joint_structure=is_joint_structure 799 ) 800 remapped_static_indices = remap_unwrapped_subclass_arg_indices( 801 args, meta.static_input_indices 802 ) 803 804 if is_joint_structure: 805 primals_unwrapped = args_unwrapped[0] 806 fn_to_trace = joint_fn 807 else: 808 primals_unwrapped = args_unwrapped 809 fn_to_trace = fw_fn 810 811 # Note: [Partitioner handling for Subclasses, Part 1] 812 # The way the partitioner works is that: 813 # (1) we pass is a single graph containing the joint fw/bw, 814 # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs 815 # (2) The partitioner accepts an arguments, num_fwd_outputs, 816 # and assumes that the first "num_fwd_outputs" graph outputs correspond 817 # to outputs of the forward graph. 818 # How do tensor subclasses enter the picture? 819 # the num_fwd_outputs in the final graph is actually non-trivial to compute, 820 # because it can be influenced by input mutations and intermediate bases. 821 # So we compute it by inspecting the current ViewAndMutationMeta object. 822 # However, the original ViewAndMutationMeta that we computed was created 823 # on the subclass -> subclass graph, 824 # which can have a different number of outputs than the dense -> dense graph. 825 # That's why we createa a fresh metadata object on the dense -> dense function here, 826 # and plumb it back up to the partitioner. 827 # See Note: [Partitioner handling for Subclasses, Part 2] for more info. 828 meta_updated = run_functionalized_fw_and_collect_metadata( 829 metadata_fn, 830 static_input_indices=remapped_static_indices, 831 keep_input_mutations=meta.keep_input_mutations, 832 is_train=meta.is_train, 833 )(*primals_unwrapped) 834 835 subclass_meta.fw_metadata = meta_updated 836 837 return SubclassTracingInfo( 838 plain_tensor_trace_fn=fn_to_trace, 839 plain_tensor_args=args_unwrapped, 840 maybe_subclass_meta=subclass_meta, 841 ) 842 843 844def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): 845 # Redundant with dynamo, but worth having in case this gets invoked elsewhere. 846 # https://github.com/pytorch/pytorch/issues/103569 847 848 def functional_call(*args, **kwargs): 849 with stateless._reparametrize_module( 850 mod, pytree.tree_unflatten(args[:params_len], params_spec) 851 ), maybe_disable_thunkify(): 852 if isinstance(mod, torch.fx.GraphModule): 853 with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): 854 warnings.filterwarnings( 855 "ignore", "Anomaly Detection has been enabled." 856 ) 857 with torch.autograd.detect_anomaly(check_nan=False): 858 detect_fake_mode().epoch += 1 859 out = PropagateUnbackedSymInts(mod).run( 860 *args[params_len:], **kwargs 861 ) 862 else: 863 out = mod(*args[params_len:], **kwargs) 864 865 if not isinstance(out, (tuple, list)): 866 raise RuntimeError( 867 "Graph output must be a (). This is so that we can avoid " 868 "pytree processing of the outputs. Please change the module to " 869 "have tuple outputs or use aot_module instead." 870 ) 871 return out 872 873 # Note [Preserving the nn module stack metadata during export non-strict mode] 874 # This path is currently only used by the non-strict export flow, 875 # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. 876 # Instead, we stash the original user nn module here, and rely on `make_fx` to grab 877 # this stashed module and use it to track nn module stack metadata 878 if store_orig_mod and not hasattr(functional_call, "_orig_mod"): 879 functional_call._orig_mod = mod # type: ignore[attr-defined] 880 881 return functional_call 882