1# mypy: allow-untyped-defs 2import contextlib 3import functools 4import itertools 5import logging 6import os 7import sys 8import time 9import warnings 10from itertools import count 11 12from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 13from unittest import mock 14 15import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 16 17import torch.fx 18import torch.utils._pytree as pytree 19 20from functorch.compile import min_cut_rematerialization_partition 21from torch._dynamo import ( 22 compiled_autograd, 23 config as dynamo_config, 24 logging as dynamo_logging, 25 utils as dynamo_utils, 26) 27from torch._dynamo.utils import ( 28 counters, 29 detect_fake_mode, 30 flatten_graph_inputs, 31 lazy_format_graph_code, 32) 33from torch._functorch import config as functorch_config 34from torch._functorch.aot_autograd import aot_export_module, make_boxed_func 35from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache 36from torch._inductor.cudagraph_utils import ( 37 BoxedDeviceIndex, 38 get_placeholders, 39 log_cudagraph_skip_and_bump_counter, 40) 41 42from torch._inductor.debug import save_args_for_compile_fx_inner 43from torch._inductor.utils import ( 44 BoxedBool, 45 count_tangents, 46 fresh_inductor_cache, 47 should_assume_input_aligned, 48 tensor_is_aligned, 49) 50from torch._logging import trace_structured 51from torch._ops import OpOverload 52from torch._subclasses.fake_tensor import FakeTensor 53from torch._utils_internal import compile_time_strobelight_meta 54from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 55from torch.fx.passes.fake_tensor_prop import FakeTensorProp 56 57from .._dynamo.backends.common import aot_autograd 58from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] 59from ..fx.graph import _PyTreeCodeGen 60from . import config, metrics 61from .debug import DebugContext 62from .decomposition import select_decomp_table 63from .fx_passes.joint_graph import joint_graph_passes 64from .fx_passes.post_grad import post_grad_passes, view_to_reshape 65from .fx_passes.pre_grad import pre_grad_passes 66from .graph import GraphLowering 67from .ir import ExternKernelNode 68from .utils import ( 69 get_cloned_parameter_buffer_name, 70 has_incompatible_cudagraph_ops, 71 maybe_get_suppress_shape_guards_ctx, 72 output_node, 73) 74from .virtualized import V 75 76if config.is_fbcode(): 77 from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log 78else: 79 # no-op decorator 80 def time_and_log(attr: str): 81 return dynamo_utils.identity 82 83 84log = logging.getLogger(__name__) 85perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") 86post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") 87ALIGNMENT = 16 88 89 90# copy_ fails when trying to write to tensors with memory overlap, 91# for expanded dimensions (a dimension which used to have size 1 -> ?) 92# we can select one element from that dimension and write to it 93# to achieve writing to all values of that dimension of the input tensor 94def get_expanded_dims(t): 95 if not isinstance(t, torch.Tensor): 96 return None 97 return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] 98 99 100def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor: 101 for expanded_dim in expanded_dims: 102 t = torch.ops.aten.slice(t, expanded_dim, 0, 1) 103 return t 104 105 106def complex_memory_overlap(t: torch.Tensor) -> bool: 107 # if torch._debug_has_internal_overlap thinks this tensor potentially has 108 # memory overlap internally, let's dig deeper to find out whether it's true. 109 # 110 # Call squeeze() so that dimension with size 1 does not cause false positive. 111 t = index_expanded_dims(t, get_expanded_dims(t)).squeeze() 112 if torch._debug_has_internal_overlap(t) != 0: 113 strides = t.stride() 114 sizes = t.shape 115 indices = list(range(len(strides))) 116 indices = [x for _, x in sorted(zip(strides, indices))] 117 for i in range(len(strides)): 118 prev_stride = 1 if i == 0 else strides[indices[i - 1]] 119 prev_size = 1 if i == 0 else sizes[indices[i - 1]] 120 if strides[indices[i]] < prev_stride * prev_size: 121 return True 122 return False 123 124 125def get_static_input_idxs(num_fixed): 126 # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes 127 # of cudagraphs. Rather than copying these into cudagraph-owned memory 128 # like we do for normal inputs on each run, we will re-record a cudagraph if these 129 # parameter locations change. 130 context = torch._guards.TracingContext.try_get() 131 fixed = list(range(num_fixed)) 132 if not context or not context.fw_metadata: 133 return fixed 134 135 return fixed + context.fw_metadata.static_parameter_indices 136 137 138@functools.lru_cache(None) 139def _step_logger(): 140 return dynamo_logging.get_step_logger(log) 141 142 143@functools.lru_cache(None) 144def _warn_tf32_disabled(): 145 if ( 146 torch.cuda.is_available() 147 and not torch.backends.cuda.matmul.allow_tf32 148 and torch.cuda.get_device_capability() >= (8, 0) 149 ): 150 warnings.warn( 151 "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " 152 "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." 153 ) 154 155 156def _unlift_graph(mod, gm, graph_signature): 157 from torch.export.unflatten import _assign_attr, _AttrKind 158 159 state_dict = {} 160 for name, param in mod.named_parameters(remove_duplicate=False): 161 state_dict[name] = param 162 _assign_attr( 163 param, 164 gm, 165 name, 166 attr_kind=_AttrKind.PARAMETER, 167 ) 168 for name, buffer in mod.named_buffers(remove_duplicate=False): 169 state_dict[name] = buffer 170 _assign_attr( 171 buffer, 172 gm, 173 name, 174 attr_kind=_AttrKind.BUFFER, 175 ) 176 177 placeholder_nodes = gm.graph.find_nodes(op="placeholder") 178 lifted_inputs = [] 179 180 # In AOTI, module parameters and buffers are not lifted as graph inputs. 181 # As a result, mutation to buffers has side effect which makes their initial 182 # values different from Eager. So we clone them here as a copy. 183 # We are not cloning for parameters, although it will be needed if we want to 184 # support training. 185 for node in placeholder_nodes: 186 node_name = node.name 187 if node_name in graph_signature.inputs_to_parameters: 188 parameter_name = graph_signature.inputs_to_parameters[node_name] 189 lifted_inputs.append(parameter_name) 190 elif node_name in graph_signature.inputs_to_buffers: 191 buffer_name = graph_signature.inputs_to_buffers[node_name] 192 lifted_inputs.append(buffer_name) 193 gm.meta[ 194 get_cloned_parameter_buffer_name(buffer_name) 195 ] = clone_preserve_strides(state_dict[buffer_name]) 196 else: 197 assert node_name in graph_signature.user_inputs 198 lifted_inputs.append(None) 199 200 from torch.export._unlift import _unlift 201 202 outputs = list(gm.graph.nodes)[-1].args[0] 203 mutated_outputs = [] 204 buffer_mutations = graph_signature.buffers_to_mutate 205 user_input_mutations = graph_signature.user_inputs_to_mutate 206 output_tokens = graph_signature.output_tokens 207 for idx, out in enumerate(outputs): 208 value = None 209 210 if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): 211 if out.name in buffer_mutations: 212 value = buffer_mutations[out.name] 213 elif out.name in user_input_mutations: 214 value = user_input_mutations[out.name] 215 216 mutated_outputs.append(value) 217 218 unlifted_gm = _unlift( 219 gm, 220 lifted_inputs, 221 mutated_outputs, 222 pytree.LeafSpec(), 223 None, 224 state_dict, 225 {}, 226 ) 227 return unlifted_gm 228 229 230def _get_subgraph_names(gm): 231 for node in sorted( 232 itertools.chain( 233 gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond), 234 gm.graph.find_nodes( 235 op="call_function", target=torch.ops.higher_order.while_loop 236 ), 237 ) 238 ): 239 if node.target == torch.ops.higher_order.cond: 240 true_subgraph_name = node.args[1].name 241 false_subgraph_name = node.args[2].name 242 yield true_subgraph_name 243 yield false_subgraph_name 244 elif node.target == torch.ops.higher_order.while_loop: 245 cond_subgraph_name = node.args[0].name 246 body_subgraph_name = node.args[1].name 247 yield cond_subgraph_name 248 yield body_subgraph_name 249 250 251def _recursive_pre_grad_passes(gm, example_inputs): 252 for subgraph_name in _get_subgraph_names(gm): 253 subgraph = getattr(gm, subgraph_name) 254 # as we don't have recursive example inputs, passing None here 255 new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) 256 setattr(gm, subgraph_name, new_subgraph) 257 return pre_grad_passes(gm, example_inputs) 258 259 260def _recursive_joint_graph_passes(gm): 261 for subgraph_name in _get_subgraph_names(gm): 262 subgraph = getattr(gm, subgraph_name) 263 _recursive_joint_graph_passes(subgraph) 264 joint_graph_passes(gm) 265 266 267def _recursive_post_grad_passes(gm, is_inference: bool = False): 268 for subgraph_name in _get_subgraph_names(gm): 269 subgraph = getattr(gm, subgraph_name) 270 _recursive_post_grad_passes(subgraph, is_inference) 271 post_grad_passes(gm, is_inference) 272 273 274def split_const_gm( 275 gm: torch.fx.GraphModule, 276) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: 277 """ 278 This function takes an GraphModule input "gm". 279 The gm will be split into 2 components, 280 1) const_gm, which consists the subgraph of gm that can be constant folded. 281 2) gm (being inplace modified,) which returns the graph after constant folding. 282 283 const_output_index is a mapping of corresponding node name from gm to the 284 output index of const_gm. 285 Returns (const_gm, const_output_index) 286 """ 287 from torch._inductor.constant_folding import ( 288 CONST_MODULE_TAG, 289 META_TAG, 290 MODULE_TAG, 291 replace_node_with_constant, 292 run_and_get_constant_graph, 293 ) 294 295 const_gm = run_and_get_constant_graph(gm) 296 const_result = const_gm() 297 298 const_outputs = { 299 x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) 300 } 301 302 to_erase_node = [] 303 to_replace_node = [] 304 const_output_index = {} 305 for node in gm.graph.nodes: 306 if node.name in const_outputs: 307 to_replace_node.append(node) 308 elif node.meta[META_TAG] == CONST_MODULE_TAG: 309 to_erase_node.append(node) 310 311 for node in to_replace_node: 312 new_const_name = "_FOLDED_CONST_" + node.name 313 replace_node_with_constant( 314 gm, 315 node, 316 const_result[const_outputs[node.name]], 317 new_const_name, 318 ) 319 const_output_index[new_const_name] = const_outputs[node.name] 320 for node in to_erase_node[::-1]: 321 if node.users: 322 for n in node.users: 323 assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." 324 else: 325 gm.graph.erase_node(node) 326 gm.recompile() 327 328 return const_gm, const_output_index 329 330 331def is_tf32_warning_applicable(gm: torch.fx.GraphModule): 332 aten = torch.ops.aten 333 tf32_ops = { 334 aten.mm.default, 335 aten.addmm.default, 336 aten.bmm.default, 337 aten.baddbmm.default, 338 } 339 for target in tf32_ops: 340 for node in gm.graph.find_nodes(op="call_function", target=target): 341 if ( 342 isinstance(node.meta.get("val", None), torch.Tensor) 343 and node.meta["val"].dtype == torch.float32 344 and node.meta["val"].device.type == "cuda" 345 ): 346 return True 347 return False 348 349 350def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): 351 """ 352 For CPU backend, enable comprehensive padding causes some unit tests 353 fail due to changing number of generated kernels. Skip for now. 354 """ 355 has_cuda = any( 356 t.device.type == "cuda" for t in example_inputs if isinstance(t, torch.Tensor) 357 ) 358 359 if config.comprehensive_padding and not has_cuda: 360 perf_hint_log.info("Skip comprehensive padding on CPU") 361 return config.patch(comprehensive_padding=False) 362 else: 363 return contextlib.nullcontext() 364 365 366def fake_tensor_prop( 367 gm: torch.fx.GraphModule, 368 example_inputs: List[torch.Tensor], 369 force_allow_non_fake_inputs: bool = False, 370): 371 """ 372 If we can not detect fake mode from the context of inputs, create one. 373 374 The created fake mode will be returned. 375 """ 376 fake_mode = detect_fake_mode(example_inputs) 377 if not fake_mode: 378 fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) 379 FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) 380 else: 381 ctx = ( 382 contextlib.nullcontext() 383 if not force_allow_non_fake_inputs 384 else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) 385 ) 386 with ctx: # type: ignore[attr-defined] 387 FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( 388 *example_inputs 389 ) 390 391 return fake_mode 392 393 394def should_use_remote_fx_graph_cache(): 395 if config.fx_graph_remote_cache: 396 return True 397 if not config.is_fbcode(): 398 return False 399 if torch.version.hip is not None: 400 return False 401 402 try: 403 from triton.fb.fb_memcache import MEMCACHE_VERSION 404 except ModuleNotFoundError: 405 return False 406 407 return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int( 408 "pytorch/remote_cache:fx_graph_memcache_version" 409 ) 410 411 412# pass config dict back to user 413def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: 414 with config.patch(config_patches): 415 return config.get_config_copy() 416 417 418@functools.wraps 419def with_fresh_cache_if_config(f): 420 if config.force_disable_caches: 421 with fresh_inductor_cache(): 422 return f 423 else: 424 return f 425 426 427@DebugContext.wrap 428@torch.utils._python_dispatch._disable_current_modes() 429@time_and_log(attr="compilation time (in seconds)") 430# Need this decorator for compile_fx_inner even if we already have one for 431# compile_fx. The reason is the compilation for backward graph may happen after 432# compile_fx return and we may want to use the _LazyGraphModule for compiling 433# the backward graph as well. 434@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) 435@with_fresh_cache_if_config 436@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False) 437def compile_fx_inner( 438 gm: torch.fx.GraphModule, 439 example_inputs: List[torch.Tensor], 440 cudagraphs: Optional[BoxedBool] = None, 441 static_input_idxs: Optional[List[int]] = None, 442 is_backward: bool = False, 443 graph_id: Optional[int] = None, 444 cpp_wrapper: bool = False, 445 aot_mode: bool = False, 446 is_inference: bool = False, 447 boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, 448 user_visible_outputs: Optional[Dict[str, None]] = None, 449 layout_opt: Optional[bool] = None, 450 extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, 451) -> Union[CompiledFxGraph, str]: 452 """ 453 Inductor API that compiles a single graph. 454 455 If you change the argument list for this function, make sure you 456 also update the call to save_args_for_compile_fx_inner below accordingly. 457 """ 458 if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: 459 # trigger the real recompilation for _LazyGraphModule before returning 460 # the forward method. 461 from torch.fx._lazy_graph_module import _LazyGraphModule 462 463 _LazyGraphModule.force_recompile(gm) 464 return make_boxed_func(gm.forward) 465 466 if static_input_idxs is None: 467 static_input_idxs = [] 468 469 assert isinstance( 470 next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) 471 ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" 472 473 if config.save_args: 474 save_args_for_compile_fx_inner( 475 gm, 476 example_inputs, 477 cudagraphs=cudagraphs, 478 static_input_idxs=static_input_idxs, 479 is_backward=is_backward, 480 graph_id=graph_id, 481 cpp_wrapper=cpp_wrapper, 482 aot_mode=aot_mode, 483 is_inference=is_inference, 484 boxed_forward_device_index=boxed_forward_device_index, 485 user_visible_outputs=user_visible_outputs, 486 layout_opt=layout_opt, 487 ) 488 489 if cudagraphs is None: 490 cudagraphs = BoxedBool(config.triton.cudagraphs) 491 492 # Inputs to fx_codegen_and_compile 493 # Anything that affects codegen should go here, so if the signature 494 # of fx_codegen_and_compile changes, the dict should be updated accordingly 495 graph_kwargs = { 496 "cudagraphs": cudagraphs, 497 "static_input_idxs": static_input_idxs, 498 "is_backward": is_backward, 499 "graph_id": graph_id, 500 "cpp_wrapper": cpp_wrapper, 501 "aot_mode": aot_mode, 502 "is_inference": is_inference, 503 "user_visible_outputs": user_visible_outputs, 504 "layout_opt": layout_opt, 505 "extern_node_serializer": extern_node_serializer, 506 } 507 508 start = time.time() 509 510 fx_graph_remote_cache = should_use_remote_fx_graph_cache() 511 inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) 512 if ( 513 not config.force_disable_caches 514 and (config.fx_graph_cache or fx_graph_remote_cache) 515 and not aot_mode 516 ): 517 for i, input in enumerate(example_inputs): 518 if ( 519 isinstance(input, torch.Tensor) 520 and input.device.type == "cuda" 521 and i in static_input_idxs 522 ): 523 input._is_inductor_static = True # type: ignore[attr-defined] 524 525 compiled_graph = FxGraphCache.load( 526 fx_codegen_and_compile, 527 gm, 528 example_inputs, 529 graph_kwargs, 530 inputs_to_check, 531 local=config.fx_graph_cache, 532 remote=fx_graph_remote_cache, 533 ) 534 else: 535 compiled_graph = fx_codegen_and_compile( 536 gm, example_inputs, **graph_kwargs # type: ignore[arg-type] 537 ) 538 539 log.debug("FX codegen and compilation took %.3fs", time.time() - start) 540 541 # check cudagraph disabling reasons from inductor lowering 542 if cudagraphs and compiled_graph.disabled_cudagraphs_reason: 543 if "cuda" in compiled_graph.device_types: 544 log_cudagraph_skip_and_bump_counter( 545 f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" 546 ) 547 else: 548 counters["inductor"]["cudagraph_skips"] += 1 549 BoxedBool.disable(cudagraphs) 550 551 # Return the output strides to the caller via TracingContext 552 context = torch._guards.TracingContext.try_get() 553 if context is not None and context.output_strides is not None: 554 assert len(context.output_strides) == 0 555 context.output_strides.extend(compiled_graph.output_strides) 556 557 if aot_mode: 558 return compiled_graph 559 560 if cudagraphs: 561 # output args are tuple of first argument 562 output = output_node(gm) 563 assert len(output.args) == 1 564 stack_traces = [ 565 (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) 566 for arg in output.args[0] 567 ] 568 569 complex_memory_overlap_inputs = any( 570 complex_memory_overlap(t) 571 for t in example_inputs 572 if isinstance(t, torch.Tensor) 573 ) 574 575 if not config.triton.cudagraph_support_input_mutation: 576 # Skip supports for cudagraph-managed tensors 577 from torch._inductor.cudagraph_utils import ( 578 check_for_mutation_ignore_cuda_graph_managed_tensor, 579 ) 580 581 has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor( 582 gm, compiled_graph, static_input_idxs 583 ) 584 has_mutation = has_mutation_str is not None 585 586 if has_mutation: 587 compiled_graph.disabled_cudagraphs_reason = has_mutation_str 588 else: 589 # Check mutation later to support cudagraph-managed tensors 590 has_mutation = None 591 592 cudagraph_tests = [ 593 (not has_mutation, "mutated inputs"), 594 (not has_incompatible_cudagraph_ops(gm), "incompatible ops"), 595 (not complex_memory_overlap_inputs, "complex memory overlap"), 596 ( 597 all( 598 isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs 599 ), 600 "non-Tensor inputs", 601 ), 602 ] 603 cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] 604 605 if not cudagraph_fail_reasons: 606 if not config.triton.cudagraph_trees: 607 # Force specialize all inputs so that CUDA graphs will work 608 for t in example_inputs: 609 if isinstance(t, torch.SymInt): 610 int(t) # guard 611 612 if ( 613 boxed_forward_device_index is not None 614 and not is_inference 615 and not is_backward 616 ): 617 boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) 618 619 compiled_graph.current_callable = cudagraphify( 620 compiled_graph.current_callable, 621 example_inputs, 622 static_input_idxs=static_input_idxs, 623 device_index=next(iter(compiled_graph.device_idxs)), 624 stack_traces=stack_traces, 625 is_backward=is_backward, 626 is_inference=is_inference, 627 constants=tuple(compiled_graph.constants.values()), 628 placeholders=tuple(get_placeholders(gm.graph)), 629 mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), 630 ) 631 else: 632 BoxedBool.disable(cudagraphs) 633 634 # See [Backward Generation Handling] 635 # if cudagraph'd the forward and set the device, we need to let the cudagraph manager 636 # know we are we running the backward even if we will not run it in cudagraphs 637 if is_backward and config.triton.cudagraph_trees: 638 assert boxed_forward_device_index is not None 639 assert boxed_forward_device_index.value is not None 640 compiled_graph_callable = compiled_graph.current_callable 641 642 manager = torch._inductor.cudagraph_trees.get_manager( 643 boxed_forward_device_index.value, create_if_none_exists=False 644 ) 645 # should already exist from forward 646 assert manager is not None 647 648 def compiled_artifact(new_inputs): 649 manager.set_to_running_backward() # type: ignore[union-attr] 650 return compiled_graph_callable(new_inputs) 651 652 compiled_graph.current_callable = compiled_artifact 653 654 if "cuda" in compiled_graph.device_types: 655 # prefer better disable_cudagraphs_reason bc stack trace 656 # TODO: migrate all disable reasons to stack trace, refactor 657 if compiled_graph.disabled_cudagraphs_reason: 658 log_cudagraph_skip_and_bump_counter( 659 compiled_graph.disabled_cudagraphs_reason 660 ) 661 else: 662 log_cudagraph_skip_and_bump_counter( 663 f"skipping cudagraphs due to {cudagraph_fail_reasons}" 664 ) 665 666 # cudagraphs does its own aligning of inputs 667 if not cudagraphs: 668 new_callable = align_inputs_from_check_idxs( 669 compiled_graph.current_callable, inputs_to_check 670 ) 671 if new_callable is not compiled_graph.current_callable: 672 compiled_graph.current_callable = new_callable 673 674 _step_logger()( 675 logging.INFO, 676 "torchinductor done compiling " 677 f"{'BACKWARDS' if is_backward else 'FORWARDS'} " 678 f"graph {graph_id}", 679 ) 680 681 # aot autograd needs to know to pass in inputs as a list 682 compiled_graph._boxed_call = True 683 return compiled_graph 684 685 686@dynamo_utils.preserve_rng_state() 687def fx_codegen_and_compile( 688 gm: torch.fx.GraphModule, 689 example_inputs: List[torch.Tensor], 690 cudagraphs: Optional[BoxedBool] = None, 691 static_input_idxs: Optional[List[int]] = None, 692 is_backward: bool = False, 693 graph_id: Optional[int] = None, 694 cpp_wrapper: bool = False, 695 aot_mode: bool = False, 696 is_inference: bool = False, 697 # Use a dict with None value rather than a set for deterministic 698 # iteration order just in case. 699 user_visible_outputs: Optional[Dict[str, None]] = None, 700 layout_opt: Optional[bool] = None, 701 extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, 702) -> Union[CompiledFxGraph, str]: 703 if is_tf32_warning_applicable(gm): 704 _warn_tf32_disabled() 705 706 # lift the maximum depth of the Python interpreter stack 707 # to adapt large/deep models 708 sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) 709 710 _step_logger()( 711 logging.INFO, 712 "torchinductor compiling " 713 f"{'BACKWARDS' if is_backward else 'FORWARDS'} " 714 f"graph {graph_id}", 715 ) 716 V.debug.fx_graph(gm, example_inputs) 717 # TODO: Should we actually dump this? It should be redundant with the aot 718 # structured logs... 719 # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) 720 721 shape_env = _shape_env_from_inputs(example_inputs) 722 723 # Convert view to reshape in the graph. This is necessary primarily for 724 # layout optimization. Do it unconditionally for uniformity. 725 # 726 # It's needed because when we do layout optimization, an contiguous tensor 727 # in eager mode may becomes a channels last tensor. A view op previously 728 # can be applied to the contiguous tensor may not be able to be applied 729 # on the channels tensor any more. An error like 730 # RuntimeError: view size is not compatible with input tensor's size and stride 731 # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. 732 # will be printed. 733 # 734 # Replace view op to reshape op in this case. 735 # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. 736 # 737 # Also this has to be done before FakeTensorProp below to avoid the failed 738 # .view() call. 739 view_to_reshape(gm) 740 741 # It is safe to run FakeTensorProp under no_grad because by the time 742 # we're in inductor, we assume that AOTAutograd has already "taken care" 743 # of autograd, so there should be no more autograd-related API's in the 744 # graph. 745 with torch.no_grad(): 746 fake_mode = fake_tensor_prop(gm, example_inputs) 747 748 # pattern matcher passes might not preserve striding information 749 # on node.meta["val"]. if in the future we rely on these being 750 # correct we will need to fix. 751 752 with V.set_fake_mode(fake_mode): 753 # has some issues with memory in training 754 _recursive_post_grad_passes(gm, is_inference=is_inference) 755 V.debug.fx_graph_transformed(gm, example_inputs) 756 post_grad_graphs_log.debug( 757 "%s", 758 lazy_format_graph_code( 759 "AFTER POST GRAD", gm, include_stride=True, include_device=True 760 ), 761 ) 762 trace_structured( 763 "inductor_post_grad_graph", 764 payload_fn=lambda: gm.print_readable( 765 print_output=False, include_stride=True, include_device=True 766 ), 767 ) 768 if config.is_fbcode(): 769 log_optimus_to_scuba( 770 extra_logging={"pt2_configs": str(get_patched_config_dict())} 771 ) 772 773 with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( 774 example_inputs 775 ): 776 const_output_index = None 777 const_graph = None 778 const_code = None 779 780 if aot_mode and config.aot_inductor.use_runtime_constant_folding: 781 const_gm, const_output_index = split_const_gm(gm) 782 783 const_graph = GraphLowering( 784 const_gm, 785 example_inputs=[], 786 shape_env=shape_env, 787 graph_id=graph_id, 788 cpp_wrapper=cpp_wrapper, 789 aot_mode=aot_mode, 790 user_visible_outputs=user_visible_outputs, 791 extern_node_serializer=extern_node_serializer, 792 is_inference=is_inference, 793 is_const_graph=True, 794 ) 795 with V.set_graph_handler(const_graph): 796 assert cpp_wrapper, "AOT mode only supports C++ wrapper" 797 const_graph.run() 798 799 const_code, _ = const_graph.codegen_with_cpp_wrapper() 800 801 graph = GraphLowering( 802 gm, 803 # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. 804 # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, 805 # we currently use fake tensors and defake them later. 806 example_inputs=example_inputs, 807 shape_env=shape_env, 808 graph_id=graph_id, 809 cpp_wrapper=cpp_wrapper, 810 aot_mode=aot_mode, 811 user_visible_outputs=user_visible_outputs, 812 extern_node_serializer=extern_node_serializer, 813 is_inference=is_inference, 814 const_output_index=const_output_index, 815 const_code=const_code, 816 const_module=const_graph, 817 ) 818 metrics_helper = metrics.CachedMetricsHelper() 819 with V.set_graph_handler(graph): 820 graph.run(*example_inputs) 821 output_strides: List[Optional[Tuple[int, ...]]] = [] 822 if graph.graph_outputs is not None: 823 # We'll put the output strides in the compiled graph so we 824 # can later return them to the caller via TracingContext 825 for out in graph.graph_outputs: 826 if ( 827 hasattr(out, "layout") 828 and len(free_unbacked_symbols(out.layout.stride)) == 0 829 ): 830 output_strides.append( 831 tuple( 832 V.graph.sizevars.size_hint(s) for s in out.layout.stride 833 ) 834 ) 835 else: 836 output_strides.append(None) 837 838 _check_triton_bf16_support(graph) 839 compiled_fn = graph.compile_to_fn() 840 num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() 841 metrics.num_bytes_accessed += num_bytes 842 metrics.node_runtimes += node_runtimes 843 metrics.nodes_num_elem += nodes_num_elem 844 845 if ( 846 cudagraphs 847 and config.triton.cudagraph_skip_dynamic_graphs 848 and not V.graph.disable_cudagraphs_reason 849 and torch._inductor.utils.any_is_symbolic(*example_inputs) 850 ): 851 stack_trace = None 852 for node in gm.graph.nodes: 853 meta_val = node.meta.get("val", None) 854 if ( 855 node.op == "placeholder" 856 or not isinstance(meta_val, torch.Tensor) 857 or not torch._inductor.utils.any_is_symbolic(meta_val) 858 ): 859 continue 860 861 if stack_trace := node.meta.get("stack_trace", None): 862 break 863 disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True." 864 if stack_trace: 865 disable = f"{disable} Found from {stack_trace}\n" 866 else: 867 disable = f"{disable}\n" 868 V.graph.disable_cudagraphs_reason = disable 869 870 if V.aot_compilation is True: 871 return compiled_fn 872 873 if cudagraphs and not V.graph.disable_cudagraphs_reason: 874 from torch._inductor.cudagraph_utils import ( 875 check_lowering_disable_cudagraph, 876 ) 877 878 V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph( 879 V.graph.device_node_mapping 880 ) 881 882 compiled_graph = CompiledFxGraph( 883 compiled_fn, 884 graph, 885 output_strides, 886 V.graph.disable_cudagraphs_reason, 887 metrics_helper.get_deltas(), 888 ) 889 890 return compiled_graph 891 892 893def clone_preserve_strides(x: torch.Tensor): 894 needed_size = ( 895 sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 896 ) 897 buffer = torch.as_strided(x, (needed_size,), (1,)).clone() 898 return torch.as_strided(buffer, x.size(), x.stride()) 899 900 901def copy_misaligned_inputs( 902 new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int] 903) -> None: 904 for i in check_inputs_idxs: 905 if new_inputs[i].data_ptr() % ALIGNMENT: 906 new_inputs[i] = clone_preserve_strides(new_inputs[i]) 907 908 909def get_input_idxs_to_check( 910 inputs: Union[List[torch.Tensor], Sequence[int]], 911 static_input_idxs: Sequence[int], 912) -> Sequence[int]: 913 """ 914 This function runs at compile time, and generates a list of indices for which we 915 might need to do a copy to preserve alignment requirements. 916 """ 917 ids_to_check = [] 918 919 for i, input in enumerate(inputs): 920 if not isinstance(input, torch.Tensor): 921 # non-tensors don't need alignment 922 continue 923 if input.device.type != "cuda": 924 # right now we only care for cuda tensors 925 continue 926 with maybe_get_suppress_shape_guards_ctx(): 927 # suppress guards so that tensor_is_aligned and should_assume_input_aligned 928 # do not add guards on input's storage offset 929 if i in static_input_idxs and tensor_is_aligned(input): 930 continue 931 if not should_assume_input_aligned(input): 932 continue 933 934 # if we get here, then 935 # (a) our triton code assumes that the input is aligned 936 # (b) we can't be sure ahead of time that the input will actually be aligned. 937 # therefore, at runtime, we'll need to check that the input is aligned 938 # (and if not, clone it to make it aligned.) 939 ids_to_check.append(i) 940 941 return ids_to_check 942 943 944def align_inputs_from_check_idxs( 945 model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int] 946): 947 if len(inputs_to_check) == 0: 948 return model 949 950 def run(new_inputs): 951 copy_misaligned_inputs(new_inputs, inputs_to_check) 952 return model(new_inputs) 953 954 return run 955 956 957@dynamo_utils.dynamo_timed 958def cudagraphify( 959 model: torch.fx.GraphModule, 960 inputs: List[torch.Tensor], 961 static_input_idxs: Sequence[int] = (), 962 *, 963 device_index: int, 964 stack_traces: List[Optional[str]], 965 is_backward: bool, 966 is_inference: bool, 967 constants: Tuple[torch.Tensor, ...] = (), 968 placeholders: Tuple[torch.fx.Node, ...] = (), 969 mutated_input_idxs: Tuple[int, ...] = (), 970): 971 from torch._inductor.cudagraph_trees import ( 972 cudagraphify_impl as new_cudagraphify_impl, 973 ) 974 975 cudagraphify_fn: Callable[..., Any] 976 if config.triton.cudagraph_trees: 977 cudagraphify_fn = functools.partial( 978 new_cudagraphify_impl, 979 device_index=device_index, 980 stack_traces=stack_traces, 981 is_backward=is_backward, 982 is_inference=is_inference, 983 constants=constants, 984 placeholders=placeholders, 985 mutated_input_idxs=mutated_input_idxs, 986 ) 987 else: 988 cudagraphify_fn = cudagraphify_impl 989 990 # if using fake tensors, defer cudagraphs until we get real inputs at runtime 991 if not any(isinstance(inp, FakeTensor) for inp in inputs): 992 return cudagraphify_fn(model, inputs, static_input_idxs) 993 994 compiled_fn = None 995 996 def run(new_inputs): 997 nonlocal compiled_fn 998 if compiled_fn is None: 999 with dynamo_utils.preserve_rng_state(): 1000 compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) 1001 return compiled_fn(new_inputs) 1002 1003 return run 1004 1005 1006def remove_unaligned_input_idxs( 1007 inputs: Union[List[torch.Tensor], Sequence[int]], 1008 static_input_idxs: Sequence[int], 1009): 1010 """ 1011 We require all inputs to be aligned, so introduce a copy for any 1012 that aren't. 1013 """ 1014 aligned_static_input_idxs = [] 1015 for idx, input in zip(static_input_idxs, inputs): 1016 if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: 1017 aligned_static_input_idxs.append(idx) 1018 if len(aligned_static_input_idxs) != len(static_input_idxs): 1019 return aligned_static_input_idxs 1020 return static_input_idxs 1021 1022 1023def static_input(x: torch.Tensor): 1024 """ 1025 Copy and input while preserving strides 1026 """ 1027 # TODO(jansel): figure out why this version doesn't work: 1028 # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) 1029 needed_size = ( 1030 sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 1031 ) 1032 buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device) 1033 return torch.as_strided(buffer, x.size(), x.stride()) 1034 1035 1036def index_expanded_dims_and_copy_( 1037 dst: torch.Tensor, 1038 src: torch.Tensor, 1039 expanded_dims: List[int], 1040): 1041 "Index into expanded dimensions of both dst and src then copy_" 1042 dst = index_expanded_dims(dst, expanded_dims) 1043 src = index_expanded_dims(src, expanded_dims) 1044 dst.copy_(src) 1045 1046 1047def cudagraphify_impl( 1048 model: torch.fx.GraphModule, 1049 inputs: List[torch.Tensor], 1050 static_input_idxs: Sequence[int] = (), 1051): 1052 """ 1053 Assumes inputs[static_input_idxs[i]] are always the same memory address 1054 """ 1055 check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) 1056 static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) 1057 copy_misaligned_inputs(inputs, check_input_idxs) 1058 1059 assert isinstance(inputs, list) 1060 1061 inps_expanded_dims = [ 1062 get_expanded_dims(x) if idx not in static_input_idxs else [] 1063 for idx, x in enumerate(inputs) 1064 ] 1065 1066 # allocate static tensor inputs 1067 static_inputs = [ 1068 x 1069 if not isinstance(x, torch.Tensor) 1070 else static_input(x) 1071 if idx not in static_input_idxs 1072 else x.detach() 1073 for idx, x in enumerate(inputs) 1074 ] 1075 1076 # copy over input values for fresh allocations 1077 for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): 1078 if isinstance(x, torch.Tensor) and idx not in static_input_idxs: 1079 index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) 1080 1081 # warmup 1082 torch.cuda.synchronize() 1083 stream = torch.cuda.Stream() 1084 stream.wait_stream(torch.cuda.current_stream()) 1085 # copy static_inputs because it will be cleared in model 1086 with torch.cuda.stream(stream): 1087 model(list(static_inputs)) 1088 stream.synchronize() 1089 torch.cuda.current_stream().wait_stream(stream) 1090 torch.cuda.synchronize() 1091 1092 # record 1093 graph = torch.cuda.CUDAGraph() 1094 with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): 1095 static_outputs = model(list(static_inputs)) 1096 if not isinstance(static_outputs, (list, tuple)): 1097 static_outputs = (static_outputs,) 1098 1099 if config.size_asserts: 1100 1101 def run(new_inputs): 1102 assert len(static_inputs) == len(new_inputs) 1103 for idx, (dst, src, expanded_dims) in enumerate( 1104 zip(static_inputs, new_inputs, inps_expanded_dims) 1105 ): 1106 if not isinstance(dst, torch.Tensor): 1107 pass 1108 elif idx in static_input_idxs: 1109 assert dst.data_ptr() == src.data_ptr() 1110 else: 1111 # TODO - could make one single op of multiple slices 1112 # and avoid dispatch. 1113 # Could also pre-index the `dst` tensors 1114 index_expanded_dims_and_copy_(dst, src, expanded_dims) 1115 new_inputs.clear() 1116 graph.replay() 1117 return static_outputs 1118 1119 else: 1120 copy_indices = [ 1121 idx for idx in range(len(static_inputs)) if idx not in static_input_idxs 1122 ] 1123 1124 def run(new_inputs): 1125 for idx in copy_indices: 1126 expanded_dims = inps_expanded_dims[idx] 1127 index_expanded_dims_and_copy_( 1128 static_inputs[idx], new_inputs[idx], expanded_dims 1129 ) 1130 new_inputs.clear() 1131 graph.replay() 1132 return static_outputs 1133 1134 return align_inputs_from_check_idxs(run, check_input_idxs) 1135 1136 1137def compile_fx_aot( 1138 model_: torch.fx.GraphModule, 1139 example_inputs_: List[torch.Tensor], 1140 inner_compile: Callable[..., Any] = compile_fx_inner, 1141 config_patches: Optional[Dict[str, Any]] = None, 1142): 1143 config_patches: Dict[str, Any] = ( 1144 {"cpp_wrapper": True} 1145 if config_patches is None 1146 else {**config_patches, "cpp_wrapper": True} 1147 ) 1148 if ( 1149 "aot_inductor.output_path" not in config_patches 1150 and not config.aot_inductor.output_path 1151 ): 1152 config_patches = { 1153 **config_patches, 1154 "aot_inductor.output_path": code_hash(model_.code), 1155 } 1156 1157 extern_node_serializer = config_patches.pop("extern_node_serializer", None) 1158 with V.set_aot_compilation(True): 1159 compiled_lib_path = compile_fx( 1160 model_, 1161 example_inputs_, 1162 inner_compile=functools.partial( 1163 inner_compile, 1164 aot_mode=True, 1165 extern_node_serializer=extern_node_serializer, 1166 ), 1167 config_patches=config_patches, 1168 ) 1169 assert os.path.exists( 1170 compiled_lib_path 1171 ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" 1172 return compiled_lib_path 1173 1174 1175_graph_counter = count(0) 1176 1177 1178def fw_compiler_freezing( 1179 aot_autograd_model: torch.fx.GraphModule, 1180 aot_example_inputs: List[torch.Tensor], 1181 dynamo_model: torch.fx.GraphModule, 1182 num_example_inputs: int, 1183 inner_compile: Callable[..., Any], 1184 cudagraphs: BoxedBool, 1185 graph_id: int, 1186 forward_device: BoxedDeviceIndex, 1187): 1188 from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze 1189 1190 # partition_fn won't be called 1191 _recursive_joint_graph_passes(aot_autograd_model) 1192 1193 layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) 1194 if layout_opt: 1195 # make sure meta['val'] is properly setup 1196 fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) 1197 convert_conv_weights_to_channels_last(aot_autograd_model) 1198 1199 opt_model, preserved_arg_indices = freeze( 1200 dynamo_model, 1201 aot_autograd_model, 1202 aot_example_inputs, # type: ignore[arg-type] 1203 ) 1204 1205 aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] 1206 num_fixed = len(preserved_arg_indices) - num_example_inputs 1207 1208 fake_mode = detect_fake_mode(aot_example_inputs) 1209 1210 # for freezing, all graph outputs should be user visible 1211 *_, model_outputs_node = opt_model.graph.nodes 1212 model_outputs = model_outputs_node.args[0] 1213 user_visible_outputs = dict.fromkeys( 1214 n.name for n in model_outputs if isinstance(n, torch.fx.Node) 1215 ) 1216 1217 static_input_idxs = list(range(num_fixed)) 1218 # constant params will be real tensors, not fake 1219 tracing_context = torch._guards.TracingContext.try_get() 1220 if tracing_context is not None: 1221 params_flat = tracing_context.params_flat 1222 assert params_flat is not None 1223 for i in range(len(params_flat)): 1224 if i not in preserved_arg_indices: 1225 params_flat[i] = None 1226 1227 if tracing_context.fw_metadata: 1228 static_input_idxs += tracing_context.fw_metadata.static_parameter_indices 1229 1230 with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): 1231 optimized_function = inner_compile( 1232 opt_model, 1233 aot_example_inputs, 1234 static_input_idxs=static_input_idxs, 1235 cudagraphs=cudagraphs, 1236 graph_id=graph_id, 1237 is_inference=True, 1238 boxed_forward_device_index=forward_device, 1239 layout_opt=layout_opt, 1240 user_visible_outputs=user_visible_outputs, 1241 ) 1242 1243 # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper 1244 # that drops constant-ified params 1245 if V.aot_compilation is True: 1246 return optimized_function 1247 1248 def wrapper(args): 1249 args_new = [args[i] for i in preserved_arg_indices] 1250 args.clear() 1251 return optimized_function(args_new) 1252 1253 wrapper._boxed_call = True # type: ignore[attr-defined] 1254 1255 return wrapper 1256 1257 1258@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) 1259def compile_fx( 1260 model_: torch.fx.GraphModule, 1261 example_inputs_: List[torch.Tensor], 1262 inner_compile: Callable[..., Any] = compile_fx_inner, 1263 config_patches: Optional[Dict[str, Any]] = None, 1264 decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, 1265): 1266 """Main entrypoint to a compile given FX graph""" 1267 if config_patches: 1268 with config.patch(config_patches): 1269 return compile_fx( 1270 model_, 1271 example_inputs_, 1272 # need extra layer of patching as backwards is compiled out of scope 1273 inner_compile=config.patch(config_patches)(inner_compile), 1274 decompositions=decompositions, 1275 ) 1276 1277 if config.cpp_wrapper: 1278 with config.patch( 1279 { 1280 "cpp_wrapper": False, 1281 "triton.autotune_cublasLt": False, 1282 "triton.cudagraphs": False, 1283 "triton.store_cubin": True, 1284 } 1285 ), V.set_real_inputs(example_inputs_): 1286 inputs_ = example_inputs_ 1287 if isinstance(model_, torch.fx.GraphModule): 1288 fake_inputs = [ 1289 node.meta.get("val") 1290 for node in model_.graph.nodes 1291 if node.op == "placeholder" 1292 ] 1293 if all(v is not None for v in fake_inputs): 1294 # Validate devices before switching to fake tensors. 1295 for idx, fi, i in zip(count(), fake_inputs, inputs_): 1296 if fi.device != i.device: 1297 raise ValueError( 1298 f"Device mismatch between fake input and example input at position #{idx}: " 1299 f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " 1300 "make sure torch.export() and torch.aot_compile() run on the same device." 1301 ) 1302 inputs_ = fake_inputs 1303 return compile_fx( 1304 model_, 1305 inputs_, 1306 inner_compile=functools.partial(inner_compile, cpp_wrapper=True), 1307 decompositions=decompositions, 1308 ) 1309 1310 recursive_compile_fx = functools.partial( 1311 compile_fx, 1312 inner_compile=inner_compile, 1313 decompositions=decompositions, 1314 ) 1315 1316 if not graph_returns_tuple(model_): 1317 return make_graph_return_tuple( 1318 model_, 1319 example_inputs_, 1320 recursive_compile_fx, 1321 ) 1322 1323 if isinstance(model_, torch.fx.GraphModule): 1324 if isinstance(model_.graph._codegen, _PyTreeCodeGen): 1325 # this graph is the result of dynamo.export() 1326 return handle_dynamo_export_graph( 1327 model_, 1328 example_inputs_, 1329 recursive_compile_fx, 1330 ) 1331 1332 model_ = _recursive_pre_grad_passes(model_, example_inputs_) 1333 1334 if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): 1335 return flatten_graph_inputs( 1336 model_, 1337 example_inputs_, 1338 recursive_compile_fx, 1339 ) 1340 1341 assert not config._raise_error_for_testing 1342 num_example_inputs = len(example_inputs_) 1343 cudagraphs = BoxedBool(config.triton.cudagraphs) 1344 forward_device = BoxedDeviceIndex(None) 1345 1346 graph_id = next(_graph_counter) 1347 1348 decompositions = ( 1349 decompositions if decompositions is not None else select_decomp_table() 1350 ) 1351 1352 @dynamo_utils.dynamo_timed 1353 def fw_compiler_base( 1354 model: torch.fx.GraphModule, 1355 example_inputs: List[torch.Tensor], 1356 is_inference: bool, 1357 ): 1358 if is_inference: 1359 # partition_fn won't be called 1360 _recursive_joint_graph_passes(model) 1361 1362 fixed = torch._inductor.utils.num_fw_fixed_arguments( 1363 num_example_inputs, len(example_inputs) 1364 ) 1365 1366 user_visible_outputs = {} 1367 1368 if config.keep_output_stride: 1369 model_outputs_node = output_node(model) 1370 model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) 1371 num_model_outputs = len(model_outputs) 1372 1373 context = torch._guards.TracingContext.try_get() 1374 # See Note [User Outputs in the inductor graph] 1375 if context is not None and context.fw_metadata and not is_inference: 1376 original_output_start_index = ( 1377 context.fw_metadata.num_mutated_inp_runtime_indices 1378 ) 1379 else: 1380 original_output_start_index = 0 1381 1382 if isinstance(model_, torch.fx.GraphModule): 1383 *_, orig_model_outputs_node = model_.graph.nodes 1384 assert orig_model_outputs_node.op == "output" 1385 orig_model_outputs, _ = pytree.tree_flatten( 1386 orig_model_outputs_node.args 1387 ) 1388 num_orig_model_outputs = len(orig_model_outputs) 1389 else: 1390 num_orig_model_outputs = num_model_outputs 1391 1392 assert num_orig_model_outputs <= num_model_outputs 1393 1394 # Note [User Outputs in the inductor graph] 1395 # We makes the following assumption 1396 # For inference 1397 # len(orig_model_outputs) == len(model_outputs) 1398 # For training 1399 # len(orig_model_outputs) <= len(model_outputs) 1400 # During training, most of the time the model_outputs starts with 1401 # original module's outputs followed by saved activations. 1402 # But this can be not true if the model have inplace updated tensors. 1403 # AOTAutograd will make those tensors being returned before the original 1404 # module's output. 1405 # To make things safe, we'll use original_output_start_index field 1406 # set by AOTAutograd to decide where the original module outputs start. 1407 orig_output_end_idx = original_output_start_index + num_orig_model_outputs 1408 # Sanity chec: we are about to splice out the "user" outputs from the full set 1409 # of "graph" outputs. Make sure we're within bounds. 1410 assert orig_output_end_idx <= num_model_outputs 1411 1412 user_visible_outputs = dict.fromkeys( 1413 n.name 1414 for n in model_outputs[original_output_start_index:orig_output_end_idx] 1415 if isinstance(n, torch.fx.Node) 1416 ) 1417 1418 return inner_compile( 1419 model, 1420 example_inputs, 1421 static_input_idxs=get_static_input_idxs(fixed), 1422 cudagraphs=cudagraphs, 1423 graph_id=graph_id, 1424 is_inference=is_inference, 1425 boxed_forward_device_index=forward_device, 1426 user_visible_outputs=user_visible_outputs, 1427 ) 1428 1429 fw_compiler = functools.partial(fw_compiler_base, is_inference=False) 1430 1431 if config.freezing and not torch.is_grad_enabled(): 1432 inference_compiler = functools.partial( 1433 fw_compiler_freezing, 1434 dynamo_model=model_, 1435 num_example_inputs=num_example_inputs, 1436 inner_compile=inner_compile, 1437 cudagraphs=cudagraphs, 1438 graph_id=graph_id, 1439 forward_device=forward_device, 1440 ) 1441 else: 1442 inference_compiler = functools.partial(fw_compiler_base, is_inference=True) 1443 1444 def partition_fn(graph, joint_inputs, **kwargs): 1445 _recursive_joint_graph_passes(graph) 1446 return min_cut_rematerialization_partition( 1447 graph, joint_inputs, **kwargs, compiler="inductor" 1448 ) 1449 1450 @compile_time_strobelight_meta(phase_name="bw_compiler") 1451 @dynamo_utils.dynamo_timed 1452 def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): 1453 user_visible_outputs = {} 1454 1455 if config.bw_outputs_user_visible: 1456 model_outputs_node = output_node(model) 1457 model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) 1458 user_visible_outputs = dict.fromkeys( 1459 n.name for n in model_outputs if isinstance(n, torch.fx.Node) 1460 ) 1461 fixed = count_tangents(model) 1462 return inner_compile( 1463 model, 1464 example_inputs, 1465 static_input_idxs=list(range(fixed)), 1466 cudagraphs=cudagraphs, 1467 is_backward=True, 1468 graph_id=graph_id, 1469 boxed_forward_device_index=forward_device, 1470 user_visible_outputs=user_visible_outputs, 1471 ) 1472 1473 # TODO: can add logging before/after the call to create_aot_dispatcher_function 1474 # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func 1475 # once torchdynamo is merged into pytorch 1476 1477 fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode( 1478 allow_non_fake_inputs=True 1479 ) 1480 tracing_context = ( 1481 torch._guards.TracingContext.try_get() 1482 or torch._guards.TracingContext(fake_mode) 1483 ) 1484 1485 if V.aot_compilation is True: 1486 with functorch_config.patch(unlift_effect_tokens=True): 1487 gm, graph_signature = aot_export_module( 1488 model_, 1489 example_inputs_, 1490 trace_joint=False, 1491 decompositions=decompositions, 1492 ) 1493 unlifted_gm = _unlift_graph(model_, gm, graph_signature) 1494 if "dynamo_flat_name_to_original_fqn" in model_.meta: 1495 unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ 1496 "dynamo_flat_name_to_original_fqn" 1497 ] 1498 1499 # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515) 1500 # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into 1501 # _sfdp_init() to register patterns. 1502 # When fallback_random is set to True, the sdpa patterns will be traced during runtime. 1503 # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which 1504 # will be the same as the generated FP16 patterns. 1505 disable_amp = torch._C._is_any_autocast_enabled() 1506 context = torch._C._DisableAutocast if disable_amp else contextlib.nullcontext 1507 with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context(): 1508 return inference_compiler(unlifted_gm, example_inputs_) 1509 1510 with V.set_fake_mode(fake_mode), torch._guards.tracing( 1511 tracing_context 1512 ), compiled_autograd.disable(), functorch_config.patch(unlift_effect_tokens=True): 1513 return aot_autograd( 1514 fw_compiler=fw_compiler, 1515 bw_compiler=bw_compiler, 1516 inference_compiler=inference_compiler, 1517 decompositions=decompositions, 1518 partition_fn=partition_fn, 1519 keep_inference_input_mutations=True, 1520 )(model_, example_inputs_) 1521 1522 1523def _shape_env_from_inputs(inputs: List[torch.Tensor]): 1524 shape_env = None 1525 fake_mode = detect_fake_mode(inputs) 1526 1527 # TODO(voz): It would be nice to enable this assert, but there are lots of tests that 1528 # pass in real inputs for now. 1529 # if len(inputs) > 0: 1530 # assert fake_mode is not None, breakpoint() 1531 1532 if fake_mode is not None: 1533 return fake_mode.shape_env 1534 1535 # When there are no tensor inputs, get shape_env from the first SymInt. 1536 for input in inputs: 1537 if isinstance(input, torch.SymInt): 1538 return input.node.shape_env 1539 1540 # TODO(voz): Should we always have one anyway? 1541 return None 1542 1543 1544def graph_returns_tuple(gm: torch.fx.GraphModule): 1545 """True if a FX graph returns a tuple""" 1546 if not isinstance(gm, torch.fx.GraphModule): 1547 return True # can't check this, assume true 1548 (rv,) = output_node(gm).args 1549 if isinstance(rv, (list, tuple)): 1550 return True 1551 if ( 1552 isinstance(rv, torch.fx.node.Node) 1553 and hasattr(rv.target, "_schema") 1554 and len(rv.target._schema.returns) > 1 1555 and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) 1556 ): 1557 # for graphs whose result is one node with multiple outputs 1558 return True 1559 return False 1560 1561 1562def make_graph_return_tuple( 1563 gm: torch.fx.GraphModule, 1564 inputs: List[torch.Tensor], 1565 compile_gm: Callable[..., Any], 1566): 1567 """ 1568 Mutate gm so it returns a tuple. This is only needed for graphs 1569 not created by torchdynamo that return non-tuples. 1570 """ 1571 node = output_node(gm) 1572 (rv,) = node.args 1573 rv, spec = pytree.tree_flatten(rv) 1574 with gm.graph.inserting_before(node): 1575 gm.graph.output(rv) 1576 gm.graph.erase_node(node) 1577 assert graph_returns_tuple(gm) 1578 1579 compiled_fn = compile_gm(gm, inputs) 1580 1581 @functools.wraps(compiled_fn) 1582 def wrapper(*args, **kwargs): 1583 return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) 1584 1585 return wrapper 1586 1587 1588def handle_dynamo_export_graph( 1589 gm: torch.fx.GraphModule, 1590 inputs: List[torch.Tensor], 1591 compile_gm: Callable[..., Any], 1592): 1593 """ 1594 `torch._dynamo.export` embeds pytrees in the FX graph codegen object, 1595 convert that to a normal FX graph so inductor can compile it. 1596 """ 1597 codegen = gm.graph._codegen 1598 gm.graph._codegen = torch.fx.graph.CodeGen() 1599 gm.recompile() 1600 1601 compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) 1602 1603 @functools.wraps(compiled_fn) 1604 def wrapper(*args): 1605 return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) 1606 1607 return wrapper 1608 1609 1610def _check_triton_bf16_support(graph: GraphLowering) -> None: 1611 def warn_and_skip(device) -> None: 1612 from torch._dynamo.exc import SkipFrame 1613 1614 device_props = torch.cuda.get_device_properties(device) 1615 warnings.warn( 1616 f"{device_props.name} does not support bfloat16 compilation natively, skipping" 1617 ) 1618 raise SkipFrame("BF16 is not supported") 1619 1620 for inp in graph.graph_inputs.values(): 1621 device = getattr(inp, "get_device", lambda: torch.device("meta"))() 1622 if device.type != "cuda" or inp.get_dtype() != torch.bfloat16: 1623 continue 1624 # Print warning and skip frame if attempting to compile for bfloat16 1625 # on device without hardware support for dtype 1626 if torch.cuda.is_bf16_supported(including_emulation=False): 1627 return 1628 warn_and_skip(device) 1629 1630 for out in graph.graph_outputs: 1631 device = getattr(out, "get_device", lambda: torch.device("meta"))() 1632 if device.type != "cuda" or out.get_dtype() != torch.bfloat16: 1633 continue 1634 # Print warning and skip frame if attempting to compile for bfloat16 1635 # on device without hardware support for dtype 1636 if torch.cuda.is_bf16_supported(including_emulation=False): 1637 return 1638 warn_and_skip(device) 1639