1# mypy: allow-untyped-defs 2""" 3This module is one of the analysis modules - it takes as input a function or graph 4and some preexisting properties, and returns some data that is useful for deciding 5how to further proceed with compilation or construct runtime wrappers. 6 7In particular, the following analyses are provided: 81. Refine the view and mutation metadata collected previously - removing duplicate 9 inputs or mapping views to their bases. 102. We also analyze the function signature for export graphs. 11""" 12 13import itertools 14from typing import Any, Dict, List, Optional, Tuple, Union 15 16import torch 17import torch.utils._pytree as pytree 18from torch import Tensor 19from torch._subclasses.functional_tensor import FunctionalTensor 20from torch.fx.experimental.symbolic_shapes import is_concrete_int 21 22from .. import config 23from .collect_metadata_analysis import coerce_tangent 24from .schemas import ( 25 BackwardSignature, 26 GraphSignature, 27 InputAliasInfo, 28 OutputAliasInfo, 29 OutputType, 30 ViewAndMutationMeta, 31) 32from .utils import strict_zip 33 34 35zip = strict_zip 36 37 38def remove_dupe_metadata( 39 m: ViewAndMutationMeta, 40 keep_arg_mask: List[bool], 41 add_dupe_map: List[int], 42) -> ViewAndMutationMeta: 43 assert len(m.input_info) == len(keep_arg_mask) 44 # Easy invariant: the first argument should never be a dupe (it will be kept) 45 assert len(keep_arg_mask) > 0 and keep_arg_mask[0] 46 47 # Filter dupe'd mutated inputs out of traced_tangents 48 num_data_mutations = len([x for x in m.input_info if x.mutates_data]) 49 other_traced_tangents = m.traced_tangents[num_data_mutations:] 50 inp_traced_tangents = m.traced_tangents[:num_data_mutations] 51 filtered_inp_traced_tangents = [ 52 # See Note [Tangents must be contiguous] 53 x 54 for i, x in enumerate(inp_traced_tangents) 55 if keep_arg_mask[m.mutated_inp_runtime_indices[i]] 56 ] 57 traced_tangents = filtered_inp_traced_tangents + other_traced_tangents 58 59 return ViewAndMutationMeta( 60 input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]], 61 # For outputs that are views of inputs, we store the index of the input that the output 62 # was generated from. Need to update that index to account for removed dupes. 63 output_info=[ 64 OutputAliasInfo( 65 output_type=o.output_type, 66 raw_type=o.raw_type, 67 dynamic_dims=o.dynamic_dims, 68 base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], 69 requires_grad=o.requires_grad, 70 functional_tensor=o.functional_tensor, 71 ) 72 for o in m.output_info 73 ], 74 num_intermediate_bases=m.num_intermediate_bases, 75 keep_input_mutations=m.keep_input_mutations, 76 traced_tangents=traced_tangents, 77 # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. 78 subclass_inp_meta=[], 79 subclass_fw_graph_out_meta=[], 80 subclass_tangent_meta=[], 81 is_train=m.is_train, 82 ) 83 84 85# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, 86# after adding synthetic base arguments to the function. 87# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, 88# and updating it with our synthetic base calling convention. 89# 90# When config.debug_assert is set, we automatically regenerate the metadata 91# and compare it to this output for sanity. 92# 93# In addition to the updated metadata, also return the list of input indices 94# that will need to be updated in the synthetic base epilogue 95 96 97# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, 98# after adding synthetic base arguments to the function. 99# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, 100# and updating it with our synthetic base calling convention. 101# 102# When config.debug_assert is set, we automatically regenerate the metadata 103# and compare it to this output for sanity. 104# 105# In addition to the updated metadata, also return the list of input indices 106# that will need to be updated in the synthetic base epilogue 107def create_synthetic_base_metadata( 108 m: ViewAndMutationMeta, 109 # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a 110 # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata) 111 synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]], 112 outer_args: List[Any], 113 inner_args: List[Any], 114) -> Tuple[ViewAndMutationMeta, List[int]]: 115 # maps inner arg indices to outer arg indices 116 synthetic_base_to_indices: Dict[int, List[int]] = {} 117 for inner_idx in range(len(inner_args)): 118 outer_aliased_indices_of_current_base_arg = [ 119 outer_idx 120 for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info) 121 if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx) 122 or ( 123 isinstance(inner_idx_or_tuple, tuple) 124 and inner_idx_or_tuple[0] == inner_idx 125 ) 126 ] 127 synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg 128 129 # given the requires_grad info on mutated inputs, 130 # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases. 131 input_infos = [] 132 for outer_indices in synthetic_base_to_indices.values(): 133 # leaf-ness should be all-or-nothing for aliased tensor. 134 # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf) 135 any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) 136 all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) 137 assert any_leaf == all_leaf 138 139 mutates_data = ( 140 True 141 if len(outer_indices) > 1 142 else m.input_info[outer_indices[0]].mutates_data 143 ) 144 mutates_metadata = ( 145 False 146 if len(outer_indices) > 1 147 else m.input_info[outer_indices[0]].mutates_metadata 148 ) 149 requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) 150 mutations_hidden_from_autograd = all( 151 m.input_info[x].mutations_hidden_from_autograd for x in outer_indices 152 ) 153 mutations_under_no_grad_or_inference_mode = all( 154 m.input_info[x].mutations_under_no_grad_or_inference_mode 155 for x in outer_indices 156 ) 157 158 mutation_inductor_storage_resize = all( 159 m.input_info[x].mutation_inductor_storage_resize for x in outer_indices 160 ) 161 162 inpt_info = InputAliasInfo( 163 # If len(outer_indices) > 1, then this input is a synthetic base. 164 # The invariant is that to the rest of aot autograd, synthetic bases only show up if 165 # one of their aliases gets a data mutation. And if any of their aliases get metadata 166 # mutations, they will be hidden from the rest of aot autograd. 167 mutates_data=mutates_data, 168 mutates_metadata=mutates_metadata, 169 mutations_hidden_from_autograd=all( 170 m.input_info[x].mutations_hidden_from_autograd for x in outer_indices 171 ), 172 mutates_storage_metadata=False 173 if len(outer_indices) > 1 174 else m.input_info[outer_indices[0]].mutates_storage_metadata, 175 mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, 176 mutation_inductor_storage_resize=mutation_inductor_storage_resize, 177 is_leaf=any_leaf, 178 requires_grad=requires_grad, 179 keep_input_mutations=m.keep_input_mutations, 180 ) 181 input_infos.append(inpt_info) 182 183 # Find any inputs that fulfill the following criteria: 184 # (1) They are part of a synthetic base (because they alias another input, 185 # and at least one input experiences a data mutation) 186 # (2) They experience a metadata mutation 187 outer_aliased_arg_idx_with_metadata_mutations = [ 188 outer_idx 189 for outer_idx, inpt_info in enumerate(m.input_info) 190 if inpt_info.mutates_metadata 191 and not isinstance(synthetic_base_info[outer_idx], int) 192 ] 193 194 # grab the original requires grad info on the outputs, except the ones from the mutated inputs 195 input_metadata_output_info = [ 196 OutputAliasInfo( 197 output_type=OutputType.alias_of_input, 198 raw_type=FunctionalTensor, 199 dynamic_dims={ 200 i 201 for i, s in enumerate(outer_args[outer_idx].shape) 202 if not is_concrete_int(s) 203 }, 204 base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index] 205 requires_grad=outer_args[outer_idx].requires_grad, 206 ) 207 for outer_idx in outer_aliased_arg_idx_with_metadata_mutations 208 ] 209 existing_output_infos = [] 210 for o in m.output_info: 211 new_base_idx = ( 212 None 213 if o.base_idx is None 214 else ( 215 synthetic_base_info[o.base_idx] 216 if isinstance(synthetic_base_info[o.base_idx], int) 217 else synthetic_base_info[o.base_idx][0] # type: ignore[index] 218 ) 219 ) 220 # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change 221 new_output_type = ( 222 OutputType.alias_of_input 223 if o.output_type == OutputType.is_input and o.base_idx != new_base_idx 224 else o.output_type 225 ) 226 existing_output_infos.append( 227 OutputAliasInfo( 228 output_type=new_output_type, 229 raw_type=o.raw_type, 230 dynamic_dims=o.dynamic_dims, 231 # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases 232 base_idx=new_base_idx, # type: ignore[arg-type] 233 requires_grad=o.requires_grad, 234 functional_tensor=o.functional_tensor, 235 ) 236 ) 237 238 inner_mutated_tangents = [ 239 # See Note [Tangents must be contiguous] 240 coerce_tangent(x) 241 for inner_idx, x in enumerate(inner_args) 242 if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad 243 ] 244 245 output_info = existing_output_infos + input_metadata_output_info 246 # Regenerate traced tangents to include mutated inputs including synthetic bases 247 traced_tangents = ( 248 inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] 249 ) 250 251 return ( 252 ViewAndMutationMeta( 253 input_info=input_infos, 254 output_info=output_info, 255 num_intermediate_bases=m.num_intermediate_bases, 256 keep_input_mutations=m.keep_input_mutations, 257 traced_tangents=traced_tangents, 258 # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. 259 subclass_inp_meta=[], 260 subclass_fw_graph_out_meta=[], 261 subclass_tangent_meta=[], 262 is_train=m.is_train, 263 ), 264 outer_aliased_arg_idx_with_metadata_mutations, 265 ) 266 267 268def _get_last_mem_address(x): 269 out = x.storage_offset() 270 for size, stride in zip(x.size(), x.stride()): 271 out += (size - 1) * stride 272 return out 273 274 275# Assumption: x and y are known to share a storage, and we are trying to determine 276# if their memory is actually completely disjoint, based on sizes/strides/storage_offset 277def _tensors_definitely_do_not_overlap(x, y): 278 if x is y: 279 return False 280 if x.numel() == 0 or y.numel() == 0: 281 return True 282 283 # Make x always on the left 284 if x.storage_offset() > y.storage_offset(): 285 x, y = y, x 286 # Short-circuit in the "obvious" overlapping case: both tensors are contiguous 287 if x.is_contiguous() and y.is_contiguous(): 288 if x.storage_offset() + x.numel() > y.storage_offset(): 289 # definitely overlap 290 return False 291 else: 292 # definitely no overlap 293 return True 294 295 # Short-circuit: if last memory address of x is < start of y, then not overlapping. 296 x_last = _get_last_mem_address(x) 297 if x_last < y.storage_offset(): 298 return True 299 300 if x.dim() == 2 and y.dim() == 2 and x.stride(1) == 1 and y.stride(1) == 1: 301 # This cases is needed for the shampoo optimizer. 302 # All tensors are 2d (non-contiguous), have the same outer stride, and have an inner stride of 1 303 # (so rows are contiguous) 304 if x.stride(0) == y.stride(0): 305 offset_delta = y.storage_offset() - x.storage_offset() 306 if offset_delta < x.size(1): 307 # definitely overlaps (row 0 of y overlaps with row 0 of x) 308 # Example: 309 # base = torch.arange(32).reshape(4, 8) 310 # x = base.narrow(1, 0, 4) 311 # x: size=(4, 4), stride=(8, 1), offset=0 312 # y = base.narrow(1, 3, 4) 313 # y: size=(4, 4), stride=(8, 1), offset=3 314 return False 315 x_total_elems_covered = x.stride(0) * (x.size(0) - 1) + x.size(1) 316 if x_total_elems_covered <= offset_delta: 317 # definitely does not overlap (last byte of x is before start of y) 318 # Example: 319 # x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27) 320 # y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28) 321 return True 322 # At this point, we want to check if the 0th row of y 323 # overlaps with **some** row of x. 324 # We can check this by shifting y backward by the shared stride, repeatedly, 325 # until the first row of y is before the first row of x. 326 # Then we can check if these rows overlap. 327 # We can accomplish this by modding our offset by the stride. 328 offset_delta_mod = offset_delta % x.stride(0) 329 # Example: 330 # 0 1 2 3 331 # 9 10 11 12 332 # 18 19 20 21 333 # 27 28 29 30 334 # x: size=(4, 4), stride=(9, 1), offset=0 335 # y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap) 336 # y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap) 337 # y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap) 338 # y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap) 339 # If the interval [modded_offset, modded_offset + x_size] falls entirely 340 # without 341 if offset_delta_mod + y.size(1) <= x.stride(0): 342 return True 343 return False 344 345 346def compute_overlapping_inputs(fwd_inputs, aliased_input_indices): 347 max_aliased_inps_w_dyn_shapes = ( 348 config._max_aliased_inputs_with_dynamic_shapes_enabled 349 ) 350 definitely_error_on_dyn_shapes = False 351 # If the JK is false / not set, we will fall back to obeying the config above 352 # If it is true, we will always error when there are aliased + mutated inps with dynamic shapes 353 if torch._inductor.config.is_fbcode(): 354 definitely_error_on_dyn_shapes = torch._utils_internal.justknobs_check( 355 "pytorch/dynamo:disable_aliased_inputs_with_mutation_and_dyn_shapes" 356 ) 357 358 actual_aliased_indices = set() 359 num_aliases = len(aliased_input_indices) 360 # > 2 check because num_aliases==1 means no aliasing 361 if num_aliases >= 2 and ( 362 definitely_error_on_dyn_shapes or num_aliases > max_aliased_inps_w_dyn_shapes 363 ): 364 dynamic_shape_indices = set() 365 for j in range(num_aliases): 366 j_ = aliased_input_indices[j] 367 curr_inp = fwd_inputs[j_] 368 if any( 369 isinstance(x, torch.SymInt) 370 for x in itertools.chain( 371 curr_inp.shape, curr_inp.stride(), [curr_inp.storage_offset()] 372 ) 373 ): 374 dynamic_shape_indices.add(j_) 375 assert ( 376 len(dynamic_shape_indices) == 0 377 ), f"""\ 378Encountered a graph where: 379- {num_aliases} graph inputs all share the same storage (input indices: {str(aliased_input_indices)}) 380- at least one of these aliased inputs was mutated 381- at least one of these inputs is being compiled with dynamic shapes (indices: {str(dynamic_shape_indices)}) 382 383Current limit: {str(max_aliased_inps_w_dyn_shapes)} 384Killswitch enabled: {str(definitely_error_on_dyn_shapes)} 385 386The most common way to run into this situation is when your model parameters are allocated as one giant buffer 387and are all mutated by the optimizer, and some of your parameters end up getting compiled with dynamic shapes. 388 389You can avoid this problem by marking your parameters so they explicitly do not participate in dynamic shapes, 390by marking each dim of your parameter static: 391 392torch._dynamo.mark_static(param, 0) # (1, 2, ... for every dimension on the parameter). 393 394If you are running into this issue in a situation where your parameters are static but some other inputs 395are aliased and mutated, and they should be dynamic, please file an issue. 396""" 397 for j in range(num_aliases): 398 for i in range(j): 399 j_ = aliased_input_indices[j] 400 i_ = aliased_input_indices[i] 401 if not _tensors_definitely_do_not_overlap(fwd_inputs[i_], fwd_inputs[j_]): 402 actual_aliased_indices.add(i_) 403 actual_aliased_indices.add(j_) 404 return actual_aliased_indices 405 406 407def _graph_input_names(gm): 408 return [node.name for node in gm.graph.find_nodes(op="placeholder")] 409 410 411def _graph_output_names(gm): 412 output_node = next(iter(reversed(gm.graph.nodes))) 413 assert output_node.op == "output" and len(output_node.args) == 1 414 return_args = output_node.args[0] 415 return [getattr(return_arg, "name", None) for return_arg in return_args] 416 417 418def create_graph_signature( 419 fx_g: torch.fx.GraphModule, 420 fw_metadata: ViewAndMutationMeta, 421 in_spec: pytree.TreeSpec, 422 out_spec: pytree.TreeSpec, 423 *, 424 user_args_flat: List[Tensor], 425 params_and_buffers_flat: List[Tensor], 426 param_names: List[str], 427 buffer_names: List[str], 428 trace_joint: bool, 429 num_user_fw_outs: Optional[int], 430 loss_index: Optional[int], 431) -> GraphSignature: 432 # Retrieve graph input names 433 graph_input_names = _graph_input_names(fx_g) 434 # Retrieve graph output names 435 graph_output_names = _graph_output_names(fx_g) 436 437 num_params_buffers = len(param_names) + len(buffer_names) 438 num_tokens = len(fw_metadata.tokens) 439 # We have enough restrictions on the graph (no de-duping, synthetic bases, etc), 440 # Such that # graph inps = # user inps + # params + # buffers 441 num_user_args = len(graph_input_names) - num_params_buffers - num_tokens 442 443 if trace_joint: 444 assert num_user_fw_outs is not None 445 num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices 446 backward_output_names = graph_output_names[num_fw_outs:] 447 448 grad_index = itertools.count(0) 449 gradients_to_parameters = { 450 backward_output_names[next(grad_index)]: param_names[i] 451 for i, param in enumerate(params_and_buffers_flat) 452 if param.requires_grad 453 } 454 455 gradients_to_user_inputs = { 456 backward_output_names[next(grad_index)]: graph_input_names[ 457 i + len(params_and_buffers_flat) 458 ] 459 for i, user_input in enumerate(user_args_flat) 460 if user_input.requires_grad 461 } 462 463 assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len( 464 backward_output_names 465 ) 466 467 # Check that we have fully accounted for all graph outputs 468 backward_signature = BackwardSignature( 469 gradients_to_parameters, 470 gradients_to_user_inputs, 471 graph_output_names[loss_index], 472 ) 473 else: 474 backward_signature = None 475 num_user_fw_outs = ( 476 len(graph_output_names) 477 - fw_metadata.num_mutated_inp_runtime_indices 478 - num_tokens 479 ) 480 481 return GraphSignature.from_tracing_metadata( 482 in_spec=in_spec, 483 out_spec=out_spec, 484 graph_input_names=graph_input_names, 485 graph_output_names=graph_output_names, 486 view_mutation_metadata=fw_metadata, 487 named_parameters=param_names, 488 named_buffers=buffer_names, 489 num_user_inputs=num_user_args, 490 num_user_outputs=num_user_fw_outs, 491 loss_index=loss_index, 492 backward_signature=backward_signature, 493 ) 494