1# mypy: allow-untyped-defs 2""" 3This module is one of the analysis modules - it takes as input a function or graph 4and some preexisting properties, and returns some data that is useful for deciding 5how to further proceed with compilation or construct runtime wrappers. 6 7In particular, the analysis here constructs view and mutation metadata from running 8a functionalized version of the graph under compilation. 9""" 10 11import collections 12import contextlib 13import logging 14from functools import wraps 15from typing import Callable, DefaultDict, Dict, List, Optional 16 17import torch 18import torch.utils._pytree as pytree 19from torch import Tensor 20from torch._guards import detect_fake_mode 21from torch._logging import getArtifactLogger 22from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode 23from torch._subclasses.meta_utils import safe_is_leaf 24from torch.fx.experimental.symbolic_shapes import is_concrete_int 25from torch.multiprocessing.reductions import StorageWeakRef 26from torch.utils._python_dispatch import ( 27 is_traceable_wrapper_subclass, 28 transform_subclass, 29) 30 31from .functional_utils import ( 32 are_all_mutations_hidden_from_autograd, 33 are_all_mutations_under_no_grad_or_inference_mode, 34 from_fun, 35 has_data_mutation, 36 has_metadata_mutation, 37 has_same_metadata, 38 to_fun, 39 was_inductor_storage_resized, 40) 41from .schemas import ( 42 FunctionalTensorMetadataEq, 43 InputAliasInfo, 44 MutationType, 45 OutputAliasInfo, 46 OutputType, 47 ViewAndMutationMeta, 48) 49from .subclass_utils import create_subclass_meta 50from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip 51 52 53zip = strict_zip 54 55log = logging.getLogger(__name__) 56static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs") 57 58 59# Note [Tangents must be contiguous] 60# We force tangents to be contiguous today. 61# The idea is that we are technically making a guess about the strides of our tangents, 62# while we trace out the joint. 63# Today, we force this guess to be correct by additioanlly calling contiguous() 64# on all tangents at runtime. 65# In the future, you could imagine lifting this restriction, since these contiguous() 66# calls can have noticeable perf overhead depending on the model. 67def coerce_tangent(x): 68 if not isinstance(x, Tensor): 69 return x 70 out = x.detach().contiguous() 71 # Note [Tangents must be contiguous, Part 2] 72 # In the same way that "what strides do we assigns to our tangents" is a question 73 # that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time, 74 # The same applies to any tensor subclass metadata, when we have tangents that are subclasses. 75 # To handle this situation, we have two new methods that a tensor subclass can implement: 76 # (1) __coerce_tangent_metadata__(self) 77 # Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata. 78 # The main example here is a DTensor with the "_Partial" placement. 79 # If we have a forward output with a _Partial placement, and corresponding tangent 80 # with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement. 81 # This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never 82 # have a tangent with "problematic" metadata, that we cannot convert to. 83 # (1) __coerce_same_metadata_as_tangent__(self, metadata) 84 # Given a subclass, and a target differing metadata, 85 # convert self to have the same metadata as the target. 86 # With DTensor being the main example, we can use this to convert a DTensor with a Replicate() 87 # placement into one with a Shard() placement, in the case that we "guessed wrong", 88 # and traced tangents with a Shard() placement at compile time. 89 # 90 if is_traceable_wrapper_subclass(out) and hasattr( 91 out, "__coerce_tangent_metadata__" 92 ): 93 out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined] 94 # It's possible to have a subclass that advertises as contiguous, 95 # but has noncontiguous inner tensors. 96 # Force these to be conntiguous too 97 if is_traceable_wrapper_subclass(out): 98 for attr in out.__tensor_flatten__()[0]: # type: ignore[attr-defined] 99 elem = getattr(out, attr) 100 if not elem.is_contiguous(): 101 elem_contig = elem.contiguous() 102 setattr(out, attr, elem_contig) 103 return out 104 105 106# This is a version of functionalization that is specifically designed 107# for the AOTAutograd use case. 108# 109# Unlike functorch's variant, this doesn't use the functorch level system, 110# instead it directly uses PyTorch's conventional dispatcher to hit the 111# functionalization key. In particular, this means that FunctionalTensorWrapper 112# can have autograd data stored directly on it. 113# 114# In typical AOTAutograd usage, the dispatch key order will look like: 115# 116# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor 117# outer tensor inner tensor 118# 119# Returns: 120# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and 121# The list of outputs from the forward, but **only** the outputs that we need 122# to pass in as tangents into the backward. 123# Specifically, aliased outputs from the forward get regenerated, and don't participate 124# in the compiled backward function. 125def run_functionalized_fw_and_collect_metadata( 126 f, 127 *, 128 keep_input_mutations: bool, 129 # TODO: refactor to kill this flag 130 is_train: bool = False, 131 # Note: this is guaranteed to be set when running under dynamo 132 static_input_indices: Optional[List[int]] = None, 133 pre_dispatch: bool = False, 134) -> Callable[..., ViewAndMutationMeta]: 135 memo: Dict[Tensor, Tensor] = {} 136 137 def _to_fun(t): 138 if isinstance(t, Tensor): 139 if t in memo: 140 return memo[t] 141 r = to_fun(t) 142 memo[t] = r 143 return r 144 else: 145 return t 146 147 @wraps(f) 148 def inner(*flat_args): 149 # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. 150 assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) 151 152 input_info: List[InputAliasInfo] = [] 153 output_info: List[OutputAliasInfo] = [] 154 155 prior_grad_enabled = torch.is_grad_enabled() 156 prior_autocast_states = _get_autocast_states() 157 158 # See Note [Disabling Functionalize TLS Above Python Functionalization] 159 disable_above = torch._C._ExcludeDispatchKeyGuard( 160 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 161 ) 162 163 # It doesn't matter if we run this under predispatch or not because it is 164 # only for figuring out metadata 165 mode = FunctionalTensorMode(_allow_token_discovery=True) 166 suppress_pending = contextlib.nullcontext() 167 fake_mode = detect_fake_mode() 168 if fake_mode and (shape_env := fake_mode.shape_env): 169 suppress_pending = shape_env.ignore_fresh_unbacked_symbols() 170 with disable_above, mode, suppress_pending: 171 # precondition: The passed in function already handles unflattening inputs + flattening outputs 172 flat_f_args = pytree.tree_map(_to_fun, flat_args) 173 flat_f_outs = f(*flat_f_args) 174 # We didn't do any tracing, so we don't need to process the 175 # unbacked symbols, they will just disappear into the ether. 176 # Also, prevent memoization from applying. 177 if fake_mode: 178 fake_mode.epoch += 1 179 fake_mode.reset_nt_tensor_id_counter() 180 181 if prior_autocast_states != _get_autocast_states(): 182 raise RuntimeError( 183 "AOTAutograd does not support tracing graphs that mutate the autocast state. " 184 "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, " 185 "which will unwind all of their mutations to autocast state before the graph exits. " 186 "If you encounter this error while using torch.compile, please file a bug." 187 ) 188 189 # Inspect the state of the input tensor functional wrapper to detect input mutation info 190 # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version 191 for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): 192 # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in 193 # strides between the functionalized arg inner tensors and non-functionalized arg inner 194 # tensors. This is a problem as the inner tensor stride change may not be reflected 195 # correctly in the outer tensor, so disallow this for now. 196 mutates_data = has_data_mutation(f_arg) 197 if ( 198 mutates_data 199 and not arg.is_contiguous() 200 and is_traceable_wrapper_subclass(arg) 201 ): 202 raise RuntimeError( 203 "Mutations on non-contiguous inputs are currently not allowed on " 204 "tensor subclasses" 205 ) 206 207 if not isinstance(arg, Tensor): 208 new_arg = arg 209 else: 210 new_arg = from_fun(f_arg) 211 mutates_metadata = has_metadata_mutation( 212 f_arg, arg, check_only_storage_mutation=False 213 ) 214 if mutates_metadata and is_traceable_wrapper_subclass(arg): 215 raise RuntimeError( 216 "Metadata mutations are currently not allowed on tensor subclasses" 217 ) 218 mutates_storage_metadata = has_metadata_mutation( 219 f_arg, arg, check_only_storage_mutation=True 220 ) 221 mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( 222 f_arg 223 ) 224 mutations_under_no_grad_or_inference_mode = ( 225 mutates_data 226 and are_all_mutations_under_no_grad_or_inference_mode(f_arg) 227 ) 228 mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg) 229 230 if mutates_storage_metadata: 231 mutates_data = False 232 233 requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad 234 235 input_info.append( 236 InputAliasInfo( 237 is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg), 238 mutates_data=mutates_data, 239 mutates_metadata=mutates_metadata, 240 mutations_hidden_from_autograd=mutations_hidden_from_autograd, 241 mutates_storage_metadata=mutates_storage_metadata, 242 mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, 243 mutation_inductor_storage_resize=mutation_inductor_storage_resize, 244 requires_grad=requires_grad, 245 keep_input_mutations=keep_input_mutations, 246 ) 247 ) 248 249 # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, 250 # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view 251 # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad 252 # on the base tensor, but we are obligated to properly set requires-gradness on the real output. 253 254 inp_storage_refs = { 255 StorageWeakRef(inpt.untyped_storage()): idx 256 for idx, inpt in enumerate(flat_f_args) 257 if isinstance(inpt, Tensor) 258 } 259 260 # We need inp tensor id's to be able to tell if an outputs **are** inputs. 261 inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)} 262 # We need output tensor id's to tell if any output._base` attributes **are** other outputs. 263 # (This is also a dict because we need to know that output's index, so we can regenerate 264 # the alias from it). 265 out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)} 266 267 # Keep track of which outputs alias other outputs 268 out_tensor_alias_counts: DefaultDict = collections.defaultdict(int) 269 # This tells us, for a given group of outputs that alias each other, 270 # whether they e.g. all came from an unbind call 271 num_aliased_tensors_that_are_multi_output_views: DefaultDict = ( 272 collections.defaultdict(int) 273 ) 274 out_storage_to_tensors: DefaultDict = collections.defaultdict(set) 275 curr_storage = None 276 for o in flat_f_outs: 277 if isinstance(o, torch.Tensor): 278 curr_storage = StorageWeakRef(o.untyped_storage()) 279 out_tensor_alias_counts[curr_storage] += 1 280 # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 281 # This is an optimization on top of the "alias of intermediates" logic, 282 # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!] 283 # 284 # Before describing the optimization: this is important for AOTAutograd to have good 285 # perf around, multi-output views. HOWEVER: 286 # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case, 287 # around using pre-dispatch tracing to partition out a graph so we can faithfully replay all 288 # views without having to regenerate them at runtime. 289 # - It's loosely described in this doc (more details will be added soon): 290 # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit 291 # - Once that change lands, we should just rip out this "optimization", since: 292 # (1) It will be fully unnecessary 293 # (2) Although it is only a few lines of code, it is a bit difficult to reason about 294 # its correctness with the autograd engine in all cases. 295 # 296 # 297 # What is this optimization? Consider the below case: 298 # def f(x): 299 # intermediate = x.mul(2) 300 # # x and intermediate here require grad 301 # o1, o2, ... o10 = intermediate.unbind(-1) 302 # return intermediate, o1, o2, ... o10 303 # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following: 304 # (1) return "intermediate as an extra output of the compiled graph 305 # (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function. 306 # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know 307 # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function, 308 # this information will be hidden. 309 # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases 310 # (like their grad_fn, for example, when the autograd engine needs to do view-replay). 311 # 312 # However, intermediate_base logic can be bad for backward performance (we sometimes generate 313 # as_strided calls during the intermediate base logic, which can have a slow backward formula). 314 # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd? 315 # 316 # For a set of outputs of the graph that alias each other, o_1...o_k, consider: 317 # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0) 318 # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate), 319 # **at most** 1 can escape from the graph (e.g. there is not some other graph input/output 320 # o_other, that aliases these outputs) 321 # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad. 322 # This condition is important because it's what causes slowness in the intermediate_base 323 # codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and 324 # aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn. 325 # "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward. 326 # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta 327 # of the other aliases? 328 # 329 # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd): 330 # (a) What happens if we mutate any of o_1 through o_k directly? 331 # Autograd raises an error: 332 # "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is 333 # the output of a function that returns multiple views. Such functions do not allow the output 334 # views to be modified inplace. You should replace the inplace operation by an out-of-place one." 335 # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)? 336 # Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views. 337 # (c) What if we mutate o_k under no_grad? 338 # Autograd raises the same error 339 # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)? 340 # Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed. 341 # Autograd raises the same error 342 # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view? 343 # We promised that there is at most **one** such alias, e.g. intermediate in the example above. 344 # You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k 345 # to be error fn's. 346 # Since intermediate was the *only* non-multi-output-alias, there are no other aliases 347 # of `intermediate` around that were produced by the compiled fn and have a valid grad_fn. 348 # 349 # Coming back to this optimization: 350 # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias 351 # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile 352 # if all of the above conditions are met. 353 # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on 354 # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance. 355 # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here: 356 # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit, 357 # then this optimization will probably matter less and might be ok to remove. 358 is_cur_tensor_multi_out_view = isinstance( 359 o, FunctionalTensor 360 ) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined] 361 o.elem 362 ) 363 if is_cur_tensor_multi_out_view: 364 num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 365 out_storage_to_tensors[curr_storage].add(o) 366 367 # maps the id of an intermediate base to its index in the output of the compiled forward 368 intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {} 369 intermediate_bases: List[torch.Tensor] = [] 370 # Why Do We Care If Storage Changed? 371 # It's important to understand the implications of storage changes in complex scenarios. Take this example: 372 # 373 # def f(x): 374 # x_storage = x.untyped_storage() 375 # non_leaf_tensor = torch.ones(4, requires_grad=True).clone() 376 # 377 # # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation 378 # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): 379 # x.set_(non_leaf_tensor.untyped_storage()) 380 # 381 # out = x.view(-1) 382 # 383 # # Restoring x to its original storage, again simulating .data = operation 384 # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): 385 # x.set_(x_storage) 386 # 387 # return out 388 # 389 # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing. 390 # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics, 391 # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. 392 # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, 393 # which could lead to issues later in the code. 394 for o in flat_f_outs: 395 functional_tensor_storage_changed = isinstance( 396 o, FunctionalTensor 397 ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] 398 o.elem 399 ) 400 curr_storage = ( 401 None 402 if not isinstance(o, torch.Tensor) 403 else StorageWeakRef(o.untyped_storage()) 404 ) 405 outs_with_identical_metadata_that_require_grad = ( 406 [] 407 if not isinstance(o, Tensor) 408 else [ 409 curr 410 for curr in out_storage_to_tensors[curr_storage] 411 if has_same_metadata(o, curr) 412 and curr.requires_grad 413 and o is not curr 414 ] 415 ) 416 417 # See Note [Accessing .grad_fn on FunctionalTensor] 418 # In-place operations on views will trigger a lazy rebase of the autograd graph; 419 # this runs during access to the .grad_fn. The rebase logic will invoke view ops 420 # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure 421 # these op calls succeed. 422 grad_fn = None 423 if isinstance(o, Tensor): 424 with FunctionalTensorMode(): 425 grad_fn = o.grad_fn 426 427 is_result_of_custom_autograd_fn = False 428 # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) 429 # autograd fns 430 if type(grad_fn).__name__ == "CppFunction": 431 is_result_of_custom_autograd_fn = True 432 if isinstance(grad_fn, torch.autograd.function.BackwardCFunction): 433 is_result_of_custom_autograd_fn = True 434 435 if not isinstance(o, Tensor): 436 output_type = OutputType.non_alias 437 base_idx = None 438 elif ( 439 curr_storage in inp_storage_refs 440 and grad_fn is not None 441 and is_result_of_custom_autograd_fn 442 ): 443 output_type = OutputType.custom_function_view 444 base_idx = None 445 elif ( 446 curr_storage in inp_storage_refs 447 and not functional_tensor_storage_changed 448 ): 449 base_idx = inp_storage_refs[curr_storage] 450 is_input_tensor = id(o) in inp_tensor_ids 451 num_aliased_outs = out_tensor_alias_counts[curr_storage] 452 num_multi_output_view_outs = ( 453 num_aliased_tensors_that_are_multi_output_views[curr_storage] 454 ) 455 num_aliased_outs_that_are_not_multi_output_views = ( 456 num_aliased_outs - num_multi_output_view_outs 457 ) 458 if ( 459 grad_fn is not None 460 and num_aliased_outs_that_are_not_multi_output_views == 0 461 ): 462 # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 463 # In particular, given: 464 # def f(x): 465 # return list(x.unbind(0)) 466 # The main reason we ordinarily try to regenerate these output aliases outside of the 467 # compiled autograd.Function is because if any of the outputs are later mutated, 468 # autograd needs to perform view-replay to regenerate them. 469 # However, autograd does not allow users to mutate multi-output views 470 # in any way that can change the autograd metadata of other aliases. 471 # So we hide this aliasing from autograd here. 472 log.debug( 473 "Encountered AOTAutograd case: differentiable outputs that \ 474alias each other from a multi-output view call" 475 ) 476 output_type = OutputType.non_alias 477 elif is_input_tensor: 478 output_type = OutputType.is_input 479 else: 480 output_type = OutputType.alias_of_input 481 elif functional_tensor_storage_changed and id(o) in inp_tensor_ids: 482 # When there is a set_() on an input, we cannot rely on checking storages 483 # to detect if we are returning an input (since the inputs storage is different) 484 assert curr_storage is not None 485 base_idx = inp_storage_refs[curr_storage] 486 output_type = OutputType.is_input 487 488 # We only need to handle the intermediate base case when both 489 # the intermediate base and the output require gradients. 490 # See Note [AOT Autograd: outputs aliasing inputs or intermediates!] 491 elif o._base is not None and o.requires_grad and o._base.requires_grad: 492 num_aliased_outs = out_tensor_alias_counts[curr_storage] 493 num_multi_output_view_outs = ( 494 num_aliased_tensors_that_are_multi_output_views[curr_storage] 495 ) 496 num_aliased_outs_that_are_not_multi_output_views = ( 497 num_aliased_outs - num_multi_output_view_outs 498 ) 499 # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] 500 if ( 501 out_tensor_alias_counts[curr_storage] == 1 502 or num_aliased_outs_that_are_not_multi_output_views <= 1 503 ): 504 # Note [Intermediate Bases Optimization] 505 # Normally if we have an output that aliases an intermediate, 506 # we need to add the extra "intermediate base" logic further down 507 # to prevent autograd from yelling at us if the user later tries to 508 # mutate that output. 509 # However, the common case here is if we have an output that aliases an intermediate, 510 # but doesn't alias any other outputs. 511 # In that case, autograd shouldn't have to worry about the aliasing at all 512 # (if that output is mutated, there are no other live aliases for autograd to worry about). 513 # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs. 514 # So as an optimization, we won't do intermediate base handling in this case. 515 # Instead, we'll hide the aliasing from autograd using aten._unsafe_view(). 516 if ( 517 out_tensor_alias_counts[curr_storage] != 1 518 and num_aliased_outs_that_are_not_multi_output_views <= 1 519 ): 520 log.debug( 521 "Encountered AOTAutograd case: differentiable outputs that alias each other \ 522from a multi-output view call" 523 ) 524 output_type = OutputType.unsafe_view_alias 525 base_idx = None 526 else: 527 # First, check if o's ._base is an existing output 528 maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None) 529 if maybe_existing_out_idx is not None: 530 # Special case where the output is an alias of a graph intermediate, but that intermediate 531 # is itself also a user output. 532 output_type = ( 533 OutputType.alias_of_intermediate_base_is_user_output 534 ) 535 base_idx = maybe_existing_out_idx 536 else: 537 # Next, check if o's ._base is an intermediate base that we already returned 538 maybe_existing_base_output_idx = ( 539 intermediate_base_tensor_id_to_output_idx.get( 540 id(o._base), None 541 ) 542 ) 543 if maybe_existing_base_output_idx is not None: 544 output_type = OutputType.alias_of_intermediate 545 base_idx = maybe_existing_base_output_idx 546 else: 547 # Otherwise, take o._base and explicitly return it as an output in the compiled graph 548 new_out_idx = len(intermediate_bases) 549 base_idx = new_out_idx 550 # Indicate to the logic later on (when we trace the joint) 551 # that this particular output should get it's ._base appended to the forward graph outputs 552 output_type = ( 553 OutputType.alias_of_intermediate_save_as_output 554 ) 555 intermediate_base_tensor_id_to_output_idx[ 556 id(o._base) 557 ] = new_out_idx 558 intermediate_bases.append(o._base) 559 elif ( 560 # See https://github.com/pytorch/pytorch/issues/100348 for this case. 561 # This protects against the specific case where a user fn returns (output, output.detach()) 562 out_tensor_alias_counts[curr_storage] > 1 563 and len(outs_with_identical_metadata_that_require_grad) > 0 564 and not o.requires_grad 565 ): 566 # In theory we could use any of these tensors to regenerate the aliased outputs from, 567 # since they all alias each other and have identical metatadata 568 out_alias = outs_with_identical_metadata_that_require_grad[0] 569 existing_out_idx = out_tensor_ids[id(out_alias)] 570 output_type = OutputType.alias_of_intermediate_base_is_user_output 571 base_idx = existing_out_idx 572 else: 573 output_type = OutputType.non_alias 574 base_idx = None 575 576 if isinstance(o, torch.Tensor): 577 dynamic_dims = { 578 i for i, s in enumerate(o.shape) if not is_concrete_int(s) 579 } 580 else: 581 dynamic_dims = None 582 583 # Save the current FunctionalTensor output. 584 # 585 # This will be used at runtime for reconstructing output views from 586 # their respective base tensors. 587 # 588 # The FunctionalTensor will be saved if one of the 2 conditions below 589 # is true: 590 functional_tensor = None 591 if ( 592 # 1. If the output_type is either of: 593 # (i) alias_of_intermediate; 594 # (ii) alias_of_intermediate_save_as_output; or 595 # (iii) alias_of_intermediate_base_is_user_output. 596 # 597 # No need to worry about in-place view operations here, since 598 # this functionalization step elimitates mutations. 599 # 600 # i.e. we have access to the actual base tensor, before the 601 # in-place operation was applied. 602 output_type 603 in ( 604 OutputType.alias_of_intermediate, 605 OutputType.alias_of_intermediate_save_as_output, 606 OutputType.alias_of_intermediate_base_is_user_output, 607 ) 608 ) or ( 609 # 2. If the output_type is alias_of_input, and no in-place view 610 # operationthe was run on the input (base tensor). 611 # 612 # In this case, we need to check for metadata mutation because 613 # the runtime explicitly reconstructs the inputs, before actually 614 # reconstructing the outputs. Due to in-place view operations, the 615 # fully reconstructed input may not be this output base tensor 616 # anymore. 617 output_type == OutputType.alias_of_input 618 and base_idx is not None 619 and not input_info[base_idx].mutates_metadata 620 ): 621 if isinstance(o, FunctionalTensor): 622 functional_tensor = FunctionalTensorMetadataEq(o.elem) 623 624 out_info = OutputAliasInfo( 625 output_type=output_type, 626 raw_type=type(o), 627 base_idx=base_idx, 628 dynamic_dims=dynamic_dims, 629 requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, 630 functional_tensor=functional_tensor, 631 ) 632 output_info.append(out_info) 633 634 # See Note [AOT Autograd: Views to avoid tangents aliasing inputs] 635 def view_avoid_dupes_with_primals(t): 636 if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): 637 return transform_subclass( 638 t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t) 639 ) 640 if isinstance(t, Tensor): 641 return t.view(t.shape) 642 return t 643 644 # This analysis function returns *only* the outputs that are meant to be tangents to the backwards. 645 # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates) 646 # are *regenerated* later, and not used directly in the autograd graph 647 f_input_tangents = [ 648 inp 649 for inp, info in zip(flat_f_args, input_info) 650 if info.mutation_type == MutationType.MUTATED_OUT_GRAPH 651 and info.mutates_data 652 and info.requires_grad 653 ] 654 f_output_tangents = [ 655 o 656 for o, info in zip(flat_f_outs, output_info) 657 if info.output_type 658 in [ 659 OutputType.non_alias, 660 OutputType.unsafe_view_alias, 661 OutputType.custom_function_view, 662 ] 663 and issubclass(info.raw_type, torch.Tensor) 664 and info.requires_grad 665 ] 666 # intermediate bases are also included in the backward graph 667 f_tangents = f_input_tangents + f_output_tangents + intermediate_bases 668 traced_tangents = pytree.tree_map(from_fun, f_tangents) 669 traced_tangents = pytree.tree_map( 670 view_avoid_dupes_with_primals, traced_tangents 671 ) 672 # See Note [Tangents must be contiguous] 673 traced_tangents = pytree.tree_map( 674 coerce_tangent, 675 traced_tangents, 676 ) 677 user_outs = pytree.tree_map(from_fun, f_output_tangents) 678 679 nonlocal static_input_indices 680 static_input_indices = static_input_indices or [] 681 if torch._dynamo.compiled_autograd.in_compiled_autograd_region: 682 passed_indices = set(static_input_indices) 683 static_input_indices = [ 684 i 685 for i, arg in enumerate(flat_args) 686 if (isinstance(arg, torch.nn.Parameter) or i in passed_indices) 687 ] 688 689 static_input_logger.debug( 690 "static input indices metadata analysis: %s", static_input_indices 691 ) 692 693 f_mutated_inputs = [ 694 inp 695 for inp, info in zip(flat_f_args, input_info) 696 if info.mutation_type == MutationType.MUTATED_OUT_GRAPH 697 ] 698 f_metadata_mutated_inputs = [ 699 inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata 700 ] 701 # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. 702 # When handling subclasses, we need info about **all** outputs of compiled forward graph, 703 # so we know precisely which graph outputs to wrap back into tensor subclasses 704 # Ideally we would refactor this so not have an is_train flag, and have the separate 705 # inference and training paths decide which inputs/output to ask for subclass info on. 706 # However, we currently stash indexing information on each SubclassMeta about its order 707 # in the graph outputs list. 708 f_fw_graph_outs = list(flat_f_outs) 709 if is_train or not keep_input_mutations: 710 f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs 711 else: 712 # even when "keep_input_mutations" is True, 713 # we never keep metadata-only mutations in the fw graph 714 f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs 715 if is_train: 716 f_fw_graph_outs = f_fw_graph_outs + intermediate_bases 717 fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) 718 719 grad_enabled_mutation = None 720 if torch.is_grad_enabled() != prior_grad_enabled: 721 grad_enabled_mutation = torch.is_grad_enabled() 722 torch.set_grad_enabled( 723 prior_grad_enabled 724 ) # Restore the prior state after tracing it 725 log.debug( 726 ( 727 "grad_mode mutation encountered in graph. " 728 "Will emit mutation epilogue, to set grad_mode=%s" 729 ), 730 grad_enabled_mutation, 731 ) 732 733 metadata = ViewAndMutationMeta( 734 input_info=input_info, 735 output_info=output_info, 736 num_intermediate_bases=len(intermediate_bases), 737 keep_input_mutations=keep_input_mutations, 738 traced_tangents=traced_tangents, 739 subclass_inp_meta=create_subclass_meta(flat_args), 740 subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), 741 subclass_tangent_meta=create_subclass_meta(traced_tangents), 742 is_train=is_train, 743 grad_enabled_mutation=grad_enabled_mutation, 744 static_input_indices=static_input_indices, 745 tokens=mode._tokens, 746 ) 747 return metadata 748 749 return inner 750