1# mypy: allow-untyped-defs 2""" 3This file contains utilities related to functionalization in AOTAutograd: 41. converting to/from functional tensors 52. detecting Tensor mutations - both metadata and Tensor value 63. regenerating/replaying views from their base 74. checking if a graph is functional i.e. whether it contains any mutation ops 8""" 9from __future__ import annotations 10 11from typing import Optional 12 13import torch 14from torch import Tensor 15from torch._logging import getArtifactLogger 16from torch._subclasses.fake_tensor import FakeTensor 17from torch._subclasses.functional_tensor import FunctionalTensor 18from torch._subclasses.meta_utils import is_sparse_any 19from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq 20from torch.multiprocessing.reductions import StorageWeakRef 21from torch.utils._python_dispatch import ( 22 is_traceable_wrapper_subclass, 23 transform_subclass, 24) 25 26 27aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") 28 29 30def to_fun(t): 31 if isinstance(t, Tensor): 32 if is_traceable_wrapper_subclass(t): 33 # See Note [Functionalization always runs last] 34 # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper 35 # goes at the bottom. 36 # recurse here, so we can support nested wrapper subclasses 37 out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t)) 38 torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] 39 return out 40 else: 41 return FunctionalTensor.to_functional(t) 42 else: 43 return t 44 45 46def sync_functional_tensor(t): 47 if is_traceable_wrapper_subclass(t): 48 attrs, ctx = t.__tensor_flatten__() # type: ignore[attr-defined] 49 for attr in attrs: 50 sync_functional_tensor(getattr(t, attr)) 51 else: 52 torch._sync(t) 53 54 55# When subclasses are involved, t here will usually look something like: 56# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor)))) 57def from_fun(t): 58 if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): 59 # See Note [Functionalization always runs last] 60 # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper 61 # goes at the bottom. 62 # recurse here, so we can support nested wrapper subclasses 63 out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t)) 64 torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] 65 return out 66 67 if not isinstance(t, FunctionalTensor): 68 # quick sanity assert 69 if isinstance(t, torch.Tensor): 70 assert not torch._is_functional_tensor(t) # type: ignore[attr-defined] 71 return t 72 sync_functional_tensor(t) 73 return torch._from_functional_tensor(t.elem) 74 75 76def is_fun(t): 77 if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): 78 # See Note [Functionalization always runs last] 79 # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper 80 # goes at the bottom. 81 # recurse here, so we can support nested wrapper subclasses 82 t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined] 83 t_inners = [getattr(t, attr) for attr in t_attrs] 84 any_fun = any(is_fun(x) for x in t_inners) 85 all_fun = all(is_fun(x) for x in t_inners) 86 assert any_fun == all_fun 87 return any_fun 88 89 return isinstance(t, FunctionalTensor) 90 91 92# t here is either 93# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) 94# (2) A traceable tensor subclass that holds a FunctionalTensor 95# (3) Not a tensor 96def has_data_mutation(t): 97 if is_traceable_wrapper_subclass(t): 98 attrs, _ = t.__tensor_flatten__() 99 # A tensor subclass was updated if any of its inner elements were updated 100 return any(has_data_mutation(getattr(t, attr)) for attr in attrs) 101 else: 102 if isinstance(t, torch.Tensor): 103 assert isinstance(t, FunctionalTensor) 104 return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined] 105 return False 106 107 108def are_all_mutations_hidden_from_autograd(t): 109 if is_traceable_wrapper_subclass(t): 110 attrs, _ = t.__tensor_flatten__() 111 # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. 112 return all( 113 are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs 114 ) 115 elif isinstance(t, torch.Tensor): 116 assert isinstance(t, FunctionalTensor) 117 return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) 118 else: 119 return False 120 121 122def are_all_mutations_under_no_grad_or_inference_mode(t): 123 if is_traceable_wrapper_subclass(t): 124 attrs, _ = t.__tensor_flatten__() 125 return all( 126 are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr)) 127 for attr in attrs 128 ) 129 else: 130 assert isinstance(t, FunctionalTensor) 131 return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode( 132 t.elem 133 ) 134 135 136def was_inductor_storage_resized(t): 137 if is_traceable_wrapper_subclass(t): 138 attrs, _ = t.__tensor_flatten__() 139 if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs): 140 raise RuntimeError( 141 f"storage resizing is not supported on tensor subclass: {type(t)}" 142 ) 143 elif not isinstance(t, torch.Tensor): 144 return False 145 else: 146 assert isinstance(t, FunctionalTensor) 147 return torch._functionalize_was_inductor_storage_resized(t.elem) 148 149 150# f_arg here is either 151# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) 152# (2) A traceable tensor subclass that holds a FunctionalTensor 153# (3) Not a tensor 154# Assumption: arg promises to be the "original" tensor wrapped by f_arg 155# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: 156# - check_only_storage_mutation=True: only return true if there was a storage mutation 157# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) 158def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): 159 if is_traceable_wrapper_subclass(f_arg): 160 attrs, _ = f_arg.__tensor_flatten__() 161 # A tensor subclass was updated if any of its inner elements were updated 162 f_inner_ts = [getattr(f_arg, attr) for attr in attrs] 163 inner_ts = [getattr(arg, attr) for attr in attrs] 164 return any( 165 has_metadata_mutation( 166 f_inner_t, 167 inner_t, 168 check_only_storage_mutation=check_only_storage_mutation, 169 ) 170 for f_inner_t, inner_t in zip(f_inner_ts, inner_ts) 171 ) 172 else: 173 if not isinstance(f_arg, torch.Tensor): 174 assert not isinstance(arg, torch.Tensor) 175 return False 176 assert isinstance(f_arg, FunctionalTensor) 177 assert isinstance(arg, FakeTensor) 178 179 arg_after = torch._from_functional_tensor(f_arg.elem) 180 # This is true if the current tensor experienced at least one set_() call 181 maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined] 182 # However, multiple set_() calls can cancel out. So we also check whether the 183 # storage of the tensor has changed. 184 # Note: if an input experienced two set_() calls that cancel out, **and** 185 # it experiences an data mutation, we pessimistically think that the set_() 186 # call is necessary here. We could in theory fix this, but this will 187 # hopefully never happen in user code, and is not needed for fsdp. 188 if is_sparse_any(arg): 189 # TODO:add sparse tensors support to functionalization 190 same_storages = False 191 else: 192 same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef( 193 arg_after.untyped_storage() 194 ) 195 has_storage_metadata_mutation = maybe_storage_changed and not same_storages 196 if check_only_storage_mutation: 197 return has_storage_metadata_mutation 198 199 # storage metadata mutation is a type of metadata mutation, so return true if we saw one 200 if has_storage_metadata_mutation: 201 return True 202 203 maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined] 204 # This is true if the current tensor experienced at least one metadata mutation. 205 # So if false, we know there was no metadata mutation 206 if not maybe_metadata_mutated: 207 return False 208 209 # However, multi metadata mutations can cancel out. 210 # So we also check if the concrete sizes/strides on the tensor have changed. 211 same_sizes = arg.shape == arg_after.shape 212 same_strides = arg.stride() == arg_after.stride() 213 same_offsets = arg.storage_offset() == arg_after.storage_offset() 214 has_metadata_mutation_ = maybe_metadata_mutated and not ( 215 same_sizes and same_strides and same_offsets 216 ) 217 # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. 218 return has_metadata_mutation_ 219 220 221def gen_alias_from_base( 222 aliased_base_tensor, 223 target_meta_tensor, 224 target_requires_grad, 225 target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, 226 *, 227 replay_views, 228): 229 # Patch the correct requires_grad field of the output tensor, depending on whether: 230 # (i) the reconstructed output (out) was came from a tensor that requires grad or not; 231 # and (ii) the concrete returned output does require grad or not. 232 def patch_requires_grad(out): 233 if aliased_base_tensor.requires_grad and not target_requires_grad: 234 out = out.detach() 235 elif not aliased_base_tensor.requires_grad and target_requires_grad: 236 out.requires_grad_(True) 237 return out 238 239 # If provided, use the target functional tensor for replaying the views. 240 # 241 # In summary, we use the fact that FunctionalTensorWrapper saves the view 242 # functions applied to itself (collected during functionalization) so as 243 # to replay them (view functions) on the aliased_base_tensor. 244 if ( 245 replay_views 246 and target_functional_tensor is not None 247 and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) 248 ): 249 functional_tensor = target_functional_tensor.tensor 250 251 out = torch._functionalize_apply_view_metas( 252 functional_tensor, aliased_base_tensor 253 ) 254 # If re-applying the ViewMeta sequence succeeded, there should be no more 255 # problems going forward. We just check we got to the target shape and 256 # patch requires_grad flag. 257 assert out.shape == target_meta_tensor.shape, ( 258 "incorrect out shape after application of ViewMeta sequence: " 259 f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" 260 ) 261 return patch_requires_grad(out) 262 263 # Try to do view-replay if possible. 264 # fall back to .as_strided() if we can't. 265 if target_meta_tensor._base is not None: 266 # The base that we want to replay our view off of might have a different shape than the view's original base. 267 b = target_meta_tensor._base 268 abt = aliased_base_tensor 269 # Don't unnecessarily call as_strided if nothing changed; as_strided's 270 # backward is poorly implemented and slow 271 if abt is not b and ( 272 abt.size() != b.size() 273 or abt.stride() != b.stride() 274 or abt.storage_offset() != b.storage_offset() 275 ): 276 reshaped_base_tensor = aliased_base_tensor.as_strided( 277 b.size(), b.stride(), b.storage_offset() 278 ) 279 else: 280 reshaped_base_tensor = aliased_base_tensor 281 out = target_meta_tensor._view_func(reshaped_base_tensor) 282 # This shape mismatch can happen due to a bug in inplace/view handling in autograd. 283 # Try putting a breakpoint here and running 284 # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types` 285 # Also, https://github.com/pytorch/pytorch/issues/49825 286 # 287 # As a stopgap, we'll fall back to as_strided. 288 if out is not None and out.shape == target_meta_tensor.shape: 289 return patch_requires_grad(out) 290 291 size = target_meta_tensor.size() 292 stride = target_meta_tensor.stride() 293 storage_offset = target_meta_tensor.storage_offset() 294 if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): 295 aliased_out = torch.view_as_real(aliased_base_tensor).as_strided( 296 size, stride, storage_offset 297 ) 298 elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): 299 aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided( 300 size, stride, storage_offset 301 ) 302 else: 303 aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) 304 # For outputs aliasing inputs, we need to check if the requires-gradness has changed. 305 aliased_out = patch_requires_grad(aliased_out) 306 # For outputs aliasing inputs, we need to check if the dtype has changed. 307 # as_strided() is the "most generic" view, but it does not cover cross-dtype views 308 if aliased_out.dtype != target_meta_tensor.dtype: 309 aliased_out = aliased_out.view(target_meta_tensor.dtype) 310 return aliased_out 311 312 313def has_same_metadata(t1, t2): 314 return ( 315 definitely_true(sym_eq(t1.size(), t2.size())) 316 and definitely_true(t1.layout == t2.layout) 317 and ( 318 is_sparse_any(t1) 319 or ( 320 definitely_true(sym_eq(t1.stride(), t2.stride())) 321 and definitely_true(t1.storage_offset() == t2.storage_offset()) 322 ) 323 ) 324 and t1.is_conj() == t2.is_conj() 325 and t1.is_neg() == t2.is_neg() 326 ) 327 328 329# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata 330# after applying all the ViewMeta operations. 331class FunctionalTensorMetadataEq: 332 def __init__(self, tensor: torch.Tensor) -> None: 333 assert torch._is_functional_tensor(tensor) 334 self.tensor = tensor 335 336 def __eq__(self, other: object) -> bool: 337 # If other is None, then it probably means that we weren't able to recreate 338 # the FunctionalTensorMetadataEq. One of this cases is when we update the 339 # view metadata by calling: create_synthetic_base_metadata. 340 if other is None: 341 return True 342 343 # Comparison agains any other type is not implemented. 344 if not isinstance(other, FunctionalTensorMetadataEq): 345 return NotImplemented 346 347 return has_same_metadata(self.tensor, other.tensor) 348 349 350# new_arg and arg here are either: 351# (1) both a FakeTensor 352# (2) both a traceable tensor subclass that holds a FakeTensor 353# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. 354# When we run functionalization and wrap our inputs into FunctionalTensors, 355# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed 356# 357# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization 358# to confirm that inputs were not mutated when running the user's model with functionalization on. 359# But when we have subclass inputs, we can't rely on that: 360# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs 361# a brand new subclass instance: we are calling __tensor_unflatten__, and going 362# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) 363def was_tensor_updated(arg, new_arg): 364 if is_traceable_wrapper_subclass(arg): 365 assert is_traceable_wrapper_subclass(new_arg) 366 attrs, _ = arg.__tensor_flatten__() 367 new_attrs, _ = new_arg.__tensor_flatten__() 368 assert attrs == new_attrs 369 # A tensor subclass was updated if any of its inner elements were updated 370 return any( 371 was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr)) 372 for attr in attrs 373 ) 374 else: 375 return arg is not new_arg 376 377 378# new_arg and arg here are either: 379# (1) both a FakeTensor 380# (2) both a traceable tensor subclass that holds a FakeTensor 381# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. 382# When we run functionalization and wrap our inputs into FunctionalTensors, 383# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, 384# but shares storage with the old input 385def was_tensor_metadata_updated(arg, new_arg): 386 if is_traceable_wrapper_subclass(arg): 387 assert is_traceable_wrapper_subclass(new_arg) 388 attrs, _ = arg.__tensor_flatten__() 389 new_attrs, _ = new_arg.__tensor_flatten__() 390 assert attrs == new_attrs 391 # A tensor subclass was updated if any of its inner elements were updated 392 return any( 393 was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) 394 for attr in attrs 395 ) 396 else: 397 return arg is not new_arg and StorageWeakRef( 398 arg.untyped_storage() 399 ) == StorageWeakRef(new_arg.untyped_storage()) 400 401 402# Returns the number of detected copy_ 403def assert_functional_graph(fx_g: torch.fx.Graph) -> int: 404 allowed_mutation_ops = [ 405 torch.ops.aten.copy_.default, 406 torch.ops.aten.set_.source_Tensor, 407 ] 408 if hasattr(torch.ops.fsdp, "set_"): 409 allowed_mutation_ops.append(torch.ops.fsdp.set_.default) 410 411 placeholders = set() 412 mutation_count = 0 413 # NB: It would also be nice to verify that the mutations all happen at the 414 # end, but we also do some administrative views after mutations so this 415 # isn't actually true. (TODO: Could this cause problems for Inductor?) 416 for n in fx_g.nodes: 417 if n.op == "placeholder": 418 placeholders.add(n) 419 if isinstance(n.target, torch._ops.OpOverload): 420 if n.target in allowed_mutation_ops: 421 suffix = True 422 # Can only copy_/set_ into an input 423 # this is mostly a hack to avoid failing XLA tests. 424 # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 425 if "set_buffer_donor_" not in str(n.args[0]): 426 assert ( 427 n.args[0] in placeholders 428 ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" 429 mutation_count += 1 430 else: 431 assert ( 432 not n.target._schema.is_mutable 433 ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" 434 return mutation_count 435 436 437def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: 438 placeholders = set() 439 for n in fx_g.nodes: 440 if n.op == "placeholder": 441 placeholders.add(n) 442 if isinstance(n.target, torch._ops.OpOverload): 443 if n.target is torch.ops.aten.copy_.default: 444 # Can only copy_ into an input, and can only do so once 445 if "set_buffer_donor_" not in str(n.args[0]): 446 assert ( 447 n.args[0] in placeholders 448 ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" 449 placeholders.remove(n.args[0]) 450 copy_from_node = n.args[1] 451 # Pre-condition: every node has a "stack_trace" field in its meta, 452 # but copy_() nodes do not (since we manually added them during functionalization). 453 # Instead, we manually propagate here. 454 if "stack_trace" in copy_from_node.meta: 455 n.meta["stack_trace"] = copy_from_node.meta["stack_trace"] 456 457 458def _check_if_mutation_can_be_in_graph( 459 keep_input_mutations: bool, 460 mutates_data, 461 mutates_metadata, 462 mutations_hidden_from_autograd, 463 mutations_under_no_grad_or_inference_mode, 464 mutates_storage_metadata, 465 mutation_inductor_storage_resize, 466 requires_grad, 467): 468 if keep_input_mutations: 469 in_graph = ( 470 mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize 471 ) and ( 472 (not mutates_metadata and not requires_grad) 473 or mutations_hidden_from_autograd 474 or mutations_under_no_grad_or_inference_mode 475 ) 476 else: 477 in_graph = False 478 # See Note [set_() Input Mutations in AOTAutograd] 479 # If there was a `set_()`, we require that all mutations were under no_grad, 480 # so we can (safely) emit the set_() in the graph at runtime 481 # resize_() gets the same treatment 482 if mutation_inductor_storage_resize or mutates_storage_metadata: 483 op_name = "resize_" if mutation_inductor_storage_resize else "set_" 484 assert in_graph, f"""\ 485Encountered a {op_name} on a graph input, but the input has other mutations that we cannot 486keep in the graph. This is not supported today. Current state: 487 keep_input_mutations={keep_input_mutations} 488 mutates_data={mutates_data} 489 mutates_metadata={mutates_metadata} 490 mutations_hidden_from_autograd={mutations_hidden_from_autograd} 491 mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode} 492 mutation_inductor_storage_resize={mutation_inductor_storage_resize} 493 requires_grad={requires_grad}""" 494 return in_graph 495