1# mypy: allow-untyped-defs 2import itertools 3import logging 4import typing 5from collections import Counter 6from typing import Any, Dict, List, Set, Union 7 8import torch 9import torch._guards 10import torch.utils._pytree as pytree 11from torch._inductor.constant_folding import ConstantFolder 12from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict 13from torch.fx.experimental.symbolic_shapes import statically_known_true 14from torch.fx.passes.graph_transform_observer import GraphTransformObserver 15from torch.multiprocessing.reductions import StorageWeakRef 16 17from ...utils._ordered_set import OrderedSet 18from .. import config 19from ..pattern_matcher import ( 20 CallFunction, 21 init_once_fakemode, 22 KeywordArg, 23 Match, 24 MULTIPLE, 25 PatternMatcherPass, 26 register_graph_pattern, 27 stable_topological_sort, 28) 29from .replace_random import replace_random_passes 30 31 32log = logging.getLogger(__name__) 33patterns = PatternMatcherPass() 34aten = torch.ops.aten 35prims = torch.ops.prims 36 37pass_patterns = [ 38 patterns, 39 PatternMatcherPass(), 40] 41 42 43@init_once_fakemode 44def lazy_init(): 45 from .fuse_attention import _sfdp_init 46 from .misc_patterns import _misc_patterns_init 47 from .pad_mm import _pad_mm_init 48 49 _pad_mm_init() 50 _sfdp_init() 51 _misc_patterns_init() 52 53 54def remove_no_ops( 55 gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node] 56): 57 with torch.utils._python_dispatch._disable_current_modes(): 58 "Removes no-ops: (+ 0, - 0, * 1, / 1)" 59 graph = gm.graph 60 61 def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): 62 if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): 63 return False 64 for field in fields: 65 if getattr(t1, field) != getattr(t2, field): 66 return False 67 return True 68 69 def replace_no_op(node, replace_input_index): 70 replacement = node.args[replace_input_index] 71 72 # https://github.com/pytorch/pytorch/issues/86128 causes 73 # non-Tensor inputs even for ops with only Tensor inputs. 74 # TODO - decompose/type promote to avoid this 75 if not all(isinstance(arg, torch.fx.Node) for arg in node.args): 76 return 77 78 if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): 79 if fake_tensors_eq( 80 node.meta["val"], 81 replacement.meta["val"], 82 ("shape", "device"), 83 ): 84 with graph.inserting_after(node): 85 replacement = graph.call_function( 86 torch.ops.prims.convert_element_type.default, 87 args=(replacement, node.meta["val"].dtype), 88 ) 89 else: 90 return 91 92 node.replace_all_uses_with(replacement) 93 replacement.meta.update(node.meta) 94 graph.erase_node(node) 95 96 for node in graph.find_nodes(op="call_function", target=aten.add.Tensor): 97 # TODO handle Tensor-Scalar adds, it's a different schema 98 if len(node.args) == 2: 99 if ( 100 not any(e in zeros for e in node.args) 101 or node.kwargs.get("alpha", 1) != 1 102 ): 103 continue 104 105 replace_index = 1 if node.args[0] in zeros else 0 106 replace_no_op(node, replace_index) 107 108 for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor): 109 if len(node.args) == 2: 110 if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: 111 continue 112 113 replace_no_op(node, 0) 114 115 for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor): 116 if len(node.args) == 2: 117 if not any(e in ones for e in node.args): 118 continue 119 120 replace_input_index = 1 if node.args[0] in ones else 0 121 replace_no_op(node, replace_input_index) 122 123 for node in graph.find_nodes(op="call_function", target=aten.div.Tensor): 124 if len(node.args) == 2 and node.args[1] in ones: 125 replace_no_op(node, 0) 126 127 # meta tensors returned from the graph have no data and can be replaced with empty_strided 128 for output_node in graph.find_nodes(op="output"): 129 had_meta_return = False 130 131 def visit(n): 132 nonlocal had_meta_return 133 val = n.meta.get("val") 134 if isinstance(val, torch.Tensor) and val.device.type == "meta": 135 with graph.inserting_before(output_node): 136 n.replace_all_uses_with( 137 graph.call_function( 138 torch.ops.aten.empty_strided.default, 139 args=(val.size(), val.stride()), 140 kwargs={"dtype": val.dtype, "device": val.device}, 141 ) 142 ) 143 had_meta_return = True 144 145 torch.fx.map_arg(output_node.args, visit) 146 if had_meta_return: 147 graph.eliminate_dead_code() 148 149 150def remove_redundant_views(gm: torch.fx.GraphModule): 151 """ 152 Removes redundant views by reusing existing ones. 153 """ 154 with torch.utils._python_dispatch._disable_current_modes(): 155 # A dictionary mapping a tensor to all aliased views. 156 views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {} 157 graph = gm.graph 158 159 for node in graph.find_nodes( 160 op="call_function", target=torch.ops.aten.view.dtype 161 ): 162 src = node.args[0] 163 to_type = node.args[1] 164 existing_views = views.get(src) 165 is_needed = True 166 167 if existing_views: 168 # Replace the view with the an existing view if available. 169 alias = existing_views.get(to_type) 170 if alias: 171 is_needed = False 172 node.replace_all_uses_with(alias) 173 alias.meta.update(node.meta) 174 graph.erase_node(node) 175 else: 176 from_type = src.meta["val"].dtype 177 existing_views = {from_type: src} 178 views[src] = existing_views 179 180 if is_needed: 181 # Save the new alias but do not replace existing one. 182 existing_views.setdefault(to_type, node) 183 views[node] = existing_views 184 185 # Clean up unused views. 186 while True: 187 unused_views = [alias for alias in views if not alias.users] 188 if len(unused_views) == 0: 189 break 190 for unused in unused_views: 191 views.pop(unused) 192 graph.erase_node(unused) 193 194 195class UniformValueConstantFolder(ConstantFolder): 196 """ 197 Runs constant folding and replaces tensors that have a unifrom value 198 with a tensor constructor call: aten.full([shape], value, ...) 199 """ 200 201 def __init__(self, gm, skip_constructors=False) -> None: 202 super().__init__(gm, skip_constructors) 203 self.node_storages_ptrs: Dict[torch.fx.Node, int] = {} 204 self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {} 205 # we may constant fold a tensor which in the graph has a sym size 206 # see: [constant folding refining of symints] 207 self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {} 208 209 # initialize symint -> node mapping so that we can 210 # use symint nodes in full constructors 211 self.symint_nodes = _SymHashingDict() 212 for n in self.module.graph.nodes: 213 if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): 214 self.symint_nodes[n.meta["val"]] = n 215 216 # reference from torch/_funtorch/partitioners.py:get_default_op_list 217 self.view_op_packets = [ 218 aten.squeeze, 219 aten.unsqueeze, 220 aten.alias, 221 aten.view, 222 aten.slice, 223 aten.t, 224 prims.broadcast_in_dim, 225 aten.expand, 226 aten.as_strided, 227 aten.permute, 228 ] 229 230 self.indexing_op_packets = { 231 aten.slice, 232 } 233 234 def _support_dynamic_shape(self): 235 return True 236 237 def insertable_tensor_check(self, t: torch.Tensor) -> bool: 238 return True 239 240 def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: 241 self.node_replacements[node] = tensor.flatten()[0].item() 242 self.node_replacements_shapes[node] = node.meta["val"].shape 243 self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) 244 245 def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: 246 for n in self.module.graph.find_nodes(op="placeholder"): 247 if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): 248 env[n] = n.meta["val"] 249 else: 250 env[n] = self.unknown_value 251 252 def _deduce_value(self, node: torch.fx.Node): 253 # deduce value for full-like nodes 254 # 1. for constructors, substitute value is a tensor of size [1] 255 # 2. for view ops/indexing, substitute value is the same as the input 256 # 3. for pointwise ops, run node to get the substitute value 257 # 4. deal with some special ops 258 # otherwise, stop deduce value and return unknown value 259 260 # TODO: cat, more indexing 261 # TODO - do on cpu to avoid syncs 262 263 # single-elem attrs 264 if node.op == "get_attr" or ( 265 node.op == "call_function" 266 and node.target == torch.ops.aten.lift_fresh_copy.default 267 ): 268 out = super(ConstantFolder, self).run_node(node) 269 if isinstance(out, torch.Tensor) and out.numel() == 1: 270 return out 271 272 # handle device_put op 273 if node.target == prims.device_put.default: 274 return super(ConstantFolder, self).run_node(node) 275 276 # constructors ops 277 if ( 278 node.op == "call_function" 279 and node.target == aten.full.default 280 and len(node.args) == 2 281 ): 282 args, kwargs = self.fetch_args_kwargs_from_env(node) 283 new_args = [[1], args[1]] 284 return aten.full.default(*new_args, **node.kwargs) 285 286 # handle before view ops because this changes value 287 if node.target == aten.view.dtype: 288 return super(ConstantFolder, self).run_node(node) 289 290 # view ops, return input tensor, the first argument 291 if hasattr(node.target, "overloadpacket") and ( 292 node.target.overloadpacket in self.view_op_packets 293 or node.target.overloadpacket in self.indexing_op_packets 294 ): 295 assert isinstance(node.args[0], torch.fx.Node) 296 return self.env[node.args[0]] 297 298 # we don't want to return unknown value for symints so that we can 299 # still constant fold through their use in constructors or views 300 # if we see them in a pointwise node (e.g., tensor * symint) 301 # we will bail 302 if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt): 303 return node.meta["val"] 304 305 # pointwise ops 306 if isinstance(node.target, torch._ops.OpOverload) and ( 307 torch.Tag.pointwise in node.target.tags 308 or node.target is torch.ops.aten.scalar_tensor.default 309 ): 310 args, kwargs = self.fetch_args_kwargs_from_env(node) 311 flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) 312 313 if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs): 314 return self.unknown_value 315 316 # we run the ops with dim 1, so remove memory_format to avoid error 317 kwargs = dict(kwargs) 318 kwargs.pop("memory_format", None) 319 320 return node.target(*args, **kwargs) 321 322 return self.unknown_value 323 324 325def constant_fold_uniform_value(gm: torch.fx.GraphModule): 326 with torch.utils._python_dispatch._disable_current_modes(): 327 "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." 328 aten = torch.ops.aten 329 330 # Constant folding can leak memory, especially with repeated compilation, so we are only going to 331 # remove constants which can be replaced with a constructor. 332 cf = UniformValueConstantFolder(gm) 333 cf.run() 334 335 node_replacements = cf.node_replacements 336 337 # note: [constant folding refining of symints] 338 # constant folding will partially evaluate a graph such that values which have dependencies which 339 # are entirely known at compile time may also become compile time constants. in some cases, 340 # this will include symints which we had not yet previously deduced are guaranteed a 341 # constant value and is then deduced in constant folding. an example is: 342 # unbacked_symint_eq_11 = torch.full((), 11).item() 343 # torch.full((unbacked_symint_eq_11,), 0) 344 node_replacements_shapes = cf.node_replacements_shapes 345 346 graph = gm.graph 347 348 zeros = set() 349 ones = set() 350 351 # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, 352 # so just constant-ify if a Tensor is unaliased 353 constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() 354 355 for node in cf.node_replacements: 356 constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 357 358 for node, value in node_replacements.items(): 359 # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now 360 # hasn't shown up to be important yet 361 if "val" not in node.meta: 362 # This can only happen in AOTI 363 continue 364 365 fake_tensor = node.meta["val"] 366 if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): 367 continue 368 369 # TODO - not sure about lossy uint->python value->uint conversions 370 if fake_tensor.dtype in ( 371 torch.uint8, 372 torch.uint16, 373 torch.uint32, 374 torch.uint64, 375 ): 376 continue 377 378 if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: 379 continue 380 381 with graph.inserting_after(node): 382 # the conversion from tensor and back to value can be lossy, just use the original full ctor value 383 if ( 384 node.op == "call_function" 385 and node.target == aten.full.default 386 and len(node.args) == 2 387 ): 388 value = node.args[1] 389 390 # refines symints, see [constant folding refining of symints] above 391 for runtime_size, compile_time_size in zip( 392 node_replacements_shapes[node], fake_tensor.shape 393 ): 394 torch._check(runtime_size == compile_time_size) 395 396 # replace SymInt as Node before creating a new full node 397 # e.g. (1, s0) -> (1, arg0_1) 398 node_shape = node_replacements_shapes[node] 399 if not all( 400 not isinstance(s, torch.SymInt) or s in cf.symint_nodes 401 for s in node_shape 402 ): 403 continue 404 405 shapes = [ 406 cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s 407 for s in node_replacements_shapes[node] 408 ] 409 410 # zeros and ones just get traced into full, so we insert those 411 new_node = graph.call_function( 412 aten.full.default, 413 args=(shapes, value), 414 kwargs={ 415 "dtype": fake_tensor.dtype, 416 "layout": torch.strided, 417 "device": fake_tensor.device, 418 "pin_memory": False, 419 }, 420 ) 421 422 new_node.meta.update(node.meta) 423 node.replace_all_uses_with(new_node) 424 graph.erase_node(node) 425 426 if value == 0: 427 zeros.add(new_node) 428 elif value == 1: 429 ones.add(new_node) 430 431 remove_no_ops(gm, zeros, ones) 432 remove_redundant_views(gm) 433 434 435def joint_graph_passes(graph: torch.fx.GraphModule): 436 """ 437 Run FX transformations on the joint forwards+backwards graph. 438 """ 439 lazy_init() 440 count = 0 441 if config.joint_custom_pre_pass is not None: 442 with GraphTransformObserver( 443 graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform 444 ): 445 config.joint_custom_pre_pass(graph.graph) 446 count += 1 447 448 from .post_grad import remove_noop_ops 449 450 remove_noop_ops(graph.graph) 451 452 if config.joint_graph_constant_folding: 453 with GraphTransformObserver( 454 graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform 455 ): 456 constant_fold_uniform_value(graph) 457 458 if config.pattern_matcher: 459 for patterns in pass_patterns: 460 count += patterns.apply(graph.graph) # type: ignore[arg-type] 461 462 if not config.fallback_random: 463 count += replace_random_passes(graph) 464 465 if config.joint_custom_post_pass is not None: 466 with GraphTransformObserver( 467 graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform 468 ): 469 config.joint_custom_post_pass(graph.graph) 470 count += 1 471 472 if count: 473 stable_topological_sort(graph.graph) 474 graph.graph.lint() 475 graph.recompile() 476 return graph 477 478 479@register_graph_pattern( 480 CallFunction( 481 torch.ops.prims.iota.default, 482 KeywordArg("length"), 483 start=KeywordArg("start"), 484 step=KeywordArg("step"), 485 dtype=KeywordArg("dtype"), 486 device=KeywordArg("device"), 487 requires_grad=KeywordArg("requires_grad"), 488 ), 489 pass_dict=patterns, 490) 491def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): 492 """ 493 Eager supports: 494 495 aten.index(cuda_tensor, torch.arange(..., device="cpu")) 496 497 But this results in an implicit host-device-copy and breaks cudagraphs. 498 Rewrite the arange to use CUDA. 499 """ 500 (node,) = match.nodes 501 user_devices: OrderedSet[torch.device] = OrderedSet() 502 for user in node.users: 503 if ( 504 user.op == "call_function" 505 and user.target in (aten.index.Tensor, aten.index_put.default) 506 and hasattr(user.meta.get("val"), "device") 507 ): 508 user_devices.add(user.meta["val"].device) # type: ignore[union-attr] 509 else: 510 return # bail out 511 512 if len(user_devices) == 1 and "val" in node.meta: 513 (user_device,) = user_devices 514 if device.type != user_device.type: 515 repl = match.graph.call_function( 516 torch.ops.prims.iota.default, 517 (length,), 518 { 519 "start": start, 520 "step": step, 521 "dtype": dtype, 522 "device": user_device, 523 "requires_grad": requires_grad, 524 }, 525 ) 526 repl.meta.update(node.meta) 527 repl.meta["val"] = repl.meta["val"].to(user_device) 528 node.replace_all_uses_with(repl) 529 match.erase_nodes() 530 531 532@register_graph_pattern( 533 CallFunction( 534 torch.ops.prims.convert_element_type.default, 535 CallFunction( 536 torch.ops.prims.convert_element_type.default, 537 KeywordArg("arg"), 538 KeywordArg("dtype1"), 539 ), 540 KeywordArg("dtype2"), 541 ), 542 pass_dict=patterns, 543) 544def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): 545 """Remove chain of dtype conversions often created by AMP""" 546 graph = match.graph 547 node = match.output_node() 548 allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64} 549 if dtype1 in allowed and dtype2 in allowed: 550 repl = graph.call_function( 551 torch.ops.prims.convert_element_type.default, (arg, dtype2) 552 ) 553 repl.meta.update(node.meta) 554 node.replace_all_uses_with(repl) 555 match.erase_nodes() 556 557 558@register_graph_pattern( 559 CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), 560 pass_dict=patterns, 561) 562def pointless_view(match: Match, arg, size): 563 """Remove no-op view""" 564 node = match.output_node() 565 arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] 566 if size == arg_size: 567 node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] 568 match.erase_nodes() 569 570 571# When softmax is used with temperature or other scaling, we get the pattern 572# 573# scale(x) - scale(x).amax(dim, keepdim=True) 574# 575# which is expected to be at most zero, but we may end up with numerical 576# discrepancies # between the recomputed values of scale(x) inside and out 577# of the reduction, # depending on compiler optimizations, e.g. use of fma 578# instructions. 579# 580# Here we replace it with the mathematically equivalent, 581# 582# scale(x - x.amax(dim, keepdim=True)) 583# 584# which is more stable as we only compute the scaling once. 585# 586# NOTE: This pattern must come after fused attention matching! 587 588 589def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False): 590 # Allow matching inp * other and other * input 591 if reverse: 592 scaled = CallFunction( 593 linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE 594 ) 595 else: 596 scaled = CallFunction( 597 linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE 598 ) 599 if to_dtype: 600 scaled = CallFunction( 601 prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE 602 ) 603 amax = CallFunction( 604 aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim") 605 ) 606 return CallFunction(aten.sub.Tensor, scaled, amax) 607 608 609def _other_is_broadcasted_in_dim(match): 610 # Check that the scaling factor is constant across the reduction dim, 611 # so scaling doesn't change which index corresponds to the maximum value 612 other = match.kwargs["other"] 613 if isinstance(other, (int, float)): 614 return True 615 616 inp = match.kwargs["inp"] 617 if not all(isinstance(x, torch.fx.Node) for x in (inp, other)): 618 return False 619 620 inp_example = inp.meta["val"] 621 other_example = other.meta["val"] 622 if isinstance(other_example, (torch.SymInt, torch.SymFloat)): 623 return True 624 625 if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)): 626 return False 627 628 inp_ndim = inp_example.ndim 629 other_shape = other_example.shape 630 if inp_ndim < len(other_shape): 631 return False 632 633 # Pad other_shape to the same ndim as inp 634 other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) 635 636 dim = match.kwargs["dim"] 637 if isinstance(dim, int): 638 dim = (dim,) 639 640 return all(statically_known_true(other_shape[d] == 1) for d in dim) 641 642 643def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): 644 def repl(inp, other): 645 if dtype is not None: 646 inp = inp.to(dtype) 647 648 sign: Union[int, float, torch.Tensor] 649 if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): 650 sign = 1 if other >= 0 else -1 651 else: 652 one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) 653 sign = torch.where(other >= 0, one, -one) 654 655 inp = inp * sign 656 max_ = torch.amax(inp, dim=dim, keepdim=keepdim) 657 return (inp - max_) * (sign * other) 658 659 match.replace_by_example(repl, [inp, other]) 660 661 662for reverse, to_dtype in itertools.product((False, True), repeat=2): 663 register_graph_pattern( 664 _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), 665 pass_dict=pass_patterns[1], 666 extra_check=_other_is_broadcasted_in_dim, 667 )(mul_softmax_pattern) 668 669 670def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): 671 def repl(inp, other): 672 if dtype is not None: 673 inp = inp.to(dtype) 674 675 sign: Union[int, float, torch.Tensor] 676 if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): 677 sign = 1 if other >= 0 else -1 678 else: 679 one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) 680 sign = torch.where(other >= 0, one, -one) 681 682 inp = inp * sign 683 max_ = torch.amax(inp, dim=dim, keepdim=keepdim) 684 return (inp - max_) / (sign * other) 685 686 match.replace_by_example(repl, [inp, other]) 687 688 689for to_dtype in (False, True): 690 register_graph_pattern( 691 _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), 692 pass_dict=pass_patterns[1], 693 extra_check=_other_is_broadcasted_in_dim, 694 )(div_softmax_pattern) 695