1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4import itertools 5import logging 6import operator 7from collections import Counter, defaultdict 8from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union 9 10import torch 11import torch._inductor as inductor 12import torch.utils._pytree as pytree 13from torch import fx 14from torch._decomp import register_decomposition 15from torch._dynamo.utils import counters, optimus_scuba_log 16from torch._inductor import comms 17from torch._inductor.virtualized import ops 18from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype 19from torch._utils_internal import upload_graph 20from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq 21from torch.fx.passes.graph_transform_observer import GraphTransformObserver 22 23from .. import config, ir, pattern_matcher 24from ..codegen.common import BackendFeature, has_backend_feature 25from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage 26from ..lowering import lowerings as L 27from ..pattern_matcher import ( 28 _return_true, 29 Arg, 30 CallFunction, 31 CallFunctionVarArgs, 32 filter_nodes, 33 get_arg_value, 34 get_mutation_region_id, 35 Ignored, 36 init_once_fakemode, 37 KeywordArg, 38 ListOf, 39 Match, 40 MULTIPLE, 41 PatternMatcherPass, 42 register_graph_pattern, 43 stable_topological_sort, 44) 45from ..utils import decode_device, get_gpu_type, is_pointwise_use 46from ..virtualized import V 47from .b2b_gemm import B2B_GEMM_PASS 48from .ddp_fusion import fuse_ddp_communication 49from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS 50from .micro_pipeline_tp import micro_pipeline_tp_pass 51from .pre_grad import is_same_dict, save_inductor_dict 52from .reinplace import reinplace_inplaceable_ops 53from .split_cat import POST_GRAD_PATTERNS 54 55 56if TYPE_CHECKING: 57 from sympy import Expr 58 59 60log = logging.getLogger(__name__) 61aten = torch.ops.aten 62prims = torch.ops.prims 63 64# First pass_patterns[0] are applied, then [1], then [2] 65pass_patterns = [ 66 PatternMatcherPass(), 67 PatternMatcherPass(), 68 PatternMatcherPass(), 69] 70 71 72def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): 73 """ 74 Passes that run on after grad. This is called once on the forwards 75 graph and once on the backwards graph. 76 77 The IR here has been normalized and functionalized. 78 """ 79 if config.dce: 80 # has some issues with mutation in inference mode 81 gm.graph.eliminate_dead_code() 82 83 if is_inference and config.reorder_for_locality: 84 reorder_for_locality(gm.graph) 85 86 fake_tensor_updater = FakeTensorUpdater(gm.graph) 87 88 if config.post_grad_custom_pre_pass is not None: 89 with GraphTransformObserver( 90 gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform 91 ): 92 config.post_grad_custom_pre_pass(gm.graph) 93 94 if config.pattern_matcher: 95 lazy_init() 96 optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) 97 group_batch_fusion_passes(gm.graph, pre_grad=False) 98 remove_noop_ops(gm.graph) 99 for patterns in pass_patterns: 100 patterns.apply(gm.graph) # type: ignore[arg-type] 101 for pass_name in config.post_grad_fusion_options: 102 # skip all patterns for group batch fusions 103 if pass_name in POST_GRAD_FUSIONS: 104 continue 105 pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name] 106 inductor_before_change = save_inductor_dict( 107 [pattern_matcher_pass.pass_name] 108 ) 109 pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] 110 if not is_same_dict(counters["inductor"], inductor_before_change): 111 optimus_scuba_log[ 112 f"{pattern_matcher_pass.pass_name}_post_grad" 113 ] = upload_graph(gm.graph) 114 if config.b2b_gemm_pass: 115 B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] 116 117 if config._micro_pipeline_tp: 118 micro_pipeline_tp_pass(gm.graph) 119 120 if config._fuse_ddp_communication: 121 fuse_ddp_communication( 122 gm.graph, 123 config._fuse_ddp_communication_passes, 124 config._fuse_ddp_bucket_size, 125 ) 126 127 if config.post_grad_custom_post_pass is not None: 128 with GraphTransformObserver( 129 gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform 130 ): 131 config.post_grad_custom_post_pass(gm.graph) 132 133 stable_topological_sort(gm.graph) 134 135 move_constructors_to_gpu(gm.graph) 136 137 fake_tensor_updater.incremental_update() 138 139 # Keep these last, since they introduces mutation. Look at 140 # ./fx_passes/README.md for a discussion of mutation invariants. 141 reinplace_inplaceable_ops(gm.graph) 142 decompose_auto_functionalized(gm.graph) 143 144 comms.reinplace_fsdp_all_gather(gm.graph) 145 146 gm.recompile() 147 optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph) 148 gm.graph.lint() 149 150 151@init_once_fakemode 152def lazy_init(): 153 if torch._C._has_mkldnn: 154 from . import decompose_mem_bound_mm # noqa: F401 155 from .mkldnn_fusion import _mkldnn_fusion_init 156 157 _mkldnn_fusion_init() 158 159 160def reorder_for_locality(graph: torch.fx.Graph): 161 def visit(other_node): 162 if ( 163 other_node.op == "call_function" 164 and other_node.target != operator.getitem 165 and all((n in seen_nodes) for n in other_node.users) 166 and get_mutation_region_id(graph, node) 167 == get_mutation_region_id(graph, other_node) 168 ): 169 # move node's producers right before it 170 node.prepend(other_node) 171 172 seen_nodes = set() 173 174 # only reorder nodes before the first copy_ in the graph. 175 # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, 176 # and this reordering doesnt work well with mutation 177 first_copy = next( 178 iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), 179 None, 180 ) 181 past_mutating_epilogue = True if first_copy is None else False 182 183 for node in reversed(graph.nodes): 184 seen_nodes.add(node) 185 if not past_mutating_epilogue: 186 past_mutating_epilogue = node is first_copy 187 continue 188 189 torch.fx.map_arg((node.args, node.kwargs), visit) 190 191 192def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): 193 """ 194 Register an aten to inductor IR replacement pattern 195 """ 196 return pattern_matcher.register_lowering_pattern( 197 pattern, extra_check, pass_dict=pass_patterns[pass_number] 198 ) 199 200 201################################################################################ 202# Actual patterns below this point. 203# Priority of patterns is: 204# - later output nodes first 205# - order patterns are defined in 206################################################################################ 207 208 209def is_valid_mm_plus_mm(match: Match): 210 *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape 211 *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape 212 if k1 != k2: 213 return False 214 215 *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape 216 *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape 217 if k3 != k4: 218 return False 219 220 if m1 != m2 or n1 != n2: 221 return False 222 223 return True 224 225 226def scatter_upon_const_tensor_extra_check(m): 227 if not config.optimize_scatter_upon_const_tensor: 228 return False 229 full_shape = m.kwargs["shape"] 230 selector = m.kwargs["selector"] 231 dim = m.kwargs["dim"] 232 if dim < 0: 233 dim += len(full_shape) 234 235 selector_ft = selector.meta["val"] 236 assert selector_ft.dim() == len(full_shape) 237 238 for idx, select_sz, full_sz in zip( 239 itertools.count(), selector_ft.shape, full_shape 240 ): 241 if idx == dim: 242 continue 243 244 # TODO: the pattern can be updated to support the case that index tensor 245 # is shorter. But that will need a more complex condition expression 246 # especially for multi-dimensional tensors. 247 # Skip it for now. 248 if isinstance(full_sz, fx.Node): 249 full_sz = full_sz.meta["val"] 250 if select_sz < full_sz: 251 return False 252 253 # Actually we can support small size larger than 1. It would be a bit 254 # tedius. E.g., we load all the index values (not many) and compare 255 # them with the position in tensor to decide what value to return. 256 return selector_ft.size(dim) == 1 257 258 259@register_lowering_pattern( 260 CallFunction( 261 aten.scatter.value, 262 CallFunction( 263 aten.full, 264 KeywordArg("shape"), 265 KeywordArg("background_val"), 266 dtype=KeywordArg("dtype"), 267 ), 268 KeywordArg("dim"), 269 KeywordArg("selector"), 270 KeywordArg("val"), # scalar value 271 ), 272 extra_check=scatter_upon_const_tensor_extra_check, 273) 274def scatter_upon_const_tensor( 275 match: Match, shape, background_val, dtype, dim, selector, val 276): 277 """ 278 Match the pattern of full+scatter into a pointwise. 279 280 TODO: Right now the scatter value must be a scalar. But we could support it 281 when it is a tensor as well. 282 """ 283 from torch._inductor import metrics 284 285 metrics.num_matches_for_scatter_upon_const_tensor += 1 286 287 selector_loader = selector.make_loader() 288 289 def inner_fn(idx): 290 selector_idx = list(idx) 291 selector_idx[dim] = 0 292 293 selector = selector_loader(selector_idx) 294 return ops.where( 295 selector == ops.index_expr(idx[dim], torch.int64), 296 ops.constant(val, dtype), 297 ops.constant(background_val, dtype), 298 ) 299 300 return ir.Pointwise.create( 301 device=selector.get_device(), 302 dtype=dtype, 303 inner_fn=inner_fn, 304 ranges=shape, 305 ) 306 307 308@register_lowering_pattern( 309 CallFunction( 310 aten.add, 311 CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), 312 CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), 313 ), 314 extra_check=is_valid_mm_plus_mm, 315) 316def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): 317 return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) 318 319 320def cuda_and_enabled_mixed_mm(match): 321 return ( 322 (config.use_mixed_mm or config.mixed_mm_choice != "default") 323 and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) 324 and ( 325 match.kwargs["mat2_dtype"].itemsize 326 > match.kwargs["mat2"].meta.get("val").dtype.itemsize 327 ) 328 and has_backend_feature("cuda", BackendFeature.TRITON_TEMPLATES) 329 ) 330 331 332def cuda_and_enabled_mixed_mm_and_not_int8(match): 333 return ( 334 cuda_and_enabled_mixed_mm(match) 335 and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) 336 and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8) 337 != torch.int8 338 ) # bitshift numerics in triton and pytorch don't match for torch.int8 339 340 341""" 342 this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor 343 (where the int4 and uint4x2 are represented with int8 and uint8 respectively) 344 where every other row of the int4 is packed with the row above it as: 345 uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4 346 347 unpack formulas: 348 int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8 349 int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8 350 351 thus matching on unpack formula: 352 torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8)) 353 354 note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior 355 of the kernel matches the pytorch formula for all dtypes except torch.int8 356 where the bitwise numerics in triton do not match those in pytorch. 357""" 358 359 360@register_lowering_pattern( 361 CallFunction( 362 aten.mm.default, 363 KeywordArg("mat1"), 364 CallFunction( 365 aten.sub.Tensor, 366 CallFunction( 367 prims.convert_element_type.default, 368 CallFunction( 369 aten.reshape.default, 370 CallFunction( 371 aten.cat.default, 372 ListOf( 373 CallFunction( 374 aten.bitwise_and.Scalar, 375 KeywordArg("mat2"), 376 0xF, 377 ), 378 # CallFunction( 379 # aten.__rshift__.Scalar, 380 # KeywordArg("mat2"), 381 # 4, 382 # ), 383 True, 384 ), 385 1, 386 ), 387 KeywordArg("mat2_mm_shape"), 388 ), 389 KeywordArg("mat2_dtype"), 390 ), 391 8, 392 ), 393 ), 394 extra_check=cuda_and_enabled_mixed_mm_and_not_int8, 395) 396def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype): 397 return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm( 398 mat1, mat2, mat2_mm_shape, mat2_dtype 399 ) 400 401 402""" 403 torch.mm(mat1, mat2.to(mat2_dtype)) 404""" 405 406 407@register_lowering_pattern( 408 CallFunction( 409 aten.mm, 410 KeywordArg("mat1"), 411 CallFunction( 412 prims.convert_element_type.default, 413 KeywordArg("mat2"), 414 KeywordArg("mat2_dtype"), 415 ), 416 ), 417 extra_check=cuda_and_enabled_mixed_mm, 418) 419def mixed_mm(match: Match, mat1, mat2, mat2_dtype): 420 return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype) 421 422 423@register_graph_pattern( 424 CallFunction( 425 aten.cumsum.default, 426 CallFunction( 427 torch.ops.aten.full.default, 428 KeywordArg("shape"), 429 KeywordArg("fill_value"), 430 dtype=KeywordArg("dtype"), 431 layout=Ignored(), 432 device=KeywordArg("device"), 433 pin_memory=False, 434 _users=MULTIPLE, 435 ), 436 KeywordArg("dim"), 437 _users=MULTIPLE, 438 ), 439 pass_dict=pass_patterns[1], 440) 441def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): 442 """Based on a pattern in OPTForCausalLM""" 443 444 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 445 # cumsum promotes all integral types to int64 446 dtype = torch.int64 447 448 def repl(*shape): 449 dim_size = shape[dim] 450 idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) 451 452 inter_shape = [1] * len(shape) 453 inter_shape[dim] = dim_size 454 return (idx * fill_value).view(inter_shape).expand(shape) 455 456 # only replace the output node, not all nodes 457 match.nodes = [match.output_node()] 458 match.replace_by_example(repl, list(shape)) 459 460 461def shape_of_mm(a, b): 462 m, _ = a.get_size() 463 _, n = b.get_size() 464 return [m, n] 465 466 467@register_lowering_pattern( 468 CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), 469) 470def cat_mm(match, inputs, dim): 471 return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) 472 473 474@register_lowering_pattern( 475 CallFunction( 476 aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() 477 ), 478) 479def cat_addmm(match, inputs, dim): 480 def shape_of(bias, a, b): 481 m, _ = a.get_size() 482 _, n = b.get_size() 483 return [m, n] 484 485 return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) 486 487 488def cat_tuned_op(match, inputs, dim, *, op, shape_of): 489 """ 490 Memory planning to remove cat. We can't use the stock memory 491 planner since autotuning matmuls needs to know the output layout. 492 """ 493 if len(inputs) == 1: 494 return op(*inputs[0]) 495 496 # TODO(jansel): rewrite this as a bmm? 497 if dim < 0: 498 dim += len(shape_of(*inputs[0])) 499 assert dim in (0, 1) 500 notdim = 1 - dim 501 502 new_size: Optional[Union[List[Expr], List[int]]] = None 503 offsets_start = [] 504 offsets_end = [] 505 506 # compute output sizes 507 for i in range(len(inputs)): 508 shape = shape_of(*inputs[i]) 509 if new_size is None: 510 new_size = shape 511 else: 512 new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] 513 shape[notdim], new_size[notdim] 514 ) 515 new_size[dim] += shape[dim] 516 offsets_start.append(new_size[dim] - shape[dim]) 517 offsets_end.append(new_size[dim]) 518 519 assert new_size is not None 520 dtype = functools.reduce( 521 torch.promote_types, 522 [x.get_dtype() for x in itertools.chain.from_iterable(inputs)], 523 ) 524 device = inputs[0][0].get_device() 525 kernel = ir.ConcatKernel( 526 name=None, 527 layout=ir.FixedLayout(device, dtype, new_size), 528 inputs=[], 529 ) 530 kernel_tensor = ir.TensorBox.create(kernel) 531 532 for i in range(len(inputs)): 533 dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) 534 src = op(*inputs[i], layout=dst.get_layout()).data.data 535 assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) 536 src.layout = ir.NonOwningLayout(dst) 537 kernel.inputs.append(src) 538 539 kernel.name = V.graph.register_buffer(kernel) 540 kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) 541 V.graph.register_operation(kernel) 542 return kernel_tensor 543 544 545_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) 546 547 548@register_lowering_pattern( 549 CallFunction( 550 aten.cat, 551 [ 552 _cat_1, 553 CallFunction( 554 aten.slice, 555 _cat_1, 556 1, 557 0, 558 KeywordArg("size"), 559 ), 560 ], 561 1, 562 ) 563) 564def cat_slice_cat(match, cat_input, size, dim=1): 565 """ 566 This is an example of a more complex pattern where cat_1 is used 567 multiple times inside the pattern. We fold 2 calls to cat into one. 568 569 Matches: 570 cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) 571 slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) 572 slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) 573 cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) 574 575 576 Rewrite to: 577 slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) 578 cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) 579 """ 580 first, *rest = cat_input 581 # Optimization is optional, because we can just not fold the cat 582 # size should be within first.get_size()[dim] such that the optimization is valid. 583 # For negative `end`, we currently fallback to not optimizing. 584 if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): 585 # fold 2 cats into 1 cat 586 return L[aten.cat]( 587 [ 588 first, 589 *rest, 590 L[aten.slice](first, dim, 0, size), 591 ], 592 dim, 593 ) 594 else: 595 # don't expect to hit this case, just fall back 596 tmp = L[aten.cat](cat_input, dim) 597 return L[aten.cat]( 598 [ 599 tmp, 600 L[aten.slice](tmp, dim, 0, size), 601 ], 602 dim, 603 ) 604 605 606def is_valid_splitwithsizes_cat(match): 607 split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) 608 cat_nodes = filter_nodes(match.nodes, aten.cat) 609 get_item_nodes = filter_nodes(match.nodes, operator.getitem) 610 if len(split_nodes) != 1 or len(cat_nodes) != 1: 611 return False 612 split_node, cat_node = split_nodes[0], cat_nodes[0] 613 # The dim of split and cat should match for passthrough 614 if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): 615 return False 616 get_item_args = { 617 get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes 618 } 619 assert None not in get_item_args 620 split_sizes = get_arg_value(split_node, 1, "split_sizes") 621 # All parts of split should be included in the cat 622 if get_item_args != set(range(len(split_sizes))): 623 return False 624 # The order of get_item_args should same with cat_node used. 625 # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), 626 # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). 627 cat_items_args_order = [ 628 get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) 629 ] 630 if cat_items_args_order != list(range(len(split_sizes))): 631 return False 632 633 return True 634 635 636def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): 637 """True if two nodes have the same metadata""" 638 val1 = node1.meta.get("val") 639 val2 = node2.meta.get("val") 640 return ( 641 val1 is not None 642 and val2 is not None 643 and statically_known_true(sym_eq(val1.size(), val2.size())) 644 and val1.layout == val2.layout 645 and val1.dtype == val2.dtype 646 and val1.device == val2.device 647 and ( 648 val1.layout != torch.strided 649 or statically_known_true(sym_eq(val1.stride(), val2.stride())) 650 ) 651 ) 652 653 654noop_registry: Dict[Any, Any] = {} 655 656 657def register_noop_decomp(targets, nop_arg=0): 658 def register_fun(cond): 659 register_decomposition(targets, registry=noop_registry, unsafe=True)( 660 (cond, nop_arg) # type: ignore[arg-type] 661 ) 662 return cond 663 664 return register_fun 665 666 667@register_noop_decomp(aten.slice) 668def slice_noop(self, dim=0, start=None, end=None, step=1): 669 if start is None or end is None: 670 return False 671 if ( 672 statically_known_true(sym_eq(start, 0)) 673 and statically_known_true(end >= 2**63 - 1) 674 and statically_known_true(sym_eq(step, 1)) 675 ): 676 return True 677 return False 678 679 680@register_noop_decomp(aten.slice_scatter, 1) 681def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): 682 if start is None: 683 start = 0 684 if end is None: 685 end = 2**63 - 1 686 if start == 0 and end >= 2**63 - 1 and step == 1: 687 return True 688 return False 689 690 691@register_noop_decomp(aten.repeat) 692def repeat_noop(self, repeats): 693 return all(r == 1 for r in repeats) 694 695 696@register_noop_decomp(aten.constant_pad_nd) 697def constant_pad_nd(x, padding, fill_value=0): 698 return all(p == 0 for p in padding) 699 700 701@register_noop_decomp(torch.ops.prims.convert_element_type) 702def convert_element_type_noop(x, dtype: torch.dtype): 703 return x.dtype == dtype 704 705 706@register_noop_decomp(torch.ops.prims.device_put) 707def device_put_noop(x, device): 708 return x.device == decode_device(device) 709 710 711@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) 712def int_noop(x): 713 return is_integer_dtype(x.dtype) 714 715 716@register_noop_decomp([aten.pow]) 717def pow_noop(a, b): 718 return isinstance(b, int) and b == 1 719 720 721@register_noop_decomp([aten.cat], lambda args: args[0][0]) 722def cat_noop(inputs, dim=0): 723 return len(inputs) == 1 724 725 726@register_noop_decomp(aten.view) 727def view_noop(arg, size): 728 return arg.shape == size 729 730 731# Note, we also always have a check for identical metadata, which is why these 732# are safe 733@register_noop_decomp([aten.copy], nop_arg=1) 734@register_noop_decomp([aten.alias, aten.clone]) 735def true_noop(*args, **kwargs): 736 return True 737 738 739def remove_noop_ops(graph: torch.fx.Graph): 740 """ 741 Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. 742 """ 743 inputs = set() 744 input_storages = set() 745 output_storages = set() 746 747 for node in graph.find_nodes(op="placeholder"): 748 inputs.add(node) 749 input_storages.add(get_node_storage(node)) 750 751 output_node = next(iter(reversed(graph.nodes))) 752 assert output_node.op == "output" 753 outputs = output_node.args[0] 754 if not isinstance(outputs, (list, tuple)): 755 # nested subgraphs can have singleton outputs 756 outputs = (outputs,) 757 for out in outputs: 758 if isinstance(out, torch.fx.Node): 759 output_storages.add(get_node_storage(out)) 760 761 for node in graph.nodes: 762 if node.target in noop_registry: 763 cond, src_index = noop_registry[node.target] 764 if isinstance(src_index, int): 765 src = node.args[src_index] 766 else: 767 src = src_index(node.args) 768 if not isinstance(src, torch.fx.Node): 769 continue 770 # Don't introduce new aliasing between inputs and outputs. 771 # See fx_passes/README.md for a discussion of why this is 772 # necessary. 773 node_storage = get_node_storage(node) 774 src_storage = get_node_storage(src) 775 node_is_view = node_storage == src_storage 776 if ( 777 not node_is_view 778 and node_storage in output_storages 779 and (src_storage in input_storages or src_storage in output_storages) 780 ): 781 continue 782 783 # Even if input and outputs are expected to alias, 784 # don't make "node is src" True 785 if ( 786 node_is_view 787 and node in output_node.args 788 and (src in inputs or src in output_node.args) 789 ): 790 continue 791 792 is_valid, args, kwargs = get_fake_args_kwargs(node) 793 if not is_valid: 794 continue 795 if same_meta(node, src) and cond(*args, **kwargs): 796 node.replace_all_uses_with(src) 797 graph.erase_node(node) 798 799 800def decompose_auto_functionalized(graph): 801 """Decomposes auto_functionalized and triton_kernel_wrapper_functional 802 nodes into clones and the underlying mutation node. 803 804 We assume that the reinplacing pass runs before this; the reinplacing pass 805 tells us (via rewriting the arguments or .meta to those nodes) which 806 Tensors we should clone and which Tensors are safe to reinplace. 807 """ 808 graph_pass = PatternMatcherPass() 809 810 @register_graph_pattern( 811 CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), 812 pass_dict=graph_pass, 813 ) 814 def _(match: Match, *args, **kwargs): 815 from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense 816 817 only_clone_these_tensors = tuple( 818 match.nodes[0].meta.get("only_clone_these_tensors", []) 819 ) 820 821 flat_args, spec = pytree.tree_flatten((args, kwargs)) 822 823 # NB: we combine (args, kwargs) into flat args for replacing. 824 # This is replace_by_example uses make_fx which does not support 825 # tracing a function with kwargs. 826 def decomp(*flat_args): 827 args, kwargs = pytree.tree_unflatten(flat_args, spec) 828 return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) 829 830 match.replace_by_example(decomp, flat_args, run_functional_passes=False) 831 832 @register_graph_pattern( 833 CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), 834 pass_dict=graph_pass, 835 ) 836 def _(match: Match, *args, **kwargs): 837 from torch._higher_order_ops.triton_kernel_wrap import ( 838 triton_kernel_wrapper_functional_dense, 839 ) 840 841 flat_args, spec = pytree.tree_flatten((args, kwargs)) 842 843 # NB: we combine (args, kwargs) into flat args for replacing. 844 # This is replace_by_example uses make_fx which does not support 845 # tracing a function with kwargs. 846 def decomp(*flat_args): 847 args, kwargs = pytree.tree_unflatten(flat_args, spec) 848 return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) 849 850 match.replace_by_example(decomp, flat_args, run_functional_passes=False) 851 852 @register_graph_pattern( 853 CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), 854 pass_dict=graph_pass, 855 ) 856 def _(match: Match, *args, **kwargs): 857 from torch._higher_order_ops.auto_functionalize import ( 858 auto_functionalized_v2_dense, 859 ) 860 861 only_clone_these_bases = tuple( 862 match.nodes[0].meta.get("only_clone_these_tensors", []) 863 ) 864 865 flat_args, spec = pytree.tree_flatten((args, kwargs)) 866 867 # NB: we combine (args, kwargs) into flat args for replacing. 868 # This is replace_by_example uses make_fx which does not support 869 # tracing a function with kwargs. 870 def decomp(*flat_args): 871 args, kwargs = pytree.tree_unflatten(flat_args, spec) 872 return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) 873 874 match.replace_by_example(decomp, flat_args, run_functional_passes=False) 875 876 graph_pass.apply(graph) 877 878 for node in graph.find_nodes( 879 op="call_function", target=torch.ops.higher_order.auto_functionalized 880 ): 881 raise AssertionError("auto_functionalized was not removed") 882 883 for node in graph.find_nodes( 884 op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 885 ): 886 raise AssertionError("auto_functionalized_v2 was not removed") 887 888 for node in graph.find_nodes( 889 op="call_function", 890 target=torch.ops.higher_order.triton_kernel_wrapper_functional, 891 ): 892 raise AssertionError("triton_kernel_wrapper_functional was not removed") 893 894 895@register_lowering_pattern( 896 CallFunction( 897 aten.cat, 898 ListOf( 899 CallFunction( 900 operator.getitem, 901 CallFunction( 902 aten.split_with_sizes, 903 KeywordArg("input_"), 904 Ignored(), 905 Ignored(), 906 _users=MULTIPLE, 907 ), 908 Ignored(), 909 ), 910 ), 911 Ignored(), 912 ), 913 pass_number=2, 914 extra_check=is_valid_splitwithsizes_cat, 915) 916def splitwithsizes_cat_replace(match, input_): 917 return input_ 918 919 920def is_valid_cat_splitwithsizes(match): 921 cat_nodes = filter_nodes(match.nodes, aten.cat) 922 split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) 923 if len(split_nodes) != 1 or len(cat_nodes) != 1: 924 return False 925 split_node, cat_node = split_nodes[0], cat_nodes[0] 926 927 # the cat node has other users: can't eliminate 928 if len(cat_node.users) > 1: 929 return False 930 931 # the dim of the cat and split should match 932 dim = get_arg_value(split_node, 2, "dim") 933 if dim != get_arg_value(cat_node, 1, "dim"): 934 return False 935 936 cat_inputs = list(get_arg_value(cat_node, 0)) 937 split_sizes = get_arg_value(split_node, 1, "split_sizes") 938 # the number of input tensors in cat and the 939 # length of the split sizes should match 940 if len(cat_inputs) != len(split_sizes): 941 return False 942 943 for cat_input, split_size in zip(cat_inputs, split_sizes): 944 # each cat input tensor's size along dim 945 # should match the corresponding split size 946 if "val" not in cat_input.meta: 947 return False 948 cat_input_size = cat_input.meta["val"].size(dim) 949 if cat_input_size != split_size: 950 return False 951 952 return True 953 954 955@register_lowering_pattern( 956 CallFunction( 957 aten.split_with_sizes, 958 CallFunction( 959 aten.cat, 960 KeywordArg("input_"), 961 Ignored(), 962 _users=MULTIPLE, 963 ), 964 Ignored(), 965 Ignored(), 966 ), 967 pass_number=2, 968 extra_check=is_valid_cat_splitwithsizes, 969) 970def cat_splitwithsizes_replace(match, input_): 971 return input_ 972 973 974def view_to_reshape(gm): 975 """ 976 Replace view ops in the GraphModule to reshape ops. 977 """ 978 for nd in gm.graph.find_nodes( 979 op="call_function", target=torch.ops.aten.view.default 980 ): 981 nd.target = torch.ops.aten.reshape.default 982 983 984def should_prefer_unfused_addmm(match): 985 inp = match.kwargs["inp"] 986 if not inp.meta["val"].is_cuda: 987 return False 988 989 output = match.output_node() 990 return all(is_pointwise_use(use) for use in output.users) 991 992 993@register_graph_pattern( 994 CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), 995 pass_dict=pass_patterns[2], 996 extra_check=should_prefer_unfused_addmm, 997) 998def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): 999 def repl(inp, x1, x2): 1000 return x1 @ x2 + inp 1001 1002 match.replace_by_example(repl, [inp, mat1, mat2]) 1003 1004 1005def is_valid_addmm_fusion(match): 1006 mat1, mat2 = match.args 1007 inp = match.kwargs["inp"] 1008 1009 if not ( 1010 isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) 1011 ): 1012 return False # Input is a number 1013 1014 in_shape = inp.meta["val"].shape 1015 mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] 1016 matched = is_expandable_to(in_shape, mm_shape) 1017 if not matched: 1018 return False # Shape mismatch 1019 1020 return not should_prefer_unfused_addmm(match) 1021 1022 1023@register_graph_pattern( 1024 CallFunction( 1025 aten.add, 1026 CallFunction(aten.mm, Arg(), Arg()), 1027 KeywordArg("inp"), 1028 ), 1029 pass_dict=pass_patterns[2], 1030 extra_check=is_valid_addmm_fusion, 1031) 1032@register_graph_pattern( 1033 CallFunction( 1034 aten.add, 1035 KeywordArg("inp"), 1036 CallFunction(aten.mm, Arg(), Arg()), 1037 ), 1038 pass_dict=pass_patterns[2], 1039 extra_check=is_valid_addmm_fusion, 1040) 1041def addmm(match, mat1, mat2, *, inp): 1042 def repl(inp, mat1, mat2): 1043 return aten.addmm(inp, mat1, mat2) 1044 1045 match.replace_by_example(repl, [inp, mat1, mat2]) 1046 1047 1048def check_shape_cuda_and_fused_int_mm_mul_enabled(match): 1049 return ( 1050 config.force_fuse_int_mm_with_mul 1051 and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 1052 and getattr(match.args[2].meta.get("val"), "is_cuda", False) 1053 ) 1054 1055 1056@register_lowering_pattern( 1057 CallFunction( 1058 prims.convert_element_type.default, 1059 CallFunction( 1060 aten.mul, 1061 CallFunction( 1062 aten._int_mm, 1063 Arg(), 1064 Arg(), 1065 ), 1066 Arg(), 1067 ), 1068 Arg(), 1069 ), 1070 check_shape_cuda_and_fused_int_mm_mul_enabled, 1071) 1072@register_lowering_pattern( 1073 CallFunction( 1074 aten.mul, 1075 CallFunction( 1076 aten._int_mm, 1077 Arg(), 1078 Arg(), 1079 ), 1080 Arg(), 1081 ), 1082 check_shape_cuda_and_fused_int_mm_mul_enabled, 1083) 1084def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None): 1085 return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype) 1086 1087 1088def is_index_put_and_requires_h2d_sync_for_gpu_value(node): 1089 from torch.fx.operator_schemas import normalize_function 1090 1091 if node.target not in [ 1092 torch.ops.aten.index_put.default, 1093 torch.ops.aten.index_put_.default, 1094 ]: 1095 return False 1096 # Inductor falls back to aten.index_put_. 1097 # index_put_ will will call nonzero() and perform a H2D sync if 1098 # any of its indices are bool/byte tensors 1099 # However, it will short-circuit this H2D sync and run mask_fill_ 1100 # if the value we are putting is a cpu scalar. 1101 # Therefore, when inductor sees an index_put_ with byte tensor indices, 1102 # it should *not* convert the cpu scalar value into a gpu tensor. 1103 args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc] 1104 any_byte_bool_indices = False 1105 indices = args_[1] 1106 for i in indices: 1107 if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]: 1108 any_byte_bool_indices = True 1109 1110 val = args_[2].meta["val"] 1111 val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1 1112 # If both these conditions hold, then converting the val 1113 # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_ 1114 return any_byte_bool_indices and val_is_cpu_scalar 1115 1116 1117class ConstructorMoverPass: 1118 def __init__(self, target: str, allow_outputs: bool = False) -> None: 1119 """ 1120 Move constructors from cpu to the target_device. 1121 1122 Sweeps through the module, looking for constructor nodes that can be moved 1123 to the target_device. 1124 1125 A constructor node can be moved to the target_device iff all of its users 1126 can also be moved (tested by cannot_be_moved). Otherwise, all dependent 1127 constructor nodes won't be moved. 1128 1129 - target: target device type 1130 - allow_outputs: allow outputs to be moved 1131 """ 1132 1133 self.target = target 1134 self.allow_outputs = allow_outputs 1135 1136 assert isinstance(target, str), ( 1137 "target should be a string representing the device type. " 1138 f"Got: {type(target).__name__}" 1139 ) 1140 1141 def allow_cpu_device(self, node: fx.Node) -> bool: 1142 """ 1143 Returns whether a node that returns a tensor on the target device may have 1144 cpu tensors as input. 1145 """ 1146 return node.target in ( 1147 torch.ops.aten.index.Tensor, 1148 torch.ops.aten.index_put.default, 1149 torch.ops.aten.index_put_.default, 1150 torch.ops.aten.copy.default, 1151 torch.ops.aten.copy_.default, 1152 torch.ops.aten.slice_scatter.default, 1153 ) 1154 1155 def cannot_be_moved(self, node: fx.Node) -> bool: 1156 """ 1157 Returns whether a node can be moved to the target device. 1158 1159 If this function returns False, it means that this node and all of its users 1160 won't be moved into the target device. 1161 """ 1162 if node.target == "output": 1163 return not self.allow_outputs 1164 1165 if not ( 1166 isinstance(node.target, torch._ops.OpOverload) 1167 and node.target.namespace in ("prims", "aten") 1168 ): 1169 return True 1170 if is_index_put_and_requires_h2d_sync_for_gpu_value(node): 1171 return True 1172 1173 return False 1174 1175 def get_node_device(self, node: fx.Node) -> Optional[torch.device]: 1176 """ 1177 Get the device of a node. 1178 """ 1179 ten = node.meta.get("val") 1180 return None if not isinstance(ten, torch.Tensor) else ten.device 1181 1182 def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]: 1183 """ 1184 Get the number of cpu inputs to a node 1185 """ 1186 cpu_indeg: Dict[fx.Node, int] = Counter() 1187 1188 for node in graph.nodes: 1189 cpu_count = 0 1190 1191 def add_cpu_inp(node): 1192 nonlocal cpu_count 1193 device = self.get_node_device(node) 1194 cpu_count += device is not None and device.type == "cpu" 1195 1196 pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) 1197 1198 if cpu_count: 1199 cpu_indeg[node] = cpu_count 1200 1201 return cpu_indeg 1202 1203 def __call__(self, graph: fx.Graph) -> None: 1204 target_devices = set() 1205 constructors = [] 1206 1207 for node in graph.nodes: 1208 device = self.get_node_device(node) 1209 if device and device.type == self.target: 1210 target_devices.add(device) 1211 1212 if not ( 1213 isinstance(node.target, torch._ops.OpOverload) 1214 and node.target.namespace in ("prims", "aten") 1215 ): 1216 continue 1217 1218 if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): 1219 continue 1220 1221 if not node.kwargs.get("device") == torch.device("cpu"): 1222 continue 1223 1224 constructors.append(node) 1225 1226 # not handling multiple target devices initially 1227 if not constructors or len(target_devices) != 1: 1228 return 1229 1230 movable_constructors = self.find_movable_constructors(graph, constructors) 1231 1232 for node in movable_constructors: 1233 kwargs = node.kwargs.copy() 1234 kwargs["device"] = next(iter(target_devices)) 1235 node.kwargs = kwargs 1236 1237 def find_movable_constructors( 1238 self, graph: fx.Graph, constructors: List[fx.Node] 1239 ) -> Set[fx.Node]: 1240 """ 1241 Starting from the cpu constructors, iterate through the graph and test that all of their 1242 downstream uses can safely be moved to cpu. 1243 """ 1244 cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph) 1245 1246 # which constructors cannot be moved to gpu 1247 cannot_move_to_gpu: Set[fx.Node] = set() 1248 1249 # For any node in the graph, which constructors does it have a dependency on 1250 constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set) 1251 1252 # if a cpu node has a dependency on two different cpu constructors, 1253 # then if either constructor cannot be moved to gpu, the other cannot as well. 1254 # In this case any node with a dependency on one will have a dependency on the other 1255 equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = { 1256 c: {c} for c in constructors 1257 } 1258 1259 def make_dependencies_equivalent( 1260 set1: Set[fx.Node], set2: Set[fx.Node] 1261 ) -> Set[fx.Node]: 1262 # could use union find but not worth complexity here 1263 set1.update(set2) 1264 for obj in set1: 1265 equal_constructor_sets[obj] = set1 1266 return set1 1267 1268 queue: List[fx.Node] = list(constructors) 1269 1270 for c in queue: 1271 constructor_dependencies[c].add(c) 1272 1273 while queue: 1274 node = queue.pop() 1275 dependencies = constructor_dependencies[node] 1276 1277 for user in node.users: 1278 if self.cannot_be_moved(user): 1279 cannot_move_to_gpu.update(dependencies) 1280 break 1281 1282 # this node was used on a op which takes in multiple devices and output a gpu 1283 # tensor. we can convert its cpu input to gpu without making further changes 1284 node_device = self.get_node_device(user) 1285 if ( 1286 self.allow_cpu_device(user) 1287 and node_device 1288 and node_device.type == self.target 1289 ): 1290 del cpu_indeg[user] 1291 else: 1292 # otherwise, we should continue look at its downstream uses 1293 cpu_indeg[user] -= 1 1294 if cpu_indeg[user] == 0: 1295 del cpu_indeg[user] 1296 queue.append(user) 1297 1298 unioned_set = make_dependencies_equivalent( 1299 dependencies, constructor_dependencies[user] 1300 ) 1301 constructor_dependencies[user] = unioned_set 1302 1303 for node in cpu_indeg: 1304 if constructor_dependencies[node]: 1305 cannot_move_to_gpu.update(constructor_dependencies[node]) 1306 1307 all_cannot_move_to_gpu = cannot_move_to_gpu.copy() 1308 for constructor in cannot_move_to_gpu: 1309 all_cannot_move_to_gpu.update(equal_constructor_sets[constructor]) 1310 1311 return set(constructors) - all_cannot_move_to_gpu 1312 1313 1314def move_constructors_to_gpu(graph: fx.Graph) -> None: 1315 """ 1316 Moves intermediary tensors which are constructed on the cpu to gpu when safe 1317 """ 1318 ConstructorMoverPass(get_gpu_type())(graph) 1319