1# mypy: allow-untyped-defs 2import functools 3import itertools 4import logging 5import operator 6import os 7import re 8import sys 9import time 10from collections import defaultdict 11from contextlib import contextmanager 12from typing import ( 13 Any, 14 Callable, 15 DefaultDict, 16 Dict, 17 List, 18 Optional, 19 Set, 20 Tuple, 21 TYPE_CHECKING, 22 Union, 23) 24 25import sympy 26 27import torch 28import torch._logging 29import torch.fx 30from torch._decomp import get_decompositions 31from torch._dynamo.utils import defake, dynamo_timed 32from torch._logging import LazyString, trace_structured 33from torch._prims_common import make_channels_last_strides_for 34from torch._subclasses.fake_tensor import FakeTensor 35from torch.fx.experimental._backward_state import BackwardState 36from torch.fx.experimental.sym_node import magic_methods, method_to_operator 37from torch.fx.experimental.symbolic_shapes import ( 38 free_unbacked_symbols, 39 has_free_symbols, 40 resolve_unbacked_bindings, 41 RuntimeAssert, 42 ShapeEnv, 43 SymTypes, 44) 45from torch.utils._mode_utils import no_dispatch 46from torch.utils._sympy.numbers import int_oo 47 48from . import config, ir 49from .codegen.common import ( 50 DeviceOpOverrides, 51 get_device_op_overrides, 52 get_scheduling_for_device, 53 get_wrapper_codegen_for_device, 54 register_backend_for_device, 55) 56from .codegen.cpp_wrapper_cpu import CppWrapperCpu 57from .codegen.cpp_wrapper_cuda import CppWrapperCuda 58from .codegen.wrapper import WrapperCodeGen 59from .exc import ( 60 CppWrapperCodeGenError, 61 LoweringException, 62 MissingOperatorWithDecomp, 63 MissingOperatorWithoutDecomp, 64) 65from .ir import ( 66 Constant, 67 FixedLayout, 68 InputBuffer, 69 Pointwise, 70 Reduction, 71 StorageBox, 72 TensorBox, 73 TorchBindObject, 74) 75from .lowering import ( 76 constrain_to_fx_strides, 77 FALLBACK_ALLOW_LIST, 78 fallback_handler, 79 fallback_node_due_to_unsupported_type, 80 layout_constraints, 81 lowerings, 82 make_fallback, 83 needs_realized_inputs, 84 unsupported_output_tensor, 85) 86from .sizevars import SizeVarAllocator 87from .utils import ( 88 convert_shape_to_inductor, 89 gather_origins, 90 get_cloned_parameter_buffer_name, 91 get_sympy_Expr_dtype, 92 maybe_get_suppress_shape_guards_ctx, 93 should_assume_input_aligned, 94) 95from .virtualized import NullHandler, V 96 97if TYPE_CHECKING: 98 from torch._higher_order_ops.effects import _EffectType 99 100log = logging.getLogger(__name__) 101perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") 102output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") 103aten = torch.ops.aten 104 105_post_grad_graph_counter = itertools.count() 106 107if config.is_fbcode(): 108 from torch._inductor.fb.utils import log_module_code 109else: 110 111 def log_module_code(*args, **kwargs): 112 pass 113 114 115def supported_dtype_of_cpp_wrapper(dtype, cuda): 116 supported_dtype = { 117 torch.float32, 118 torch.float64, 119 torch.int64, 120 torch.int32, 121 torch.int16, 122 torch.int8, 123 torch.uint8, 124 torch.bool, 125 torch.bfloat16, 126 torch.complex32, 127 torch.complex64, 128 torch.complex128, 129 torch.float16, 130 } 131 if cuda: 132 supported_dtype.add(torch.float8_e4m3fn) 133 supported_dtype.add(torch.float8_e5m2) 134 supported_dtype.add(torch.float8_e4m3fnuz) 135 supported_dtype.add(torch.float8_e5m2fnuz) 136 137 return dtype in supported_dtype 138 139 140def may_get_constant_buffer_dtype(constant_buffer): 141 assert isinstance( 142 constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) 143 ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" 144 if isinstance(constant_buffer, sympy.core.numbers.Integer): 145 return torch.int64 146 147 if isinstance(constant_buffer, sympy.Expr): 148 return get_sympy_Expr_dtype(constant_buffer) 149 150 if constant_buffer.is_integer: 151 return torch.int64 152 elif constant_buffer.is_float: 153 return torch.float32 154 else: 155 return None 156 157 158def is_magic_method(op): 159 magic_ops = {method_to_operator(m) for m in magic_methods} 160 return op in magic_ops 161 162 163def getattr_recursive(obj, target): 164 target_atoms = target.split(".") 165 attr_itr = obj 166 for i, atom in enumerate(target_atoms): 167 if not hasattr(attr_itr, atom): 168 raise RuntimeError( 169 f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" 170 ) 171 attr_itr = getattr(attr_itr, atom) 172 return attr_itr 173 174 175def mark_nodes_dislike_padding(g): 176 """ 177 Nodes like convolution/convolution_backward want its input to be dense. 178 If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. 179 180 The pass finds nodes that dislike padding. These are nodes that can be reached 181 from a convolution/convolution_backward in the backward direction without 182 going thru a reduction. 183 """ 184 if not config.comprehensive_padding: 185 return 186 ops_dislike_padding = { 187 aten.convolution, 188 aten.convolution_backward, 189 } 190 # what's a better way to collect the reduction ops? 191 ops_like_padding = { 192 aten.var_mean, 193 aten.sum, 194 aten.mean, 195 aten.prod, 196 aten.any, 197 aten.amin, 198 aten.amax, 199 aten.min, 200 aten.max, 201 aten.argmin, 202 aten.argmax, 203 aten.scatter_reduce, 204 } 205 206 def _get_overload_packet(node): 207 return ( 208 node.target._overloadpacket 209 if node.op == "call_function" and hasattr(node.target, "_overloadpacket") 210 else None 211 ) 212 213 for cur in reversed(g.nodes): 214 op = _get_overload_packet(cur) 215 if not op: 216 continue 217 if op in ops_dislike_padding: 218 cur.meta["dislike_padding"] = True 219 220 if cur.meta.get("dislike_padding", False): 221 # propagate 222 for prior in cur.all_input_nodes: 223 prior_op = _get_overload_packet(prior) 224 if not prior_op: 225 continue 226 if prior_op not in ops_like_padding: 227 prior.meta["dislike_padding"] = True 228 229 230class GraphLowering(torch.fx.Interpreter): 231 graph_outputs: List[ir.IRNode] 232 233 def symbolic_sizes_strides(self, ex: torch.Tensor): 234 """ 235 Support dynamic shapes and dynamic strides by assigning variables 236 to each dimension. We duck-shape tensors, so if two tensors 237 have the same size they get assigned the same symbolic variable. 238 """ 239 if self.reuse_shape_env: 240 return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( 241 ex.stride() 242 ) 243 else: 244 from torch._dynamo.source import ConstantSource 245 246 # TODO: this should not be needed once #93059 lands 247 # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 248 # TODO: make a dedicated UnknownSource for this? 249 # NB: This is using the legacy default behavior from 250 # create_symbolic_sizes_strides_storage_offset but we hope we can 251 # just delete this entirely 252 source = ConstantSource( 253 f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" 254 ) 255 ( 256 size, 257 stride, 258 _, 259 ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( 260 ex, 261 source, 262 ) 263 264 size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] 265 stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] 266 return size, stride 267 268 def static_sizes_strides(self, ex: torch.Tensor): 269 """ 270 Primarily used to weights 271 """ 272 size = [sympy.Integer(i) for i in ex.size()] 273 stride = [sympy.Integer(i) for i in ex.stride()] 274 return size, stride 275 276 def init_backend_registration(self): 277 if get_scheduling_for_device("cpu") is None: 278 from .codegen.cpp import CppScheduling 279 280 register_backend_for_device( 281 "cpu", CppScheduling, WrapperCodeGen, CppWrapperCpu 282 ) 283 284 if get_scheduling_for_device("cuda") is None: 285 from .codegen.cuda_combined_scheduling import CUDACombinedScheduling 286 287 # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation 288 register_backend_for_device( 289 "cuda", CUDACombinedScheduling, WrapperCodeGen, CppWrapperCuda 290 ) 291 292 if get_scheduling_for_device("xpu") is None: 293 from .codegen.triton import TritonScheduling 294 295 register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) 296 297 def __init__( 298 self, 299 gm: torch.fx.GraphModule, 300 example_inputs: Optional[List[torch.Tensor]] = None, 301 shape_env=None, 302 graph_id=None, 303 cpp_wrapper=False, 304 aot_mode=False, 305 user_visible_outputs=None, 306 layout_opt=None, 307 extern_node_serializer=None, 308 is_inference=False, 309 is_const_graph=False, 310 const_output_index=None, 311 const_code=None, 312 const_module=None, 313 name=None, 314 ): 315 super().__init__(gm) 316 self.example_inputs = example_inputs 317 self.layout_opt = ( 318 layout_opt 319 if layout_opt is not None 320 else self.decide_layout_opt(gm, is_inference=is_inference) 321 ) 322 self.num_channels_last_conv = 0 323 self.is_inference = is_inference 324 self.is_const_graph = is_const_graph 325 self.const_code = const_code 326 self.const_module = const_module 327 328 self.extra_traceback = False # we do our own error wrapping 329 if shape_env is None: 330 shape_env = ShapeEnv() 331 self.reuse_shape_env = False 332 else: 333 self._shape_env = shape_env 334 self.reuse_shape_env = True 335 self._shape_env = shape_env 336 # We are going to start code generating runtime asserts, so make sure 337 # you don't start adding new ones in the lowering process 338 shape_env.freeze_runtime_asserts() 339 # We're going to mutate ras_by_symbol as we finish generating them 340 self.ras_by_symbol: Dict[ 341 sympy.Symbol, List[RuntimeAssert] 342 ] = shape_env.deferred_runtime_asserts.copy() 343 self.bound_unbacked_symbols: Set[sympy.Symbol] = set() 344 self.sizevars = SizeVarAllocator(shape_env) 345 self.graph_input_names: List[str] = [] 346 self.graph_inputs: Dict[str, TensorBox] = {} 347 self.graph_inputs_original: Dict[str, InputBuffer] = {} 348 self.device_types: Set[str] = ( 349 const_module.device_types if const_module else set() 350 ) 351 self.device_idxs: Set[int] = const_module.device_idxs if const_module else set() 352 self.cuda = False 353 self.buffers: List[ir.Buffer] = [] 354 self.const_output_index: Dict[str, int] = ( 355 const_output_index if const_output_index else {} 356 ) 357 self.folded_constants: Set[str] = ( 358 set(const_output_index.keys()) if const_output_index else set() 359 ) 360 self.constants: Dict[str, torch.Tensor] = ( 361 const_module.constants if const_module else {} 362 ) 363 self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {} 364 self.constant_reprs: Dict[str, str] = {} 365 self.removed_buffers: Set[str] = set() 366 self.removed_inplace_buffers: Set[str] = set() 367 self.mutated_buffers: Set[str] = set() 368 self.never_reuse_buffers: Set[str] = set() 369 self.inplaced_to_remove: Set[str] = set() 370 self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] 371 self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] 372 # See `ProxyExecutor Design Note` in ir.py for more details 373 self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] 374 self.extern_node_serializer: Optional[ 375 Callable[[List[ir.ExternKernelNode]], Any] 376 ] = extern_node_serializer 377 self.current_node: torch.fx.Node = None # type: ignore[assignment] 378 self.lists: Dict[str, List[str]] = {} 379 self.mutated_inputs: Set[str] = set() 380 self.mutated_input_idxs: List[int] = [] 381 self.name_to_buffer: Dict[str, ir.Buffer] = {} 382 self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) 383 self.creation_time = time.time() 384 self.name = name 385 self.cpp_wrapper = cpp_wrapper 386 387 # record multi_kernel choice for cpp_wrapper so the second pass knows 388 # which sub-kernel is picked. Copy cpp_wrapper to another variable 389 # since cpp_wrapper flag is set to false for the first pass of codegen. 390 self.record_multi_kernel_choice = cpp_wrapper 391 self.multi_kernel_to_choice: Dict[str, int] = {} 392 393 self.aot_mode = aot_mode 394 self.graph_id = graph_id 395 self.post_grad_graph_id = next(_post_grad_graph_counter) 396 self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] 397 self.nodes_prefer_channels_last = ( 398 self.find_nodes_prefer_channels_last() if self.layout_opt else set() 399 ) 400 mark_nodes_dislike_padding(gm.graph) 401 self._warned_fallback = {"aten.convolution_backward"} 402 self.user_visible_outputs = ( 403 user_visible_outputs if user_visible_outputs is not None else {} 404 ) 405 self.cache_key: str = "" # This is the cache key for the compiled artifact 406 self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored 407 self.cache_linemap: List[ 408 Tuple[int, str] 409 ] = ( 410 [] 411 ) # This is the linemap used by the profiler to mark custom compiled kernels getting run 412 # Used if lowering encounters cases where cudagraphs are not supported 413 self.disable_cudagraphs_reason: Optional[str] = None 414 415 # only keeping one node per device for stack trace purposes 416 self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} 417 self.orig_gm: torch.fx.GraphModule = gm.__copy__() 418 self.dynamo_flat_name_to_original_fqn = self.module.meta.get( 419 "dynamo_flat_name_to_original_fqn", {} 420 ) 421 self.allocated_constant_name = ( 422 const_module.allocated_constant_name if const_module is not None else {} 423 ) 424 self.init_backend_registration() 425 426 self.effectful_ops: Dict[_EffectType, ir.Buffer] = {} 427 428 self.aligned_inputs: Set[str] = set() 429 430 @staticmethod 431 def decide_layout_opt(gm, *, is_inference) -> bool: 432 """ 433 Decide if we should enable layout optimization for this graph based on 434 heuristics. 435 """ 436 if not config.layout_optimization: 437 return False 438 439 if config.force_layout_optimization: 440 return True 441 442 conv_nodes = [ 443 n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default 444 ] 445 nconv = len(conv_nodes) 446 447 if nconv == 0: 448 return False 449 450 # For cpu backend and mkldnn enabled, we always use channels_last for better performance. 451 if ( 452 torch.backends.mkldnn.enabled 453 and torch.backends.mkldnn.is_available() 454 and all( 455 n.args[idx].meta["val"].device == torch.device("cpu") 456 for n in conv_nodes 457 for idx in [0, 1] 458 ) 459 ): 460 return True 461 462 # Following models are skipped due to this: 463 # jx_nest_base 464 # volo_d1_224 465 if len(list(gm.graph.nodes)) >= 300 * nconv: 466 log.debug("Skipped layout opt because only a few conv") 467 return False 468 469 if any( 470 has_free_symbols(n.args[idx].meta["val"]) 471 for n in conv_nodes 472 for idx in [0, 1] 473 ): 474 log.debug( 475 "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" 476 ) 477 return False 478 479 def is_grouped(n): 480 return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1 481 482 def is_in_out_channel(n): 483 return ( 484 n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) 485 and n.args[1].meta["val"].size(2) > 1 486 ) 487 488 def is_small_channel(n): 489 return ( 490 n.args[1].meta["val"].size(0) <= 64 491 and n.args[1].meta["val"].size(1) <= 64 492 ) 493 494 # only grouped convolutions benchmarked as slower in conv samples for inference only 495 if is_inference: 496 from torch.utils.flop_counter import FlopCounterMode 497 498 flop_counts: Dict[str, float] = defaultdict(float) 499 for node in conv_nodes: 500 success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( 501 node 502 ) 503 504 if success: 505 with FlopCounterMode(display=False) as flop_counter_mode: 506 with V.fake_mode: 507 node.target(*args, **kwargs) 508 509 counted_flops = flop_counter_mode.get_total_flops() 510 if is_grouped(node): 511 node_type = "grouped" 512 elif is_small_channel(node): 513 node_type = "small" 514 elif is_in_out_channel(node): 515 node_type = "in_out" 516 else: 517 node_type = "default" 518 519 flop_counts[node_type] += counted_flops 520 else: 521 log.debug("Conv inputs meta not found") 522 523 # average benchmarked channels last speedup / slowdown, < 1 is speedup. 524 # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ 525 # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb 526 GROUPED_MULTIPLIER = 1.358 527 DEFAULT_MULTIPLIER = 0.823 528 IN_OUT_MULTIPLIER = 0.725 529 SMALL_MULTIPLIER = 0.783 530 531 total_flops = sum(flop_counts.values()) 532 # TODO - get different values per hardware 533 weighted_flops = ( 534 flop_counts["grouped"] * GROUPED_MULTIPLIER 535 + flop_counts["small"] * SMALL_MULTIPLIER 536 + flop_counts["in_out"] * IN_OUT_MULTIPLIER 537 + flop_counts["default"] * DEFAULT_MULTIPLIER 538 ) 539 do_layout_opt = weighted_flops <= total_flops 540 if not do_layout_opt: 541 log.debug( 542 "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", 543 total_flops, 544 weighted_flops, 545 ) 546 return do_layout_opt 547 548 # Channels last layout can dramatically hurt grouped conv perf. E.g. 549 # Conv with arguments like 550 # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], 551 # "stride": [2, 2], "padding": [1, 1], "groups": 2} 552 # slows down 31x using channels last.. 553 554 # But a lot of timm models use depthwise separable convolution which will 555 # result in grouped convolution with in-channel size == 1. 556 # For those grouped convolution, channels last still helps a lot. 557 # E.g. 558 # Conv with arguments 559 # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], 560 # "stride": [2, 2], "padding": [1, 1], "groups": 58} 561 # get 1.86x speedup with channels last layout. 562 # 563 # The following heuristics skip using channels-last if the model contains 564 # grouped convolution with in-channels > 1. 565 if any(map(is_grouped, conv_nodes)): 566 log.debug( 567 "Skip layout opt because found grouped convolution with >1 in_channels!" 568 ) 569 return False 570 571 # For some models that contain convolution with larger in-channel than out-channel, applying 572 # channels last hurts performance. 573 # Following models are skipped due to this: 574 # - pytorch_unet 575 # - phlippe_densenet (slightly worse) 576 # - Background_Matting (1.22x -> 0.821x) 577 # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) 578 if any(map(is_in_out_channel, conv_nodes)): 579 log.debug( 580 "Skip layout opt because some convolutions have smaller out_channel" 581 ) 582 return False 583 584 # Following models are skipped due to this: 585 # - functorch_maml_omniglot 586 if all(map(is_small_channel, conv_nodes)): 587 log.debug("Skip layout opt because all convolution channels are too small") 588 return False 589 590 return True 591 592 def qualify_name(self, name: str) -> str: 593 """Prepend the given name with the graph name if any.""" 594 if self.name is not None: 595 return f"{self.name}_{name}" 596 return name 597 598 def make_subgraph( 599 self, 600 gm: torch.fx.GraphModule, 601 example_inputs: List[torch.Tensor], 602 subgraph_name: str, 603 ) -> "GraphLowering": 604 """ 605 Make a subgraph of the current graph with all inherited 606 parts, except the graph module (`gm`) and `example_inputs`. 607 The subgraphs are lowered separately, but intended to be 608 inlined in the parent graph's codegening. Hence the need 609 for maintaining the same `shape_env` and other properties. 610 The subgraph name is qualified by the parent graph's name. 611 """ 612 return GraphLowering( 613 gm=gm, 614 example_inputs=example_inputs, 615 shape_env=self._shape_env, 616 cpp_wrapper=self.cpp_wrapper, 617 aot_mode=self.aot_mode, 618 extern_node_serializer=self.extern_node_serializer, 619 is_inference=self.is_inference, 620 name=self.qualify_name(subgraph_name), 621 ) 622 623 def find_nodes_prefer_channels_last(self): 624 """ 625 The rule to decide if an node prefer channels last is simple. 626 1. if it's input/output of a convolution 627 2. if one of its user prefers channels last 628 629 We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; 630 Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers 631 channels last. 632 633 Consider the scenario: conv -> batch-norm -> relu -> conv 634 Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: 635 1. the output of batch-norm should be channels last initially since its input is a conv's output. 636 Forcing the batch-norm's output to be contiguous results in the first copy 637 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. 638 We need convert it to channels last layout which results in the second copy. 639 With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies 640 can be saved. 641 """ 642 output_set = set() 643 for n in reversed(self.module.graph.nodes): 644 if n.target == torch.ops.aten.convolution.default: 645 output_set.add(n) 646 continue 647 648 for user in n.users: 649 if user in output_set: 650 output_set.add(n) 651 break 652 653 # need a second pass to add downstream nodes of those channel last nodes to the sets. 654 # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. 655 # 656 # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned 657 # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. 658 # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last 659 # tensors and passed to a kernel. 660 # 661 # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. 662 # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . 663 # This also helps the following models: 664 # - res2net101_26w_4s 665 # - res2net50_14w_8s 666 # - sebotnet33ts_256 667 for n in self.module.graph.nodes: 668 if n in output_set: 669 output_set.update(n.users) 670 671 return output_set 672 673 def warn_fallback(self, name): 674 if name not in self._warned_fallback: 675 self._warned_fallback.add(name) 676 perf_hint_log.info("Using FallbackKernel: %s", name) 677 678 def add_device_info(self, device: torch.device): 679 self.device_types.add(device.type) 680 if device.index is not None: 681 self.device_idxs.add(device.index) 682 if V.graph.current_node and device not in self.device_node_mapping: 683 self.device_node_mapping[device] = V.graph.current_node 684 685 @property 686 def fake_mode(self): 687 return V.fake_mode 688 689 def get_buffer(self, buffer_name: str): 690 if buffer_name in self.name_to_buffer: 691 return self.name_to_buffer[buffer_name] 692 if buffer_name in self.graph_inputs: 693 return self.graph_inputs[buffer_name] 694 if buffer_name in self.constants: 695 data = V.graph.constants[buffer_name] 696 return ir.ConstantBuffer( 697 buffer_name, 698 ir.FixedLayout( 699 data.device, data.dtype, *V.graph.static_sizes_strides(data) 700 ), 701 ) 702 return None 703 704 def get_dtype(self, buffer_name: str): 705 if buffer_name in self.constants: 706 return self.constants[buffer_name].dtype 707 if buffer_name in self.name_to_buffer: 708 return self.name_to_buffer[buffer_name].get_dtype() 709 if buffer_name in self.graph_inputs: 710 return self.graph_inputs[buffer_name].get_dtype() 711 m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) 712 if m: 713 return self.get_dtype(m.group(1)) 714 raise KeyError(f"could not find {buffer_name}") 715 716 def get_numel(self, buffer_name: str): 717 from .ir import MultiOutputLayout 718 719 if buffer_name in self.constants: 720 return self.constants[buffer_name].numel() 721 if buffer_name in self.name_to_buffer: 722 buf = self.name_to_buffer[buffer_name] 723 if isinstance(getattr(buf, "layout", None), MultiOutputLayout): 724 return 1 725 return buf.get_numel() 726 if buffer_name in self.graph_inputs: 727 return self.graph_inputs[buffer_name].get_numel() 728 raise KeyError(f"could not find {buffer_name}") 729 730 @dynamo_timed 731 def run(self, *args): 732 return super().run(*args) 733 734 def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False): 735 name = self.qualify_name(f"buf{len(self.buffers)}") 736 self.buffers.append(buffer) 737 self.name_to_buffer[name] = buffer 738 # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 739 if ( 740 not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) 741 and buffer.get_device() is not None 742 ): 743 self.add_device_info(buffer.get_device()) 744 745 if set_name: 746 buffer.name = name 747 return name 748 749 def register_list(self, buffer_names: List[str]): 750 name = self.qualify_name("list_" + "_".join(buffer_names)) 751 self.lists[name] = buffer_names 752 return name 753 754 def register_users_of(self, node_output): 755 def register(value): 756 if isinstance(value, (list, tuple)): 757 for x in value: 758 register(x) 759 if isinstance(value, ir.IRNode): 760 if ( 761 not hasattr(value, "data") 762 or not isinstance(value.data, ir.IRNode) 763 or not ( 764 hasattr(value.data, "data") 765 and isinstance(value.data.data, ir.IRNode) 766 ) 767 ): 768 return 769 770 for read_name in value.get_read_names(): 771 self.name_to_users[read_name].append(value) 772 773 register(node_output) 774 775 def mark_buffer_mutated(self, name: str): 776 """ 777 When a buffer is mutated we need to make sure all the reads to 778 the old version are realized before the mutation happens. 779 """ 780 assert isinstance(name, str) 781 self.mutated_buffers.add(name) 782 783 if name not in self.name_to_users: 784 return 785 786 for user in self.name_to_users[name]: 787 user.realize() 788 789 def get_original_value_of_constant(self, name: str): 790 """ 791 In AOTI, module buffers may have been mutated during the tracing and compilation. 792 Thus we need to read from previously stored original buffers, to make sure the 793 generated model.so uses correct initial values. 794 """ 795 assert name in self.allocated_constant_name and name in self.constants, ( 796 "Can not find the original value for " + name 797 ) 798 orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) 799 return ( 800 self.module.meta[orig_name] 801 if orig_name in self.module.meta 802 else self.constants[name] 803 ) 804 805 def allocate_non_dup_const_name(self, name, data): 806 orig_name = name 807 if not config.aot_inductor.use_runtime_constant_folding: 808 for constant_name, value in self.constants.items(): 809 if ( 810 not data.is_mkldnn 811 and data.size() == value.size() 812 and data.stride() == value.stride() 813 and data.dtype == value.dtype 814 and data.device == value.device 815 and data.untyped_storage().data_ptr() 816 == value.untyped_storage().data_ptr() 817 and data.storage_offset() == value.storage_offset() 818 ): 819 return constant_name 820 821 if name is None: 822 name = f"constant{len(self.constants)}" 823 if name[0].isdigit(): 824 name = f"constant_{name}" 825 name = self.qualify_name(name) 826 # We may generate a var name for each constant in the codegen. 827 # Let's only keep sane characters. 828 prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) 829 name = prefix 830 cnt = 0 831 while name in self.constants: 832 name = f"{prefix}_{cnt}" 833 cnt += 1 834 self.constants[name] = data 835 self.constant_reprs[name] = ( 836 f"{data.device!r} {data.dtype!r} " 837 f"{tuple(data.size())!r} {tuple(data.stride())!r} " 838 f"{hash(data):x}" 839 ) 840 self.allocated_constant_name[name] = orig_name 841 return name 842 843 def add_tensor_constant(self, data, name=None): 844 new_name = self.allocate_non_dup_const_name(name, data) 845 return TensorBox.create( 846 ir.ConstantBuffer( 847 new_name, 848 FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), 849 ) 850 ) 851 852 def constant_name(self, name: str, device_override: Optional[torch.device]): 853 """ 854 We AOT copy constants to the devices they are needed on. 855 If device_override doesn't match the constant's device, then 856 copy it and return a different name. 857 """ 858 if self.constants[name].device == device_override or device_override is None: 859 return name 860 with torch.utils._python_dispatch._disable_current_modes(): 861 # caller might have set fake tensor mode which will create a fake tensor 862 # when calling .to, so unset modes here 863 return self.allocate_non_dup_const_name( 864 f"{name}_{device_override.type}{device_override.index or 0}", 865 self.constants[name].to(device_override), 866 ) 867 868 def placeholder(self, target: str, args, kwargs): 869 example = super().placeholder(target, args, kwargs) 870 self.graph_input_names.append(target) 871 if isinstance(example, SymTypes): 872 expr = example.node.expr 873 self.graph_inputs[target] = expr 874 return expr 875 elif isinstance(example, (int, bool, float)): 876 expr = sympy.sympify(example) 877 self.graph_inputs[target] = expr 878 return expr 879 if isinstance(example, BackwardState): 880 # Ignored arg, must be unused 881 # Alternately we could filter this out in AotAutograd 882 return None 883 assert isinstance(example, torch.Tensor), example 884 # todo(chilli): We can remove the last check once we turn buffers into 885 # static shape tensors. That's a hack to workaround Inductor believing 886 # the buffer should be static but us passing in a fake tensor with 887 # symbolic shapes. 888 if not example._has_symbolic_sizes_strides: 889 # the first N inputs are weights 890 sizes, strides = self.static_sizes_strides(example) 891 else: 892 sizes, strides = self.symbolic_sizes_strides(example) 893 # TODO(jansel): handle input aliasing 894 target = self.qualify_name(target) 895 tensor = TensorBox.create( 896 InputBuffer( 897 target, 898 FixedLayout(example.device, example.dtype, sizes, strides), 899 ) 900 ) 901 self.graph_inputs[target] = tensor 902 self.graph_inputs_original[target] = tensor.data.data 903 self.add_device_info(example.device) 904 905 # Note: [Input Alignment handling in Inductor] 906 # Alignment matters for generating efficient code. Some operations, 907 # e.g. vectorized loads, can only be performed on aligned inputs. 908 # 909 # But if we codegen assuming aligned inputs and then get unaligned 910 # inputs at runtime, then we are forced to clone - which is bad for 911 # both perf and memory usage. 912 # 913 # One option would be to guard on storage_offset%ALIGNMENT, and then 914 # codegen based on this. But storage_offset guards turned out to be 915 # expensive and cause recompiles; Instead, we're generating code 916 # based on the alignment of the example input without guarding. 917 with maybe_get_suppress_shape_guards_ctx(): 918 if should_assume_input_aligned(example): 919 self.aligned_inputs.add(target) 920 return tensor 921 922 def call_function(self, target, args, kwargs): 923 if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): 924 return super().call_function(target, args, kwargs) 925 926 if hasattr(target, "_inductor_lowering_function"): 927 # passthrough lowerings from .pattern_matcher 928 return target(*args, **kwargs) 929 930 def get_custom_op_layout_constraints(target, args, kwargs): 931 # Custom operations that require preserving stride order 932 # which run through implicit fallback must constrain their 933 # arguments' fx strides 934 layout_constraint = None 935 if torch._C.Tag.needs_fixed_stride_order in target.tags: 936 # We have to set the current args because call_function will immediately 937 # evaluate this lowering after creating the fallback, without evaluating 938 # the layout constraint 939 constrain_fn = functools.partial( 940 constrain_to_fx_strides, ignore_mutated_args_FIXME=True 941 ) 942 args, kwargs = constrain_fn(self.current_node, *args, **kwargs) 943 # Also register the layout constraint so when the fallback 944 # is used again, we can constrain the args to the same layout 945 layout_constraint = constrain_fn 946 return layout_constraint, args, kwargs 947 948 if target not in lowerings: 949 assert isinstance( 950 target, torch._ops.OpOverload 951 ), f"{target} is not an OpOverload" 952 base_name = target.name().split(".")[0] 953 if base_name in FALLBACK_ALLOW_LIST: 954 make_fallback(target) 955 elif config.implicit_fallbacks: 956 layout_constraint, args, kwargs = get_custom_op_layout_constraints( 957 target, args, kwargs 958 ) 959 error = ( 960 MissingOperatorWithDecomp 961 if get_decompositions([target]) 962 else MissingOperatorWithoutDecomp 963 ) 964 log.info( 965 "Creating implicit fallback for:\n%s", 966 error.operator_str(target, args, kwargs), 967 ) 968 make_fallback(target, layout_constraint) 969 970 elif get_decompositions([target]): 971 # There isn't a good way to dynamically patch this in 972 # since AOT Autograd already ran. The error message tells 973 # the user how to fix it. 974 raise MissingOperatorWithDecomp(target, args, kwargs) 975 else: 976 raise MissingOperatorWithoutDecomp(target, args, kwargs) 977 978 try: 979 log.debug(" via %s", lowerings[target]) 980 out = lowerings[target](*args, **kwargs) 981 return out 982 except Exception as e: 983 raise LoweringException(e, target, args, kwargs).with_traceback( 984 e.__traceback__ 985 ) from None 986 987 @staticmethod 988 def can_inline_constant(t: torch.Tensor) -> bool: 989 """ 990 True if this is a small constant attr that will be inlined. 991 """ 992 return len(t.shape) == 1 and t.shape[0] <= 8 993 994 def get_attr(self, target, args, kwargs): 995 # this is a constant 996 value = getattr_recursive(self.module, target) 997 998 if isinstance(value, torch.fx.GraphModule): 999 return ir.Subgraph(name=target, graph_module=value) 1000 1001 if isinstance(value, torch._C.ScriptObject): 1002 self.torchbind_constants[target] = value 1003 self.constant_reprs[target] = "" 1004 return TorchBindObject(target, value) 1005 1006 if ( 1007 config.aot_inductor.use_runtime_constant_folding 1008 or config.always_keep_tensor_constants 1009 or unsupported_output_tensor(value) 1010 ): 1011 return self.add_tensor_constant(value, target) 1012 1013 with no_dispatch(): 1014 if value.shape == (): 1015 return Constant(value.item(), value.dtype, value.device) 1016 if self.can_inline_constant(value): 1017 # tensor lowering has constant inlining logic 1018 from .lowering import tensor 1019 1020 return tensor(value.tolist(), dtype=value.dtype, device=value.device) 1021 1022 return self.add_tensor_constant(value, target) 1023 1024 def call_module(self, target, args, kwargs): 1025 raise AssertionError 1026 1027 def call_method(self, target, args, kwargs): 1028 raise AssertionError 1029 1030 def output(self, target, args, kwargs): 1031 result = super().output(target, args, kwargs) 1032 if not isinstance(result, (tuple, list)): 1033 # nested subgraphs can have singleton outputs 1034 result = (result,) 1035 assert isinstance(result, (tuple, list)), type(result) 1036 assert all( 1037 isinstance( 1038 x, 1039 ( 1040 TensorBox, 1041 ir.Constant, 1042 type(None), 1043 ir.ConstantBuffer, 1044 sympy.Expr, 1045 sympy.logic.boolalg.Boolean, 1046 int, 1047 ir.EffectfulKernel, 1048 ), 1049 ) 1050 for x in result 1051 ), result 1052 1053 fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type] 1054 if not isinstance(fx_node_args, (tuple, list)): 1055 # nested subgraphs can have singleton outputs 1056 fx_node_args = (fx_node_args,) 1057 result = [ir.ExternKernel.realize_input(x) for x in result] 1058 result_correct_strides = [] 1059 1060 assert len(fx_node_args) == len(result) 1061 for r, fx_node in zip(result, fx_node_args): 1062 if not isinstance(r, (ir.TensorBox, ir.BaseView)): 1063 result_correct_strides.append(r) 1064 else: 1065 # AOT Autograd tries to detect stride divergence of inductor from output metadata. 1066 # Here, we try to avoid spurious divergence by matching insignificant strides such as 1067 result_correct_strides.append( 1068 self.try_match_insignificant_strides( 1069 r, fx_node.meta["val"].stride() 1070 ) 1071 ) 1072 1073 self.graph_outputs = result_correct_strides 1074 value: ir.IRNode 1075 for name, value in self.graph_inputs.items(): 1076 assert isinstance( 1077 value, (TensorBox, sympy.Expr) 1078 ), f"Unsupported inductor graph input type: {type(value)}" 1079 if not isinstance(value, TensorBox): 1080 continue 1081 value.realize() 1082 assert isinstance(value, TensorBox) 1083 value = value.data 1084 assert isinstance(value, ir.StorageBox) 1085 value_storage_box = value 1086 value = value.data 1087 if not isinstance(value, InputBuffer) or value.get_name() != name: 1088 # one of our inputs was mutated, need to turn that into a copy 1089 ir.MutationLayoutSHOULDREMOVE.realize_into( 1090 value, self.graph_inputs_original[name] 1091 ) 1092 # replace output with mutated input 1093 try: 1094 ind = self.graph_outputs.index(value_storage_box) 1095 self.graph_outputs[ind] = self.graph_inputs_original[name] 1096 except ValueError: 1097 pass 1098 1099 self.finalize() 1100 log.debug( 1101 "Force channels last inputs for %d conv for the current graph with id %d", 1102 self.num_channels_last_conv, 1103 self.graph_id if self.graph_id is not None else -1, 1104 ) 1105 1106 def finalize(self): 1107 for buf in self.buffers: 1108 buf.decide_layout() 1109 1110 @contextmanager 1111 def set_current_node(self, node: torch.fx.Node): 1112 old = self.current_node 1113 try: 1114 self.current_node = node 1115 yield 1116 finally: 1117 self.current_node = old 1118 1119 def try_match_insignificant_strides( 1120 self, 1121 tensor, 1122 meta_strides_inp: Tuple[Union[int, torch.SymInt], ...], 1123 ) -> ir.TensorBox: 1124 """ 1125 Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant 1126 dimensions - size 0 or 1 - will be updated. 1127 1128 If there are real stride differences (NHWC vs NCHW) then the input will be returned. 1129 """ 1130 1131 # should have already been realized 1132 assert torch._inductor.ir.is_storage_and_layout(tensor) 1133 1134 meta_strides = [ 1135 s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp 1136 ] 1137 1138 if all( 1139 self.sizevars.statically_known_equals(s1, s2) 1140 for s1, s2 in zip(meta_strides, tensor.get_stride()) 1141 ): 1142 return tensor 1143 1144 def significant_strides_equal(shape, meta_strides, tensor_strides): 1145 for dim, s1, s2 in zip(shape, meta_strides, tensor_strides): 1146 if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] 1147 continue 1148 1149 if not self.sizevars.statically_known_equals(s1, s2): 1150 return False 1151 1152 return True 1153 1154 if not significant_strides_equal( 1155 tensor.get_size(), meta_strides, tensor.get_stride() 1156 ): 1157 return tensor 1158 1159 storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor) 1160 new_stride = list(old_layout.stride) 1161 for i, s in enumerate(tensor.get_size()): 1162 if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] 1163 new_stride[i] = meta_strides[i] 1164 1165 new_layout = torch._inductor.ir.FixedLayout( 1166 old_layout.device, 1167 old_layout.dtype, 1168 old_layout.size, 1169 new_stride, 1170 old_layout.offset, 1171 ) 1172 return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout)) 1173 1174 def run_node(self, n: torch.fx.Node): 1175 def debug(msg): 1176 log.debug("lowering %s %s", LazyString(n.format_node), msg) 1177 1178 buffer_watermark = len(self.buffers) 1179 1180 origins = {n} 1181 if n.op == "call_function": 1182 args, kwargs = self.fetch_args_kwargs_from_env(n) 1183 origins |= gather_origins(args, kwargs) 1184 with ir.IRNode.current_origins(origins), self.set_current_node( 1185 n 1186 ), V.set_current_node(n): 1187 if ( 1188 n.op == "call_function" 1189 and n.target is not operator.getitem 1190 and fallback_node_due_to_unsupported_type(n) 1191 ): 1192 debug("fallback_handler") 1193 result = fallback_handler(n.target, add_to_fallback_set=False)( 1194 *args, **kwargs # type: ignore[possibly-undefined] 1195 ) 1196 elif n.op == "call_function" and n.target in layout_constraints: 1197 debug("layout_constraints") 1198 args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index] 1199 result = self.call_function(n.target, args, kwargs) 1200 elif is_magic_method(n.target): 1201 # TODO: this is sus, it probably should be handled in the 1202 # lowerings themselves similarly to sym_size/sym-stride 1203 # https://github.com/pytorch/pytorch/issues/127789 1204 debug("is_magic_method") 1205 if isinstance( 1206 n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) 1207 ): 1208 result = n.meta["val"].node.expr 1209 else: 1210 result = super().run_node(n) 1211 else: 1212 debug("") 1213 result = super().run_node(n) 1214 1215 # require the same stride order for dense outputs, 1216 # 1. user-land view() will not throw because inductor 1217 # output different strides than eager 1218 # long term the solution is to make view() always succeed 1219 # with infallible strides. 1220 # 2: as_strided ops, we need make sure its input has same size/stride with 1221 # eager model to align with eager behavior. 1222 as_strided_ops = [ 1223 torch.ops.aten.as_strided.default, 1224 torch.ops.aten.as_strided_.default, 1225 torch.ops.aten.as_strided_scatter.default, 1226 torch.ops.aten.resize.default, 1227 torch.ops.aten.resize_as.default, 1228 ] 1229 is_output = any(user.op == "output" for user in n.users) 1230 is_input_for_as_strided = any( 1231 user.target in as_strided_ops for user in n.users 1232 ) 1233 1234 if n.meta.get("inductor_realize_to_strides", False) and isinstance( 1235 result, TensorBox 1236 ): 1237 result.realize() 1238 strides = n.meta["val"].stride() 1239 sym_strides = torch._inductor.utils.any_is_symbolic(*strides) 1240 if ( 1241 not hasattr(result, "get_stride") 1242 or result.get_stride() != strides 1243 and not sym_strides 1244 ): 1245 stride_order = ir.get_stride_order(strides) 1246 result = ir.ExternKernel.require_stride_order(result, stride_order) 1247 if ( 1248 is_output 1249 and isinstance(result, TensorBox) 1250 and isinstance(result.data, ir.BaseView) 1251 ): 1252 # Realize so that outputs are correctly aliased 1253 result.realize() 1254 1255 if (is_output or is_input_for_as_strided) and isinstance( 1256 n.meta["val"], torch.Tensor 1257 ): 1258 strides = n.meta["val"].stride() 1259 dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) 1260 unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0 1261 # requiring a stride order for a non-dense output wouldn't 1262 # recreate the same strides, and would fail with view, defer for now. 1263 if not unbacked_symbols_in_strides and dense and len(strides): 1264 stride_order = ir.get_stride_order(strides) 1265 if ( 1266 len(result.get_size()) == 4 1267 and n in self.nodes_prefer_channels_last 1268 and n.name not in self.user_visible_outputs 1269 and not is_input_for_as_strided 1270 ): 1271 stride_order = ir.NHWC_STRIDE_ORDER 1272 1273 allow_padding = ( 1274 n.name not in self.user_visible_outputs 1275 and not is_input_for_as_strided 1276 ) 1277 result = ir.ExternKernel.require_stride_order( 1278 result, stride_order, allow_padding=allow_padding 1279 ) 1280 1281 # Realize if (1) any user need inputs realized, or (2) there is 1282 # already too many reads and rematerializing can be bad. 1283 num_users = len(set(n.users)) 1284 if num_users > 1 and isinstance(result, TensorBox): 1285 for user in n.users: 1286 if user.target in needs_realized_inputs: 1287 result.realize_hint() 1288 # This inclusion is somewhat controversial (from 1289 # discussion between Horace, Natalia, and Elias). 1290 # Currently, it's not very clear why this is helpful. 1291 # The general idea here is that even though a node may 1292 # have FlexibleLayout, we still often *treat* it as if 1293 # it was contiguous. This appears to sometimes result in 1294 # suboptimal behavior. 1295 # 1296 # When we do a better job selecting layout, we should 1297 # revisit this. 1298 need_fixed_layout = [ 1299 torch.ops.aten.convolution_backward.default, 1300 torch.ops.aten.mm.default, 1301 torch.ops.aten._int_mm.default, 1302 ] 1303 need_fixed_channels_last_layout = [] 1304 if not self.layout_opt: 1305 need_fixed_layout.append(torch.ops.aten.convolution.default) 1306 if torch._C._has_mkldnn: 1307 need_fixed_layout += [ 1308 torch.ops.mkldnn._linear_pointwise.default, 1309 torch.ops.mkldnn._linear_pointwise.binary, 1310 torch.ops.aten.mkldnn_rnn_layer.default, 1311 torch.ops.onednn.qlinear_pointwise.default, 1312 torch.ops.onednn.qlinear_pointwise.tensor, 1313 torch.ops.onednn.qlinear_pointwise.binary, 1314 torch.ops.onednn.qlinear_pointwise.binary_tensor, 1315 ] 1316 need_fixed_channels_last_layout += [ 1317 torch.ops.mkldnn._convolution_pointwise.default, 1318 torch.ops.mkldnn._convolution_pointwise.binary, 1319 torch.ops.mkldnn._convolution_pointwise_.binary, 1320 torch.ops.mkldnn._convolution_transpose_pointwise.default, 1321 torch.ops.onednn.qconv2d_pointwise.default, 1322 torch.ops.onednn.qconv2d_pointwise.binary, 1323 ] 1324 if torch._C.has_mkl: 1325 need_fixed_layout += [torch.ops.mkl._mkl_linear.default] 1326 if user.target in need_fixed_layout: 1327 result = ir.ExternKernel.require_stride_order( 1328 result, 1329 ir.get_stride_order(n.meta["val"].stride()), 1330 allow_padding=True, 1331 ) 1332 if ( 1333 user.target in need_fixed_channels_last_layout 1334 and n is user.args[0] 1335 ): 1336 result = ir.ExternKernel.require_stride_order( 1337 result, 1338 ir.get_stride_order( 1339 make_channels_last_strides_for(n.meta["val"].shape) 1340 ), 1341 ) 1342 if user.op == "output": 1343 if isinstance(result.data.data, (Pointwise, Reduction)): 1344 result.realize() 1345 1346 # TODO(jansel): introduce a store vs inline choice 1347 result.mark_reuse(len(n.users)) 1348 1349 # Realize if the IRNode already has accumulated lots of reads 1350 if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): 1351 # Prevent excessive accumulation in a computed buffer, when 1352 # there are multiple branches each with small number of memory 1353 # reads, but they converge to a user. 1354 result.realize_hint() 1355 1356 # Realize if a Pointwise has too much stuff to be inlined. 1357 # As this may cause RecursionError during Inductor's evaluation. 1358 if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): 1359 curr = result.data.data 1360 if isinstance(curr, Pointwise): 1361 # Use inner fn as a rough proxy. Good enough. 1362 if curr.has_large_inner_fn(): 1363 result.realize() 1364 1365 # This is not complete, but it doesn't have to be: origin_node 1366 # tracking is best effort. The logic here critically relies on direct 1367 # TensorBox -> StorageBox denoting a non-view; we don't bother trying 1368 # to get views to work. Feel free to add any extra cases as needed. 1369 # 1370 # Note: we can't YOLO tree_map over this result, because if there are 1371 # buffers or a view involved, we might not be able to validly assign 1372 # the origin_node here. 1373 if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): 1374 if isinstance(result.data.data, ir.Loops): 1375 result.data.data.origin_node = n 1376 elif isinstance(result.data.data, ir.Buffer): 1377 result.data.data.origin_node = n 1378 if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( 1379 result.data.data.data, ir.Loops 1380 ): 1381 result.data.data.data.origin_node = n 1382 # Not really multi-output, can straightforwardly recurse in 1383 elif ( 1384 isinstance(result.data.data, ir.MultiOutput) 1385 and not result.data.data.indices 1386 ): 1387 if isinstance(result.data.data.inputs[0], ir.Buffer): 1388 result.data.data.inputs[0].origin_node = n 1389 1390 self.register_users_of(result) 1391 1392 new_unbacked_defs = set() 1393 for i in range(buffer_watermark, len(self.buffers)): 1394 new_unbacked_defs |= self.buffers[i].get_unbacked_symbol_defs() 1395 1396 def format_buffers(): 1397 r = [] 1398 for b in self.buffers[buffer_watermark:]: 1399 r.append( 1400 f"unbacked_symbol_defs={b.get_unbacked_symbol_defs()} in:\n{b}\n" 1401 ) 1402 return "***\n".join(r) 1403 1404 if n.op != "placeholder": 1405 # Note [Backwards runtime asserts] 1406 # Backwards poses an interesting problem for deferred runtime 1407 # asserts. In the easy case, we may solely close over data 1408 # dependent sized tensors, and there are no binding sites for 1409 # unbacked SymInts. In this case, we can just drop all the 1410 # runtime asserts on the floor: no non-placeholder bindings, no 1411 # problem. 1412 # 1413 # However, it is *possible* for a fresh runtime assert to show up 1414 # between forwards and backwards. Right now, the freezing process 1415 # that happens when we lower forwards means that we will freeze 1416 # runtime asserts, and then the moment the backwards lowering 1417 # process attempts to add a new deferred runtime assert, we will 1418 # fail. Let's say you remove that assert. Now when we get here, 1419 # we need to make sure we actually emit these asserts (because we 1420 # can't emit them in forwards, we already compiled it). So we 1421 # have to do something here. But we don't want to reemit ALL 1422 # deferred runtime asserts, we only want to emit the NEW ones. 1423 # Therefore needing some sort of stratification in the ShapeEnv. 1424 # This is all doable, it just hasn't been done yet. 1425 shape_env = V.graph.sizevars.shape_env 1426 1427 for i0 in new_unbacked_defs: 1428 ras = self.ras_by_symbol.pop(i0, []) 1429 # NB: size-like not needed, we won't retrace 1430 vr = shape_env.var_to_range[i0] 1431 if not shape_env._default_unspecified_value_range().issubset(vr): 1432 1433 def is_convertible(s): 1434 if s in (int_oo, -int_oo): 1435 return False 1436 try: 1437 int(s) 1438 return True 1439 except TypeError: 1440 return False 1441 1442 if is_convertible(vr.lower): 1443 self.register_buffer( 1444 ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"), 1445 set_name=True, 1446 ) 1447 if is_convertible(vr.upper): 1448 self.register_buffer( 1449 ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"), 1450 set_name=True, 1451 ) 1452 1453 for ra in ras: 1454 fvs = free_unbacked_symbols(ra.expr) 1455 missing = fvs - self.bound_unbacked_symbols 1456 if missing: 1457 i1 = sorted(missing, key=lambda x: str(x))[0] 1458 self.ras_by_symbol.setdefault(i1, []).append(ra) 1459 else: 1460 self.register_buffer( 1461 ir.AssertScalar(ra.expr, f"{ra.expr}"), set_name=True 1462 ) 1463 1464 self.bound_unbacked_symbols |= new_unbacked_defs 1465 1466 unbacked_bindings = resolve_unbacked_bindings( 1467 V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) 1468 ) 1469 # When we do lowering, it is possible we reallocate unbacked SymInts. 1470 # So we need to line up the unbacked SymInts when performing the test 1471 # here 1472 # 1473 # In principle, we could permit lowering to introduce MORE unbacked 1474 # SymInts: as long as all the old unbacked ones are accounted for, 1475 # it's fine for inductor to introduce extra calls to item()/unbacked() 1476 # whatever. This actually happens in practice when an unbacked SymInt 1477 # gets memoized away; naively, when Inductor reprocesses a kernel, it 1478 # doesn't know that the memo still applies, and ends up allocating a 1479 # new symbol. However, this is generally a bad thing: we may still 1480 # end up needing to test equalities on the symbols, and a fresh 1481 # symbol is likely to hit lots of GuardOnDataDependent errors that 1482 # we already know facts for. 1483 renamed_unbacked_bindings = { 1484 V.fake_mode.shape_env.unbacked_renamings.get(s, s) 1485 for s in unbacked_bindings.keys() 1486 } 1487 assert new_unbacked_defs >= renamed_unbacked_bindings, ( 1488 f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" 1489 f"fx node is: {n.format_node()}\n" 1490 f"new buffers are:\n\n{format_buffers()}" 1491 ) 1492 1493 return result 1494 1495 def validate_can_generate_cpp_wrapper(self): 1496 if config.disable_cpp_codegen: 1497 raise CppWrapperCodeGenError("C++ codegen is disabled") 1498 1499 if sys.platform not in ["linux", "darwin"]: 1500 raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") 1501 1502 for value in self.graph_inputs.values(): 1503 dtype = None 1504 if isinstance(value, TensorBox): 1505 dtype = value.get_dtype() 1506 elif isinstance( 1507 value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) 1508 ): 1509 dtype = may_get_constant_buffer_dtype(value) 1510 1511 if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): 1512 raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") 1513 1514 def init_wrapper_code(self): 1515 self.cuda = "cuda" in self.device_types 1516 if self.cpp_wrapper: 1517 self.validate_can_generate_cpp_wrapper() 1518 1519 device_types = self.device_types.copy() 1520 device_types.discard("cpu") 1521 device_types.discard("meta") 1522 # TODO(Eikan): Only support mixing cpu and other device now. 1523 assert len(device_types) <= 1, "Does not support mixing {}".format( 1524 "+".join(device_types) 1525 ) 1526 only_cpu = len(device_types) == 0 1527 device_type = "cpu" if only_cpu else device_types.pop() 1528 1529 self.device_ops = get_device_op_overrides(device_type) 1530 wrapper_code_gen_cls = get_wrapper_codegen_for_device( 1531 device_type, self.cpp_wrapper 1532 ) 1533 assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" 1534 self.wrapper_code = wrapper_code_gen_cls() 1535 1536 if self.const_module: 1537 # If we have const module, we could reuse the kernels 1538 # This could avoid duplication and save time on doing recompilation (if Triton.) 1539 self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter 1540 self.wrapper_code.src_to_kernel = ( 1541 self.const_module.wrapper_code.src_to_kernel 1542 ) 1543 1544 def codegen_with_cpp_wrapper(self): 1545 """ 1546 For CPU, the cpp wrapper codegen is done in one pass. 1547 For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python 1548 wrapper code and run it to generate autotuned kernel binaries in the first pass; and then 1549 generate cpp wrapper code and compile it to a dynamic library in the second pass. 1550 """ 1551 if "cuda" in self.device_types: 1552 # first pass 1553 self.cpp_wrapper = False 1554 # Although triton.store_cubin was set in compile_fx, the backward pass didn't pick 1555 # that up. In theory it should work by only setting triton.store_cubin to True here, 1556 # but that will cause a problem when use_runtime_constant_folding is set. 1557 with config.patch({"triton.store_cubin": True}): 1558 compiled = self.compile_to_module().call 1559 1560 def materialize(x): 1561 if isinstance(x, (torch.SymInt, torch.SymFloat)): 1562 # Need concrete value to run dynamic shapes and tune the result 1563 return x.node.hint 1564 elif isinstance(x, FakeTensor): 1565 return defake(x) 1566 else: 1567 assert isinstance( 1568 x, torch.Tensor 1569 ), "Unknown type when creating real inputs" + str(type(x)) 1570 return x 1571 1572 tracing_context = torch._guards.TracingContext.try_get() 1573 if tracing_context is not None and not isinstance( 1574 V.real_inputs, NullHandler 1575 ): 1576 if tracing_context.output_strides: 1577 tracing_context.output_strides.clear() 1578 1579 params_flat = [ 1580 param 1581 for param in tracing_context.params_flat # type: ignore[union-attr] 1582 if param is not None 1583 ] 1584 real_inputs = [ 1585 materialize(x) for x in itertools.chain(params_flat, V.real_inputs) 1586 ] 1587 else: 1588 # In the backward pass, V.real_inputs is not set. 1589 # Generating random inputs based on self.example_inputs sometimes can be problematic, 1590 # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. 1591 real_inputs = [ 1592 materialize(x) 1593 for x in ( 1594 self.example_inputs 1595 if isinstance(V.real_inputs, NullHandler) 1596 else V.real_inputs 1597 ) 1598 ] 1599 1600 if self.mutated_inputs: 1601 from .compile_fx import clone_preserve_strides 1602 1603 mutated_input_idxs = [ 1604 idx 1605 for idx, name in enumerate(self.graph_inputs) 1606 if name in self.mutated_inputs 1607 and isinstance(real_inputs[idx], torch.Tensor) 1608 ] 1609 for idx in mutated_input_idxs: 1610 # clone mutated Tensor inputs to avoid mutating them in 1611 # the first pass of the CPP wrapper-based compilation, as 1612 # this will lead to a side effect on the example inputs: 1613 # e.g. if torch.compile(f)(x) if called on input-mutating 1614 # f, the inputs x will be mutated twice in the process: 1615 # once here, and again when running the compiled model; 1616 # this will also lead to a numerically incorrect output 1617 real_inputs[idx] = clone_preserve_strides(real_inputs[idx]) 1618 1619 with torch.utils._python_dispatch._disable_current_modes(): 1620 compiled(real_inputs) 1621 del real_inputs 1622 1623 # second pass 1624 # TODO: reuse self.scheduler from the first pass to speed up the second pass 1625 self.cpp_wrapper = True 1626 self.removed_buffers.clear() 1627 self.inplaced_to_remove.clear() 1628 V.graph.sizevars.precomputed_replacements.clear() 1629 V.graph.sizevars.inv_precomputed_replacements.clear() 1630 return self.codegen() 1631 else: 1632 # cpu 1633 return self.codegen() 1634 1635 def codegen(self): 1636 from .scheduler import Scheduler 1637 1638 self.init_wrapper_code() 1639 1640 self.scheduler = Scheduler(self.buffers) 1641 V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) 1642 1643 self.wrapper_code.push_codegened_graph(self) 1644 self.scheduler.codegen() 1645 result = self.wrapper_code.generate(self.is_inference) 1646 self.wrapper_code.pop_codegened_graph() 1647 return result 1648 1649 def codegen_subgraph(self, parent_graph): 1650 """ 1651 This is a more compact version of the `codegen()` above 1652 where we codegen this graph as a subgraph of some parent 1653 graph. The parent graph is passed as an argument: the 1654 intention is to inline codegening of the subgraph in 1655 the parent graph's wrapper code (including the generated 1656 kerenls). The wrapper code is not finalized (via `.generate()` 1657 call), as this will be done in the parent graph's `codegen()`. 1658 """ 1659 from .scheduler import Scheduler 1660 1661 self.wrapper_code = parent_graph.wrapper_code 1662 self.device_ops = parent_graph.device_ops 1663 self.cpp_wrapper = parent_graph.cpp_wrapper 1664 1665 self.scheduler = Scheduler(self.buffers) 1666 self.scheduler.codegen() 1667 1668 def count_bytes(self): 1669 total_bytes = 0 1670 node_counts = [] 1671 node_runtimes = [] 1672 for node in self.scheduler.nodes: 1673 num_bytes = node.get_read_write_buffers_sizes() 1674 total_bytes += num_bytes 1675 node_counts.append((node, num_bytes // 4)) 1676 node_runtimes.append((node, node.get_estimated_runtime())) 1677 return total_bytes, node_counts, node_runtimes 1678 1679 @dynamo_timed(phase_name="code_gen", fwd_only=False) 1680 def compile_to_module(self): 1681 from .codecache import PyCodeCache 1682 1683 code, linemap = ( 1684 self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() 1685 ) 1686 1687 output_code_log.debug("Output code: \n%s", code) 1688 try: 1689 linemap = [(line_no, node.stack_trace) for line_no, node in linemap] 1690 key, path = PyCodeCache.write(code) 1691 except Exception: 1692 trace_structured( 1693 "inductor_output_code", 1694 # Just omit the filename, I still want the code though! 1695 payload_fn=lambda: code, 1696 ) 1697 raise 1698 else: 1699 trace_structured( 1700 "inductor_output_code", 1701 lambda: {"filename": path}, 1702 payload_fn=lambda: code, 1703 ) 1704 1705 mod = PyCodeCache.load_by_key_path( 1706 key, 1707 path, 1708 linemap=linemap, 1709 attrs={**self.constants, **self.torchbind_constants}, 1710 ) 1711 self.cache_key = key 1712 self.cache_path = path 1713 self.cache_linemap = linemap 1714 1715 # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 1716 # TODO. Revisit this once the logging API is more mature 1717 assert mod.__file__ is not None 1718 1719 log_module_code(mod.__file__) 1720 log.debug("Output code written to: %s", mod.__file__) 1721 output_code_log.info("Output code written to: %s", mod.__file__) 1722 if config.benchmark_kernel: 1723 print(f"Compiled module path: {mod.__file__}", file=sys.stderr) 1724 V.debug.output_code(mod.__file__) 1725 V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") 1726 return mod 1727 1728 def compile_to_fn(self): 1729 if self.aot_mode: 1730 from .codecache import AotCodeCompiler 1731 1732 assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" 1733 code, linemap = self.codegen_with_cpp_wrapper() 1734 output_code_log.debug("Output code: \n%s", code) 1735 1736 serialized_extern_kernel_nodes = None 1737 if ( 1738 config.is_fbcode() 1739 and self.extern_kernel_nodes 1740 and self.extern_node_serializer 1741 ): 1742 serialized_extern_kernel_nodes = self.extern_node_serializer( 1743 self.extern_kernel_nodes 1744 ) 1745 output_code_log.debug( 1746 "Serialized Extern Kernel Nodes: \n%s", 1747 serialized_extern_kernel_nodes, 1748 ) 1749 1750 # Directly return the file path with the compiled code 1751 return AotCodeCompiler.compile( 1752 self, code, serialized_extern_kernel_nodes, cuda=self.cuda 1753 ) 1754 else: 1755 return self.compile_to_module().call 1756 1757 def get_output_names(self): 1758 return [ 1759 node.get_name() 1760 for node in self.graph_outputs 1761 if not isinstance(node, ir.NoneAsConstantBuffer) 1762 and not isinstance(node, ir.ShapeAsConstantBuffer) 1763 ] 1764 1765 def is_unspec_arg(self, name: str): 1766 # dynamo wraps unspec variable as 0d CPU tensor, 1767 # need to convert to scalar during codegen (triton only) 1768 return ( 1769 name in self.graph_inputs.keys() 1770 and self.graph_inputs[name].get_numel() == 1 1771 and self.graph_inputs[name].get_device().type == "cpu" 1772 ) 1773