1# mypy: allow-untyped-decorators 2""" 3# Inductor Pattern Matcher 4 5The pattern matcher enables search/replace within an FX graph. 6 7The main entrypoint to the pattern matcher is register_replacement(). Given a 8search function and a replacement function this will register a replacement with 9a pass (such as torch._inductor.fx_passes.joint_graph.patterns). 10 11Internally the pattern matcher represents patterns as a graph (a DAG). Creating 12new patterns manually as a graph is cumbersome and error-prone so the standard 13way to create patterns (using register_replacement()) is to provide a search 14function and a replacement function which is traced and converted into a graph. 15 16Because the search functions are built somewhat generic (they tend to ignore 17tensor sizes, for example) register_replacement() allows you to specify an 18`extra_check` function which performs additional checks to verify that the 19matched pattern fully matches before returning it. 20 21## Precompiled Patterns 22 23New patterns are added using register_replacement(). Patterns added in this way 24can have a compile-time overhead because they need to be traced before 25use. Patterns can be precompiled and added using gen_register_replacement() 26instead. To do this you call gen_register_replacement() instead of 27register_replacement(). The arguments are the same except for an additional 28unique name which is used as a lookup key. 29 30## Internals 31 32The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr 33implements a `_match` method which returns either a `Match` object for a 34successful match or a `FailedMatch` object for a failure to match. 35""" 36 37from __future__ import annotations 38 39import contextlib 40import dataclasses 41import functools 42import importlib 43import inspect 44import itertools 45import logging 46import operator 47import os 48import re 49import textwrap 50import typing 51from abc import ABC, abstractmethod 52from collections import defaultdict 53from pathlib import Path 54from typing import ( 55 Any, 56 Callable, 57 DefaultDict, 58 Dict, 59 Generator, 60 Iterable, 61 List, 62 Mapping, 63 NoReturn, 64 Optional, 65 Protocol, 66 Sequence, 67 Set, 68 Tuple, 69 Type, 70 TypeVar, 71 Union, 72) 73from typing_extensions import Self, TypeGuard 74 75import torch 76import torch._guards 77import torch.fx 78import torch.utils._pytree as pytree 79from torch._dispatch.python import enable_python_dispatcher 80from torch._dynamo.utils import counters 81from torch._inductor.config import trace as trace_config 82from torch._prims_common import is_integer_dtype 83from torch._subclasses.fake_tensor import unset_fake_temporarily 84from torch.fx.experimental.proxy_tensor import make_fx 85from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 86from torch.fx.immutable_collections import immutable_dict, immutable_list 87from torch.fx.passes.graph_transform_observer import GraphTransformObserver 88 89from .._functorch import config as functorch_config 90from .._functorch.aot_autograd import aot_function, make_boxed_func 91from .._functorch.partitioners import default_partition 92from .._subclasses import FakeTensor, FakeTensorMode 93from ..fx import Transformer 94from . import config 95from .decomposition import select_decomp_table 96from .lowering import fallback_node_due_to_unsupported_type 97 98 99log = logging.getLogger(__name__) 100aten = torch.ops.aten 101prims = torch.ops.prims 102 103Constant = Any 104NodeOrConstant = Union[Constant, torch.fx.Node] 105 106 107class SearchFn(Protocol): 108 __name__: str 109 110 def __call__(self, *args: Any, **kwargs: Any) -> Any: 111 ... 112 113 114class ReplaceFn(Protocol): 115 def __call__(self, *args: Any, **kwargs: Any) -> Any: 116 ... 117 118 119class TraceFn(Protocol): 120 def __call__( 121 self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any 122 ) -> torch.fx.GraphModule: 123 ... 124 125 126T = TypeVar("T") 127 128# What's a better name for this? 129FnsType = Union[torch.fx.node.Target, str] 130 131 132class Multiple: 133 def __init__(self) -> None: 134 # Ensure we're really a singleton. 135 assert "MULTIPLE" not in globals() or self is MULTIPLE 136 137 138# Sentinel indicating multiple quantities can be matched 139MULTIPLE = Multiple() 140 141 142class Match: 143 """ 144 Represents a successfully matched pattern. 145 146 The `Match` object is returned to represent a successfully matched 147 pattern. Included in the Match are the pattern that was matched, the graph 148 nodes matched, and any args that were used during the matching. 149 150 The args and kwargs are specific to the type of pattern that was matched and 151 provide hints about what was matched. 152 """ 153 154 pattern: PatternExpr 155 args: List[Any] 156 kwargs: Dict[str, Any] 157 nodes: List[torch.fx.Node] 158 targets: Dict[_TargetExpr, torch.fx.node.Target] 159 ctx: MatchContext 160 replacement_graph: Optional[torch.fx.Graph] 161 162 def __init__( 163 self, 164 ctx: MatchContext, 165 pattern: PatternExpr, 166 args: Optional[Sequence[Any]] = None, 167 kwargs: Optional[Dict[str, Any]] = None, 168 ) -> None: 169 super().__init__() 170 self.pattern = pattern 171 # The input nodes that must be passed in to the result 172 self.args = list(args or []) 173 self.kwargs = kwargs or {} 174 # The nodes matched in this expression 175 self.nodes = [] 176 # Mapping CallFunction to the node.target 177 self.targets = {} 178 self.ctx = ctx 179 self.replacement_graph = None 180 181 @property 182 def graph(self) -> torch.fx.Graph: 183 return self.ctx.graph 184 185 def extend(self, other: Match) -> None: 186 if self.kwargs: 187 for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): 188 if self.kwargs[key] != other.kwargs[key]: 189 raise FailedMatch("kwarg mismatch: {}", key) 190 self.args.extend(other.args) 191 self.nodes.extend(other.nodes) 192 self.kwargs.update(other.kwargs) 193 self.targets.update(other.targets) 194 195 def bundle(self) -> Match: 196 # Wrap args in an extra list 197 self.args = [tuple(self.args)] if self.args else [] 198 return self 199 200 def __repr__(self) -> str: 201 return f"Match(..., {self.args}, {self.kwargs})" 202 203 def erase_nodes(self) -> None: 204 graph = self.graph 205 for n in reversed(self.nodes): 206 if not n._erased and not n.users: 207 graph.erase_node(n) 208 209 def output_nodes(self) -> List[Optional[torch.fx.Node]]: 210 return [ 211 (self.ctx.pattern_to_node[p] if p is not None else None) 212 for p in self.ctx.outputs 213 ] 214 215 def output_node(self) -> torch.fx.Node: 216 return next(p for p in self.output_nodes() if p) 217 218 def replace_with_graph( 219 self, replacement_graph: torch.fx.Graph, args: Sequence[Any] 220 ) -> None: 221 ReplacementPatternEntry.replace_with_graph( 222 self, self.ctx.graph, replacement_graph, args 223 ) 224 225 def replace_by_example( 226 self, 227 replacement_fn: ReplaceFn, 228 args: Sequence[Any], 229 trace_fn: Optional[TraceFn] = None, 230 run_functional_passes: bool = True, 231 ) -> None: 232 """Replace with a graph generated by tracing the replacement_fn. 233 234 Args: 235 run_functional_passes (bool). If we should run passes that 236 assume functional IR (like DCE, remove_noop_ops), on the 237 replacement graph. 238 239 """ 240 from torch._inductor.virtualized import NullHandler, V 241 242 context = ( 243 V.fake_mode 244 if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) 245 else contextlib.nullcontext() 246 ) 247 248 with context: 249 if trace_fn is None: 250 trace_fn = functools.partial( 251 fwd_only, run_functional_passes=run_functional_passes 252 ) 253 replacement = trace_fn( 254 replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type] 255 ) 256 ReplacementPatternEntry.replace_with_graph( 257 self, 258 self.ctx.graph, 259 replacement, 260 args, 261 ) 262 263 264class FailedMatch(RuntimeError): 265 """ 266 Represents a unsuccessful match. 267 268 The `FailedMatch` object is returned to represent a failure to match a 269 pattern. 270 """ 271 272 format_string: str 273 274 def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: 275 self.format_string = format_string 276 # We want to construct error messages lazily instead of eagerly, as 277 # constructing them eagerly can significantly worsen compile times. 278 if len(format_string) > 200: 279 raise RuntimeError( 280 f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" 281 ) 282 self.args = args 283 self.kwargs = kwargs 284 285 def __str__(self) -> str: 286 return self.format_string.format(*self.args, **self.kwargs) 287 288 def __bool__(self) -> bool: 289 return False 290 291 292MatchResult = Union[Match, FailedMatch] 293 294 295def is_match(m: MatchResult) -> TypeGuard[Match]: 296 """ 297 TypeGuards cannot act on `self`. Thus this function exists to let mypy 298 recognize FailedMatch.__bool__ as a TypeGuard. 299 """ 300 return bool(m) 301 302 303class MatchContext: 304 """ 305 Internal state needed while running PatternExpr._match(). 306 """ 307 308 outputs: List[Optional[PatternExpr]] 309 pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] 310 graph: torch.fx.Graph 311 exclusive_node_set: List[NodeOrConstant] 312 313 def __init__( 314 self, 315 outputs: List[Optional[PatternExpr]], 316 pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, 317 *, 318 graph: torch.fx.Graph, 319 ) -> None: 320 self.outputs = outputs 321 self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) 322 self.graph = graph 323 self.exclusive_node_set = [] 324 325 def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: 326 """wrapper to check reused nodes in patterns""" 327 if pattern in self.pattern_to_node: 328 if self.pattern_to_node[pattern] == node: 329 return Match(self, pattern) # already checked this node 330 else: 331 return FailedMatch("repeated pattern differs") 332 m = pattern._match(node, self) 333 assert pattern not in self.pattern_to_node 334 self.pattern_to_node[pattern] = node if m else None 335 return m 336 337 def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: 338 return { 339 pattern: node 340 for pattern, node in self.pattern_to_node.items() 341 if pattern.has_multiple_users() and node is not None 342 } 343 344 345class PatternExpr(ABC): 346 """ 347 Base class for types of patterns. 348 """ 349 350 @abstractmethod 351 def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: 352 ... 353 354 def match(self, node: torch.fx.Node) -> MatchResult: 355 try: 356 return MatchContext([self], graph=node.graph).match(self, node) 357 except FailedMatch as e: 358 return e 359 360 def has_multiple_users(self) -> bool: 361 return False 362 363 def __repr__(self) -> str: 364 return self.__class__.__name__ + "()" 365 366 def find_anchor_nodes( 367 self, ctx: MatchContext, searched: Set[torch.fx.Node] 368 ) -> Generator[Optional[torch.fx.Node], None, None]: 369 if self in ctx.pattern_to_node: 370 yield ctx.pattern_to_node[self] 371 372 def pattern_eq(self, other: Any) -> bool: 373 """ 374 Compare two `PatternExpr`s and return true if they are the 375 same. Note this is NOT matching a pattern - it is comparing the pattern 376 structures (for debugging). 377 """ 378 return isinstance(other, self.__class__) 379 380 381class Arg(PatternExpr): 382 """ 383 Capture an arg which will become an input to the handler. Args are 384 passed in depth first order. 385 """ 386 387 def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: 388 return Match(ctx, self, args=[node]) # matches anything 389 390 391class Ignored(PatternExpr): 392 """ 393 Match an arg, but don't pass it to handler 394 """ 395 396 def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: 397 return Match(ctx, self) # matches anything 398 399 def __repr__(self) -> str: 400 return "*" 401 402 def pretty_print(self, pp: PatternPrettyPrinter) -> str: 403 return "Ignored()" 404 405 406class KeywordArg(PatternExpr): 407 """ 408 Capture a kwarg which will become an input to the handler. 409 """ 410 411 def __init__(self, name: str) -> None: 412 super().__init__() 413 self.name = name 414 415 def __repr__(self) -> str: 416 return f"KeywordArg({self.name!r})" 417 418 def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: 419 return Match(ctx, self, kwargs={self.name: node}) # matches anything 420 421 def pattern_eq(self, other: Any) -> bool: 422 other = typing.cast(Self, other) # super makes sure this is true 423 return super().pattern_eq(other) and self.name == other.name 424 425 426class ExclusiveKeywordArg(PatternExpr): 427 """ 428 Capture a kwarg which will become an input to the handler. 429 """ 430 431 name: str 432 433 def __init__(self, name: str) -> None: 434 super().__init__() 435 self.name = name 436 437 def __repr__(self) -> str: 438 return f"ExclusiveKeywordArg({self.name!r})" 439 440 def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: 441 if node in ctx.exclusive_node_set: 442 return FailedMatch("exclusive arg appears twice") 443 444 ctx.exclusive_node_set.append(node) 445 return Match(ctx, self, kwargs={self.name: node}) # matches anything 446 447 def pattern_eq(self, other: Any) -> bool: 448 other = typing.cast(Self, other) # super makes sure this is true 449 return super().pattern_eq(other) and self.name == other.name 450 451 452class _TargetExpr(PatternExpr): 453 """ 454 Base class for filtering match by node.target 455 """ 456 457 fns: List[FnsType] 458 fns_set: Set[FnsType] 459 460 def __init__( 461 self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 462 ) -> None: 463 super().__init__() 464 fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) 465 for fn in fns: 466 if isinstance(fn, torch._ops.OpOverloadPacket): 467 fns.extend(getattr(fn, overload) for overload in fn.overloads()) 468 469 self.fns = fns 470 self.fns_set = set(fns) 471 self.users = users 472 473 @property 474 @abstractmethod 475 def op(self) -> str: 476 ... 477 478 def fns_repr(self) -> str: 479 first_repr = self.fns[0] 480 if not isinstance(first_repr, str): 481 first_repr = first_repr.__name__ 482 483 if len(self.fns) > 1: 484 return f"[{first_repr}, ...]" 485 elif self.fns[0] is getattr(torch, first_repr, None): 486 return f"torch.{first_repr}" 487 elif isinstance(self.fns[0], torch._ops.OpOverload): 488 return str(self.fns[0]) 489 else: 490 return first_repr 491 492 def __repr__(self) -> str: 493 if self.users is MULTIPLE: 494 comma_users = ", MULTIPLE" 495 elif self.users != 1: 496 comma_users = f", {self.users})" 497 else: 498 comma_users = "" 499 return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})" 500 501 def has_multiple_users(self) -> bool: 502 return isinstance(self.users, Multiple) or self.users > 1 503 504 def find_anchor_nodes( 505 self, ctx: MatchContext, searched: Set[torch.fx.Node] 506 ) -> Generator[Optional[torch.fx.Node], None, None]: 507 raise NotImplementedError 508 509 def _match_fns(self, node: torch.fx.Node) -> bool: 510 return ( 511 isinstance(node, torch.fx.Node) 512 and node.op == self.op 513 and extract_target(node) in self.fns_set 514 ) 515 516 def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: 517 return ( 518 self in ctx.outputs 519 or self.users is MULTIPLE 520 or len(node.users) == self.users 521 ) 522 523 def pattern_eq(self, other: Any) -> bool: 524 other = typing.cast(Self, other) # super makes sure this is true 525 return ( 526 super().pattern_eq(other) 527 and self.op == other.op 528 and self.fns == other.fns 529 and self.users == other.users 530 ) 531 532 533_SimpleSpec = Tuple[Any, ...] 534 535 536class _TargetArgsExpr(_TargetExpr): 537 """ 538 Base class for filtering match by node.{target,args,kwargs} 539 """ 540 541 def __init__( 542 self, 543 fns: Union[torch.fx.node.Target, str, Sequence[Any]], 544 *args: Any, 545 _users: Union[int, Multiple] = 1, 546 **kwargs: Any, 547 ) -> None: 548 super().__init__(fns, _users) 549 self.args = tuple(args) 550 self.kwargs = dict(kwargs) 551 if any( 552 isinstance(x, (dict, list, tuple)) 553 for x in itertools.chain(args, kwargs.values()) 554 ): 555 self.flatten = self.pytree_flatten 556 else: 557 self.flatten = self.simple_flatten 558 self.flat_args_kwargs = self.flatten(self.args, self.kwargs) 559 560 @staticmethod 561 def simple_flatten( 562 args: Sequence[Any], kwargs: Mapping[Any, Any] 563 ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: 564 values = (*args, *kwargs.values()) 565 spec = (len(args), *kwargs.keys()) 566 return values, spec 567 568 @staticmethod 569 def pytree_flatten( 570 args: Sequence[Any], kwargs: Mapping[Any, Any] 571 ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: 572 def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: 573 if s.type is None: 574 return s 575 mapping = {immutable_list: list, tuple: list, immutable_dict: dict} 576 return pytree.TreeSpec( 577 mapping.get(s.type, s.type), 578 s.context, 579 list(map(norm_spec, s.children_specs)), 580 ) 581 582 flat, spec = pytree.tree_flatten([args, kwargs]) 583 spec = norm_spec(spec) 584 return flat, spec 585 586 def __repr__(self) -> str: 587 args = [ 588 self.fns_repr(), 589 *map(repr, self.args), 590 *[f"{k}={v}" for k, v in self.kwargs.items()], 591 ] 592 if self.users is MULTIPLE: 593 args.append("_users=MULTIPLE") 594 elif self.users != 1: 595 args.append(f"_users={self.users}") 596 return f"{self.__class__.__name__}({', '.join(args)})" 597 598 def pretty_print(self, pp: PatternPrettyPrinter) -> str: 599 args = [ 600 self.fns_repr(), 601 *(pp.pretty_print(x) for x in self.args), 602 *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], 603 ] 604 if self.users is MULTIPLE: 605 args.append("_users=MULTIPLE") 606 elif self.users != 1: 607 args.append(f"_users={self.users}") 608 609 joiner_str = ", " 610 return f"{self.__class__.__name__}({joiner_str.join(args)})" 611 612 def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: 613 if not self._match_fns(node) or len(node.args) != len(self.args): 614 return FailedMatch("function_mismatch: node={}, pattern={}", node, self) 615 616 if not self._match_users(node, ctx): 617 return FailedMatch("multiple_users {}", self) 618 619 _args = node.args 620 _kwargs = node.kwargs 621 if len(_kwargs) < len(self.kwargs): 622 from torch.fx.operator_schemas import normalize_function 623 624 normalized_args_and_kwargs = normalize_function( 625 node.target, node.args, node.kwargs # type: ignore[arg-type] 626 ) 627 628 if normalized_args_and_kwargs is None: 629 return FailedMatch("function_mismatch: node={}, pattern={}", node, self) 630 else: 631 _args, _kwargs = normalized_args_and_kwargs 632 if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): 633 _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} 634 else: 635 return FailedMatch( 636 "function_mismatch: node={}, pattern={}", node, self 637 ) 638 else: 639 _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} 640 641 node_items, node_spec = self.flatten(_args, _kwargs) 642 self_items, self_spec = self.flat_args_kwargs 643 if node_spec != self_spec: 644 return FailedMatch("args_structure {} {}", node_spec, self_spec) 645 assert len(node_items) == len(self_items) 646 647 m = Match(ctx, self) 648 for i, pattern, child_node in zip(itertools.count(), self_items, node_items): 649 if isinstance(pattern, PatternExpr): 650 child_match = ctx.match(pattern, child_node) 651 if not is_match(child_match): 652 return child_match 653 m.extend(child_match) 654 elif isinstance(child_node, torch.fx.Node) or child_node != pattern: 655 return FailedMatch( 656 "constant_args: {} {!r}!={pattern!r}", node, child_node 657 ) 658 m.nodes.append(node) 659 m.targets[self] = node.target 660 return m 661 662 def find_anchor_nodes( 663 self, ctx: MatchContext, searched: Set[torch.fx.Node] 664 ) -> Generator[Optional[torch.fx.Node], None, None]: 665 """ 666 This is used when we are matching a pattern with multiple outputs. 667 There is a partial match (stored in ctx) and we want to walk 668 this pattern to find a connection to an already-matched node. 669 670 Yields candidate nodes that `self._match` might like. 671 """ 672 if self in ctx.pattern_to_node: 673 yield ctx.pattern_to_node[self] 674 return 675 676 for pattern in self.flat_args_kwargs[0]: 677 if isinstance(pattern, PatternExpr): 678 for other_node in pattern.find_anchor_nodes(ctx, searched): 679 if not isinstance(other_node, torch.fx.Node): 680 continue 681 for node in other_node.users: 682 if node not in searched: 683 if self._match_fns(node): 684 yield node 685 searched.add(node) 686 687 def pattern_eq(self, other: Any) -> bool: 688 other = typing.cast(Self, other) # super makes sure this is true 689 return ( 690 super().pattern_eq(other) 691 and self.flat_args_kwargs[1] == other.flat_args_kwargs[1] 692 and all( 693 a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b 694 for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0]) 695 ) 696 ) 697 698 699class CallFunction(_TargetArgsExpr): 700 """ 701 Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` 702 """ 703 704 op = "call_function" 705 706 707class CallMethod(_TargetArgsExpr): 708 """ 709 Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` 710 """ 711 712 op = "call_method" 713 714 715class CallModule(_TargetArgsExpr): 716 """ 717 Matches a call_module node in the FX graphs: `module(*args, **kwargs)` 718 """ 719 720 op = "call_module" 721 722 723class _TargetExprVarArgs(_TargetExpr): 724 """ 725 Matches a call_function node with any arguments which are passed into the pattern 726 """ 727 728 def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: 729 if not self._match_fns(node): 730 return FailedMatch("function_mismatch") 731 732 if not self._match_users(node, ctx): 733 return FailedMatch("multiple_users") 734 735 m = Match(ctx, self) 736 m.nodes.append(node) 737 m.targets[self] = node.target 738 m.args.extend(node.args) 739 m.kwargs.update(node.kwargs) 740 return m 741 742 743class CallFunctionVarArgs(_TargetExprVarArgs): 744 op = "call_function" 745 746 747class CallMethodVarArgs(_TargetExprVarArgs): 748 op = "call_method" 749 750 751class CallModuleVarArgs(_TargetExprVarArgs): 752 op = "call_module" 753 754 755class ListOf(PatternExpr): 756 """ 757 Matches a repeated pattern 758 """ 759 760 def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: 761 super().__init__() 762 assert isinstance(pattern, PatternExpr) 763 self.pattern = pattern 764 self.partial = partial 765 766 def __repr__(self) -> str: 767 return f"{self.__class__.__name__}({self.pattern})" 768 769 def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] 770 if not isinstance(node, (list, tuple)) or len(node) == 0: 771 return FailedMatch("non_list") 772 m = Match(ctx, self) 773 # Propagating patterns with multiple users will ensure we don't revisit 774 # the same nodes 775 pattern_to_node = ctx.filter_multi_user_patterns() 776 matched = False 777 for i, child_node in enumerate(node): 778 child_ctx = MatchContext( 779 ctx.outputs, pattern_to_node, graph=child_node.graph 780 ) 781 child_match = child_ctx.match(self.pattern, child_node) 782 pattern_to_node = child_ctx.filter_multi_user_patterns() 783 if not is_match(child_match): 784 if not self.partial: 785 return FailedMatch("list[{}]: {}", i, child_match) 786 continue 787 matched = True 788 m.extend(child_match.bundle()) 789 if not matched: 790 return FailedMatch("list: no_match") 791 return m.bundle() 792 793 def pattern_eq(self, other: Any) -> bool: 794 other = typing.cast(Self, other) # super makes sure this is true 795 return ( 796 super().pattern_eq(other) 797 and self.pattern.pattern_eq(other.pattern) 798 and self.partial == other.partial 799 ) 800 801 802class MultiOutputPattern(PatternExpr): 803 outputs: List[Optional[PatternExpr]] 804 805 def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: 806 super().__init__() 807 assert isinstance(outputs[0], _TargetExpr) 808 assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs 809 self.outputs = list(outputs) 810 self.op = outputs[0].op 811 812 @property 813 def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: 814 # This cast is checked above in __init__() 815 output = typing.cast(_TargetExpr, self.outputs[0]) 816 return output.fns 817 818 def __repr__(self) -> str: 819 return f"{self.__class__.__name__}({self.outputs})" 820 821 def pretty_print(self, pp: PatternPrettyPrinter) -> str: 822 args = [pp.pretty_print(x) for x in self.outputs] 823 joiner_str = f",\n{' '}" 824 str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" 825 str_out = f"{str_out}\n])" 826 return str_out 827 828 def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: 829 output = typing.cast(_TargetExpr, self.outputs[0]) 830 m = ctx.match(output, node) 831 if not is_match(m): 832 return m 833 834 for pattern in self.outputs[1:]: 835 if pattern is None: 836 continue 837 child_match = self._match_from_anchors(pattern, ctx) 838 if not is_match(child_match): 839 return child_match 840 m.extend(child_match) 841 842 return m 843 844 def _match_from_anchors( 845 self, pattern: PatternExpr, ctx: MatchContext 846 ) -> MatchResult: 847 prior = dict(ctx.pattern_to_node) 848 m: MatchResult = FailedMatch("no anchor found") 849 for node in pattern.find_anchor_nodes(ctx, set()): 850 m = ctx.match(pattern, node) 851 if is_match(m): 852 return m 853 # revert any partial matches 854 ctx.pattern_to_node = dict(prior) 855 return m 856 857 def match(self, node: torch.fx.Node) -> MatchResult: 858 try: 859 return MatchContext(self.outputs, graph=node.graph).match(self, node) 860 except FailedMatch as e: 861 return e 862 863 def pattern_eq(self, other: Any) -> bool: 864 other = typing.cast(Self, other) # super makes sure this is true 865 return ( 866 super().pattern_eq(other) 867 and len(self.outputs) == len(other.outputs) 868 and all( 869 a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b 870 for a, b in zip(self.outputs, other.outputs) 871 ) 872 ) 873 874 875class RepeatedExpr(PatternExpr): 876 """ 877 Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` 878 """ 879 880 def __init__(self, inner_pattern: _TargetExpr) -> None: 881 super().__init__() 882 self.inner_pattern = inner_pattern 883 self.op = inner_pattern.op 884 885 @property 886 def fns(self) -> Sequence[FnsType]: 887 return self.inner_pattern.fns 888 889 def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: 890 m = ctx.match(self.inner_pattern, node) 891 if not is_match(m): 892 return m 893 ctx.pattern_to_node.pop( 894 self.inner_pattern, 895 ) 896 # Check all anchor nodes match the pattern 897 for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()): 898 anchor_m = MatchContext([self], graph=node.graph).match( 899 self.inner_pattern, anchor_node 900 ) 901 if not is_match(anchor_m): 902 return anchor_m 903 m.extend(anchor_m) 904 return m 905 906 def pattern_eq(self, other: Any) -> bool: 907 other = typing.cast(Self, other) # super makes sure this is true 908 return super().pattern_eq(other) and self.inner_pattern.pattern_eq( 909 other.inner_pattern 910 ) 911 912 913class PatternPrettyPrinter: 914 """ 915 Serializes Patterns to executable python. 916 XXX: currently only used and tested for fuse attention patterns. May not cover 917 all patterns. 918 """ 919 920 def __init__(self) -> None: 921 self.namespace = torch.fx.graph._Namespace() 922 self.memoized_objs_names: Dict[PatternExpr, str] = {} 923 self.memoized_objs_pp: Dict[PatternExpr, str] = {} 924 925 @staticmethod 926 @functools.lru_cache(None) 927 def run(obj: PatternExpr, output_name: str = "output") -> str: 928 """ 929 Serializes obj to python code with obj written out to `output_name` 930 """ 931 932 pp = PatternPrettyPrinter() 933 assert hasattr(obj, "pretty_print") 934 out_str = obj.pretty_print(pp=pp) 935 936 output = [] 937 for key in pp.memoized_objs_names: 938 output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") 939 940 output.append(f"{output_name} = {out_str}") 941 942 return "\n".join(output) 943 944 def pretty_print(self, obj: Any) -> str: 945 if isinstance(obj, _TargetArgsExpr): 946 if memoized_name := self.memoized_objs_names.get(obj): 947 return memoized_name 948 else: 949 return self.memoize(obj) 950 if hasattr(obj, "pretty_print"): 951 return obj.pretty_print(self) 952 953 return repr(obj) 954 955 def memoize(self, obj: _TargetArgsExpr) -> str: 956 obj_str = obj.pretty_print(self) 957 obj_name = obj.fns_repr() 958 for prefix in ("aten.", "torch.", "prims."): 959 obj_name = obj_name.replace(prefix, "") 960 961 tmp_name = self.namespace.create_name(obj_name, None) 962 self.memoized_objs_names[obj] = tmp_name 963 self.memoized_objs_pp[obj] = obj_str 964 return tmp_name 965 966 967class _PassDictsType(Protocol): 968 def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: 969 ... 970 971 972@dataclasses.dataclass 973class PatternEntry: 974 pattern: PatternExpr 975 extra_check: Callable[[Match], bool] 976 977 def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: 978 raise NotImplementedError 979 980 def register( 981 self, 982 pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], 983 target: Union[torch.fx.node.Target, None] = None, 984 prepend: bool = False, 985 ) -> None: 986 if target is None: 987 assert hasattr(self.pattern, "fns") 988 for fn in self.pattern.fns: 989 self.register(pass_dicts, fn, prepend=prepend) 990 elif isinstance(pass_dicts, (dict, PatternMatcherPass)): 991 assert hasattr(self.pattern, "op") 992 if prepend: 993 pass_dicts[(self.pattern.op, target)].insert(0, self) 994 else: 995 pass_dicts[(self.pattern.op, target)].append(self) 996 else: 997 pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) 998 for x in pass_dicts: 999 self.register(x, target, prepend=prepend) 1000 1001 1002@dataclasses.dataclass 1003class LoweringPatternEntry(PatternEntry): 1004 handler: Callable[..., Any] 1005 1006 def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: 1007 handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) 1008 with graph.inserting_before(node): 1009 replacement = graph.call_function(handler, tuple(match.args), match.kwargs) 1010 replacement.meta.update(node.meta) 1011 node.replace_all_uses_with(replacement) 1012 assert match.nodes[-1] is node 1013 match.erase_nodes() 1014 1015 1016@dataclasses.dataclass 1017class GraphPatternEntry(PatternEntry): 1018 """ 1019 A pattern that runs a function on the FX graph 1020 """ 1021 1022 handler: Callable[..., Any] 1023 1024 def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: 1025 with graph.inserting_before(node): 1026 self.handler(match, *match.args, **match.kwargs) 1027 1028 1029@dataclasses.dataclass 1030class ReplacementPatternEntry(PatternEntry): 1031 normalize_args: Callable[..., List[Any]] 1032 1033 @staticmethod 1034 def replace_with_graph( 1035 match: Match, 1036 graph: torch.fx.Graph, 1037 replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], 1038 args: Sequence[torch.fx.Node], 1039 ) -> None: 1040 class Replacer(torch.fx.Interpreter): 1041 call_method = None # type: ignore[assignment] 1042 call_module = None # type: ignore[assignment] 1043 get_attr = None # type: ignore[assignment] 1044 1045 def run_node(self, node: torch.fx.Node) -> Any: 1046 if node.op in ("placeholder", "output"): 1047 return super().run_node(node) 1048 if node.op == "call_function": 1049 target = node.target 1050 args, kwargs = self.fetch_args_kwargs_from_env(node) 1051 result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] 1052 if "val" in node.meta and "val" not in result.meta: 1053 result.meta["val"] = node.meta["val"] 1054 if isinstance(node.meta["val"], torch.Tensor): 1055 assert "tensor_meta" in node.meta 1056 result.meta["tensor_meta"] = node.meta["tensor_meta"] 1057 return result 1058 raise NotImplementedError(f"unhandled {node}") 1059 1060 output_nodes = match.output_nodes() 1061 1062 if len(output_nodes) == 1: 1063 last_node = output_nodes[0] 1064 else: 1065 assert output_nodes[0] 1066 nodes = list(output_nodes[0].graph.nodes) 1067 indices = [ 1068 (nodes.index(n), n) 1069 for n in output_nodes 1070 if isinstance(n, torch.fx.Node) 1071 ] 1072 last_node = min(indices, key=operator.itemgetter(0))[1] 1073 1074 def percolate_tags( 1075 node: torch.fx.Node, 1076 tag_name: str, 1077 tag_value: str, 1078 input_stops: Set[torch.fx.Node], 1079 ) -> None: 1080 queue = [node] 1081 visited = set() 1082 1083 while queue: 1084 arg = queue.pop() 1085 if ( 1086 arg not in visited 1087 and arg not in input_stops 1088 and hasattr(arg, "meta") 1089 ): 1090 visited.add(arg) 1091 arg.meta[tag_name] = tag_value 1092 queue.extend(arg.all_input_nodes) 1093 1094 with graph.inserting_before(last_node): 1095 replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type] 1096 if isinstance(replacement, torch.fx.Node): 1097 replacement = [replacement] 1098 1099 def maybe_getitem(node: torch.fx.Node) -> Any: 1100 if node.op != "call_function": 1101 return None 1102 if node.target != operator.getitem: 1103 return None 1104 assert len(node.args) == 2 1105 return node.args[1] 1106 1107 def replace( 1108 old: Union[torch.fx.Node, None], 1109 new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], 1110 ) -> None: 1111 if old is None: 1112 assert new is None 1113 return 1114 assert isinstance(old, torch.fx.Node) 1115 if new is None: 1116 old.replace_all_uses_with(None) # type: ignore[arg-type] 1117 graph.erase_node(old) 1118 return 1119 if isinstance(new, torch.fx.Node): 1120 if "val" not in new.meta: 1121 new.meta.update(old.meta) 1122 1123 # Preserve the recompute tags in the replacement graph. We 1124 # look at the recompute tags of the original output node to 1125 # propagate the tag from the output all the way to the input 1126 # args (named as args in the replace_with_graph). 1127 # Note that this is best effort. Since patterns are from 1128 # many to many, there is no easy way to correctly map the 1129 # recomputable tags. It is possible in some scenarios that we 1130 # incorrectly tag some nodes as recomputables. 1131 for tag_name in ["recompute", "ac_graph_id"]: 1132 if tag_name in old.meta: 1133 percolate_tags(new, tag_name, old.meta[tag_name], set(args)) 1134 1135 old.replace_all_uses_with(new) 1136 graph.erase_node(old) 1137 return 1138 1139 # `new` is not a node: it's a list of nodes. 1140 # 1141 # This happens when we want to replace a node that has a single 1142 # packed return with multiple unpacked returns. We need to do 1143 # some graph surgery here. 1144 # 1145 # Example: 1146 # def original_graph(x): 1147 # a = op(x) 1148 # b = a[0] 1149 # c = a[1] 1150 # ... 1151 # 1152 # Assume that we want to replace op(x) with the graph 1153 # def new_op(x): 1154 # w = x + 1 1155 # z = x + 2 1156 # return (w, z) 1157 # 1158 # We need to replace `op` with the contents of `new_op`, 1159 # and then rewrite a[0] to be w and a[1] to be z, as so: 1160 # def new_graph(x): 1161 # w = x + 1 1162 # z = x + 2 1163 # b = w 1164 # c = z 1165 # ... 1166 old_uses = list(old.users.keys()) 1167 for user in old_uses: 1168 idx = maybe_getitem(user) 1169 if idx is None: 1170 raise AssertionError("can't handle") 1171 replace(user, new[idx]) # type: ignore[index] 1172 graph.erase_node(old) 1173 1174 if len(output_nodes) == len(replacement): 1175 for old, new in zip(output_nodes, replacement): 1176 replace(old, new) 1177 else: 1178 assert len(output_nodes) == 1 1179 replace(output_nodes[0], replacement) 1180 1181 match.erase_nodes() 1182 1183 def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: 1184 assert match.replacement_graph is not None 1185 self.replace_with_graph( 1186 match, 1187 graph, 1188 match.replacement_graph, 1189 self.normalize_args(*match.args, **match.kwargs), 1190 ) 1191 1192 1193def _return_true(match: Match) -> bool: 1194 return True 1195 1196 1197def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: 1198 log.info( 1199 "Replacement pattern %s failed to apply due to shape mismatch: %s", 1200 search_fn.__name__, 1201 e, 1202 ) 1203 1204 1205def register_replacement( 1206 search_fn: SearchFn, 1207 replace_fn: ReplaceFn, 1208 example_inputs: Iterable[Any], 1209 trace_fn: TraceFn, 1210 pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], 1211 extra_check: Callable[[Match], bool] = _return_true, 1212 scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, 1213 exclusive_arg_names: Sequence[str] = (), 1214 search_fn_pattern: Union[PatternExpr, None] = None, 1215) -> bool: 1216 """ 1217 Create a replacement rule based on example functions that get traced 1218 to create patterns. This supports both training and inference when 1219 run on a joint forward+backward graph. 1220 1221 Args: 1222 search_fn: traced to give original pattern 1223 replace_fn: traced to give replacement graph 1224 example_inputs: example inputs for initial trace 1225 trace_fn: fwd_only or joint_fwd_bwd 1226 pass_dict: dict of passes to register to 1227 extra_check: additional check to run on match(using real shapes) 1228 """ 1229 argnames_static = [*inspect.signature(search_fn).parameters.keys()] 1230 1231 def check_fn(match: Match) -> bool: 1232 """ 1233 Often shapes get burned into the pattern, so our initial match ran with 1234 `ignore_types=(int, ...)`. 1235 1236 Recheck the match with the correct shapes. 1237 """ 1238 argnames = list(argnames_static) 1239 for name in argnames: 1240 if name not in match.kwargs: 1241 raise RuntimeError( 1242 f"Not all inputs to pattern found in match.kwargs. Perhaps one " 1243 f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" 1244 ) 1245 1246 args = list( 1247 torch.fx.map_arg( # type: ignore[arg-type] 1248 [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] 1249 ) 1250 ) 1251 sym_args: List[torch.SymInt] = [] 1252 with torch._dynamo.utils.detect_fake_mode(args): 1253 for i, grad in enumerate(requires_grad): 1254 if isinstance(args[i], torch.Tensor): 1255 if grad and is_integer_dtype(args[i].dtype): 1256 return False 1257 1258 args[i] = torch.empty_strided( 1259 args[i].size(), 1260 args[i].stride(), 1261 dtype=args[i].dtype, 1262 device=args[i].device, 1263 requires_grad=grad, 1264 ) 1265 for v in itertools.chain(args[i].shape, args[i].stride()): 1266 if isinstance(v, torch.SymInt) and all( 1267 guard_size_oblivious(v != a) for a in sym_args 1268 ): 1269 sym_args.append(v) 1270 1271 # If we were given a pre-traced pattern then use that instead of 1272 # retracing. Note that this means the pattern has to be independent 1273 # of its args. 1274 specific_pattern = search_fn_pattern 1275 1276 if not specific_pattern: 1277 if sym_args: 1278 # AOT Autograd and make fx will dedupe symbolic shape size 1279 # accesses of sym ints that appear as inputs 1280 # We don't want the sym_size uses to interfere with pattern matching 1281 # so we provide them as inputs. 1282 # Later, when we actually do the replacement, the symbolic shape 1283 # sizes will get re-traced and added to the graph. 1284 1285 def search_fn_new(*args_new: Any) -> Any: 1286 return search_fn(*args_new[len(args_new) - len(args) :]) 1287 1288 try: 1289 specific_graph = trace_fn(search_fn_new, sym_args + args) 1290 except RuntimeError as e: 1291 log_trace_failure(search_fn, e) 1292 return False 1293 1294 # correct argnames in the graph 1295 sym_arg_names = [] 1296 for i, placeholder in zip( 1297 range(len(sym_args) + len(args)), 1298 specific_graph.graph.nodes, 1299 ): 1300 if i < len(sym_args): 1301 sym_arg_names.append(placeholder.target) 1302 continue 1303 1304 with specific_graph.graph.inserting_after(placeholder): 1305 new_node = specific_graph.graph.placeholder( 1306 argnames[i - len(sym_args)] 1307 ) 1308 new_node.target = new_node.name 1309 placeholder.replace_all_uses_with(new_node) 1310 specific_graph.graph.erase_node(placeholder) 1311 1312 argnames = sym_arg_names + argnames 1313 else: 1314 try: 1315 specific_graph = trace_fn(search_fn, args) 1316 except RuntimeError as e: 1317 log_trace_failure(search_fn, e) 1318 return False 1319 1320 specific_pattern = fx_to_pattern( 1321 specific_graph, 1322 argnames=argnames, 1323 exclusive_arg_names=exclusive_arg_names, 1324 scalar_workaround=scalar_workaround, 1325 ) 1326 1327 node = match.output_nodes()[0] 1328 assert node is not None 1329 specific_pattern_match = specific_pattern.match(node) 1330 1331 if is_match(specific_pattern_match) and extra_check(specific_pattern_match): 1332 # trace the pattern using the shapes from the user program 1333 match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] 1334 return True 1335 return False 1336 1337 def normalize_args(**kwargs: Any) -> List[Any]: 1338 args = [] 1339 for name in argnames_static: 1340 args.append(kwargs.pop(name)) 1341 for i in range(1, len(kwargs) + 1): 1342 if f"tangents_{i}" not in kwargs: 1343 break 1344 args.append(kwargs.pop(f"tangents_{i}")) 1345 assert not kwargs, f"leftover kwargs: {kwargs!r}" 1346 return args 1347 1348 if trace_fn is joint_fwd_bwd: 1349 # If inference mode is enabled during compilation, assume that we don't 1350 # want to match on any training graph patterns 1351 if torch.is_inference_mode_enabled(): 1352 return False 1353 1354 # TODO: Revisit the functionalize_rng_ops for lowmem dropout 1355 with functorch_config.patch(functionalize_rng_ops=False): 1356 requires_grad: List[bool] = [ 1357 isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs 1358 ] 1359 if search_fn_pattern is None: 1360 pattern = gen_pattern( 1361 search_fn, 1362 example_inputs, 1363 trace_fn, 1364 scalar_workaround, 1365 exclusive_arg_names, 1366 ) 1367 else: 1368 pattern = search_fn_pattern 1369 1370 pattern_repr = PatternPrettyPrinter.run(pattern) 1371 assert pattern_repr not in _seen_patterns 1372 _seen_patterns.add(pattern_repr) 1373 pattern = ReplacementPatternEntry( 1374 pattern=pattern, 1375 extra_check=check_fn, 1376 normalize_args=normalize_args, 1377 ) 1378 pattern.register(pass_dicts) 1379 return pattern.pattern 1380 1381 1382_serialized_patterns: Set[str] = set() 1383 1384 1385def _serialize_pattern( 1386 unique_name: str, 1387 search_fn: SearchFn, 1388 example_inputs: Iterable[Any], 1389 trace_fn: TraceFn, 1390 scalar_workaround: Union[Dict[str, Union[float, int]], None], 1391) -> PatternExpr: 1392 def get_file_template() -> str: 1393 auto_generated_msg = textwrap.dedent( 1394 """\ 1395 # This is an auto-generated file. Please do not modify it by hand. 1396 # To re-generate, run: 1397 # cd ~/pytorch && python torchgen/fuse/gen_patterns.py 1398 """ 1399 ) 1400 1401 file_template = textwrap.dedent( 1402 """\ 1403 # mypy: ignore-errors 1404 1405 # noqa: F401, E501 1406 {msg} 1407 import torch 1408 import torch._inductor 1409 1410 aten = torch.ops.aten 1411 prims = torch.ops.prims 1412 1413 """ 1414 ).format(msg=auto_generated_msg) 1415 1416 pattern_matcher_imports = [] 1417 for name in dir(torch._inductor.pattern_matcher): 1418 attr = getattr(torch._inductor.pattern_matcher, name) 1419 if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)): 1420 pattern_matcher_imports.append(name) 1421 1422 formatted_imports = ",\n ".join(pattern_matcher_imports) 1423 formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n" 1424 return f"{file_template}{formatted_imports}" 1425 1426 if not SERIALIZED_PATTERN_PATH.is_dir(): 1427 raise RuntimeError( 1428 f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}" 1429 ) 1430 1431 pattern_name = search_fn.__name__ 1432 1433 from torch._functorch import config as functorch_config 1434 1435 with functorch_config.patch(functionalize_rng_ops=False): 1436 pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) 1437 1438 serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name) 1439 if pattern_name not in _serialized_patterns: 1440 write_mode = "w" 1441 _serialized_patterns.add(pattern_name) 1442 else: 1443 write_mode = "a" 1444 1445 file_template = get_file_template() 1446 1447 with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f: 1448 if write_mode == "w": 1449 f.write(file_template) 1450 else: 1451 f.write("\n\n") 1452 f.write(serialized_pattern) 1453 f.write("\n") 1454 1455 return pattern 1456 1457 1458SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" 1459 1460# This is the set of serialized patterns that we've registered. Used by 1461# test_serialized_patterns_up_to_date() to ensure the patterns are up 1462# to date. 1463_known_precompiled_patterns: List[ 1464 Tuple[ 1465 Any, 1466 Iterable[Any], 1467 Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], 1468 Any, 1469 PatternExpr, 1470 ] 1471] = [] 1472 1473 1474def gen_register_replacement( 1475 unique_name: str, 1476 search_fn: SearchFn, 1477 replace_fn: ReplaceFn, 1478 example_inputs: Iterable[Any], 1479 trace_fn: TraceFn, 1480 pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], 1481 extra_check: Callable[[Match], bool] = _return_true, 1482 scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, 1483 exclusive_arg_names: Sequence[str] = (), 1484 skip_duplicates: bool = False, 1485) -> None: 1486 # Make sure the example_inputs is materialized. 1487 example_inputs = tuple(example_inputs) 1488 1489 if "PYTORCH_GEN_PATTERNS" in os.environ: 1490 pat = _serialize_pattern( 1491 unique_name, search_fn, example_inputs, trace_fn, scalar_workaround 1492 ) 1493 else: 1494 pattern_name = search_fn.__name__ 1495 m = importlib.import_module( 1496 f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}" 1497 ) 1498 if not m or not hasattr(m, unique_name): 1499 log.warning( 1500 "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", 1501 unique_name, 1502 ) 1503 pat = getattr(m, unique_name) 1504 1505 for arg in pytree.tree_iter(example_inputs): 1506 if isinstance(arg, FakeTensor) and arg.constant is not None: 1507 # This can be a problem - small fake tensors (e.g. `tensor(2)`) will 1508 # hold onto their original constant value - and by stashing it here 1509 # will cause a memory leak if the constant value is on GPU. 1510 # Since this is just an optimization we can clear it out. 1511 arg.constant = None 1512 1513 if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates: 1514 return 1515 _known_precompiled_patterns.append( 1516 (search_fn, example_inputs, trace_fn, scalar_workaround, pat) 1517 ) 1518 register_replacement( 1519 search_fn, 1520 replace_fn, 1521 example_inputs, 1522 trace_fn, 1523 pass_dicts, 1524 extra_check, 1525 scalar_workaround, 1526 exclusive_arg_names, 1527 search_fn_pattern=pat, 1528 ) 1529 1530 1531@functorch_config.patch(functionalize_rng_ops=False) 1532def gen_pattern( 1533 search_fn: SearchFn, 1534 example_inputs: Sequence[Any], 1535 trace_fn: TraceFn, 1536 scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, 1537 exclusive_arg_names: Sequence[str] = (), 1538) -> PatternExpr: 1539 argnames = [*inspect.signature(search_fn).parameters.keys()] 1540 1541 if scalar_workaround is None: 1542 scalar_workaround = {} 1543 flat_inputs = [] 1544 input_idx = 0 # Positional arguments index 1545 1546 for argname in argnames: 1547 if argname in scalar_workaround: 1548 flat_inputs.append(scalar_workaround[argname]) 1549 else: 1550 flat_inputs.append(example_inputs[input_idx]) 1551 input_idx += 1 1552 1553 search_gm = trace_fn(search_fn, flat_inputs) 1554 return fx_to_pattern( 1555 search_gm, 1556 ignore_types=(int, float, list, torch.device, torch.dtype), 1557 argnames=argnames, 1558 scalar_workaround=scalar_workaround, 1559 exclusive_arg_names=exclusive_arg_names, 1560 ) 1561 1562 1563def register_lowering_pattern( 1564 pattern: PatternExpr, 1565 extra_check: Callable[[Match], bool] = _return_true, 1566 *, 1567 pass_dict: _PassDictsType, 1568 prepend: bool = False, 1569) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 1570 """ 1571 Register an aten to inductor IR replacement pattern. The decorated 1572 function is saved and then called a lowering time allowing direct 1573 pattern to inductor IR conversion. 1574 """ 1575 1576 def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: 1577 assert callable(handler) 1578 LoweringPatternEntry( 1579 pattern=pattern, extra_check=extra_check, handler=handler 1580 ).register(pass_dict, prepend=prepend) 1581 handler._inductor_lowering_function = True # type: ignore[attr-defined] 1582 return handler 1583 1584 return decorator 1585 1586 1587def register_graph_pattern( 1588 pattern: PatternExpr, 1589 extra_check: Callable[[Match], bool] = _return_true, 1590 *, 1591 pass_dict: _PassDictsType, 1592 prepend: bool = False, 1593) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 1594 """ 1595 Register a pattern that runs a function on the FX graph, allowing 1596 custom transformation code. 1597 """ 1598 1599 def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: 1600 assert callable(handler) 1601 GraphPatternEntry( 1602 pattern=pattern, extra_check=extra_check, handler=handler 1603 ).register(pass_dict, prepend=prepend) 1604 return handler 1605 1606 return decorator 1607 1608 1609def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: 1610 # first node in the graph 1611 return node is next(iter(graph.nodes)) 1612 1613 1614# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc 1615# doesn't match: __rshift__, etc 1616_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)") 1617 1618 1619def is_mutation_op(node: torch.fx.Node) -> bool: 1620 if node.op == "call_function": 1621 if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] 1622 return True 1623 elif node.op == "call_method": 1624 if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] 1625 return True 1626 return node.kwargs.get("out") is not None 1627 1628 1629def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool: 1630 assert "mutation_region_id" in a.meta 1631 assert "mutation_region_id" in b.meta 1632 return a.meta["mutation_region_id"] == b.meta["mutation_region_id"] 1633 1634 1635def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: 1636 n = node 1637 while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): 1638 n = n.prev 1639 mutation_region_id = n.meta.get("mutation_region_id", 0) 1640 while n is not node: 1641 n = n.next 1642 if is_mutation_op(n): 1643 mutation_region_id += 1 1644 n.meta["mutation_region_id"] = mutation_region_id 1645 return mutation_region_id 1646 1647 1648def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: 1649 return "mutation_region_id" not in next(iter(graph.nodes)).meta 1650 1651 1652def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: 1653 mutation_region_id = 0 1654 for nd in graph.nodes: 1655 if is_mutation_op(nd): 1656 mutation_region_id += 1 1657 nd.meta["mutation_region_id"] = mutation_region_id 1658 1659 1660class PatternMatcherPass: 1661 def __init__( 1662 self, 1663 pass_name: Optional[str] = None, 1664 ) -> None: 1665 super().__init__() 1666 self.patterns: DefaultDict[ 1667 Tuple[str, torch.fx.node.Target], List[PatternEntry] 1668 ] = defaultdict(list) 1669 self.pass_name = pass_name 1670 1671 def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: 1672 return self.patterns[item] 1673 1674 def apply(self, gm: torch.fx.GraphModule) -> int: 1675 if not self.patterns: 1676 return 0 1677 if isinstance(gm, torch.fx.GraphModule): 1678 graph = gm.graph 1679 elif isinstance(gm, torch.fx.Graph): 1680 graph = gm 1681 gm = graph.owning_module 1682 else: 1683 raise RuntimeError( 1684 f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" 1685 ) 1686 if should_compute_mutation_region_ids(graph): # type: ignore[arg-type] 1687 compute_mutation_region_ids(graph) # type: ignore[arg-type] 1688 get_mutation_region_id_partial = functools.partial( 1689 get_mutation_region_id, graph 1690 ) 1691 count = 0 1692 nodes = [] 1693 has_call_module = False 1694 for op, target in self.patterns: 1695 if op == "call_module": 1696 has_call_module = True 1697 else: 1698 nodes.append(graph.find_nodes(op=op, target=target, sort=False)) 1699 if has_call_module: 1700 nodes.append(graph.find_nodes(op="call_module", sort=False)) 1701 pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" 1702 with GraphTransformObserver( 1703 gm, pass_name, trace_config.log_url_for_graph_xform 1704 ): 1705 for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): 1706 target = extract_target(node) 1707 if node.op == "call_module": 1708 if (node.op, target) not in self.patterns: 1709 continue 1710 1711 # conservatively not applying pattern for cpu input, 1712 # since some of the patterns induce codegen and split nodes. 1713 # Note: we will only skip cpu compute if disable_cpp_codegen=True 1714 if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): 1715 continue 1716 1717 for entry in self.patterns[(node.op, target)]: 1718 if node._erased: 1719 break 1720 m = entry.pattern.match(node) 1721 # pattern match crosses mutation barrier - discard 1722 if ( 1723 is_match(m) 1724 and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] 1725 ): 1726 continue 1727 if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: 1728 log.warning("%s%s %s %s", node, node.args, m, entry.pattern) 1729 if is_match(m) and entry.extra_check(m): 1730 count += 1 1731 entry.apply(m, graph, node) # type: ignore[arg-type] 1732 counters["inductor"]["pattern_matcher_count"] += 1 1733 counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) 1734 return count 1735 1736 def clear(self) -> None: 1737 self.patterns.clear() 1738 1739 1740def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: 1741 raise NotImplementedError 1742 1743 1744def fx_to_pattern( 1745 gm: Union[torch.fx.GraphModule, torch.fx.Graph], 1746 ignore_types: Sequence[Type[Any]] = (), 1747 argnames: Sequence[str] = (), 1748 scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, 1749 exclusive_arg_names: Sequence[str] = (), 1750) -> PatternExpr: 1751 """ 1752 Convert an FX graph into a PatternExpr. This is useful for simple 1753 patterns that can only match single functions and fixed-length lists. 1754 """ 1755 # scalar_workaround is a hack to capture dropout_p 1756 # see https://github.com/pytorch/pytorch/issues/97894 1757 scalar_workaround = scalar_workaround or {} 1758 inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} 1759 assert len(inv_scalar_workaround) == len(scalar_workaround) 1760 1761 def process_arg(x: T) -> Union[T, KeywordArg, Ignored]: 1762 if isinstance(x, (float, int)) and x in inv_scalar_workaround: 1763 return KeywordArg(inv_scalar_workaround[x]) 1764 if type(x) in ignore_types: 1765 return Ignored() 1766 if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: 1767 return Ignored() 1768 return x 1769 1770 argnum = itertools.count() 1771 1772 class Converter(torch.fx.Interpreter): 1773 call_method = _not_implemented 1774 call_module = _not_implemented 1775 get_attr = _not_implemented 1776 1777 def placeholder( 1778 self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] 1779 ) -> Union[ExclusiveKeywordArg, KeywordArg]: 1780 n = next(argnum) 1781 if n < len(argnames): 1782 name = argnames[n] 1783 elif argnames: 1784 assert target.startswith("tangent") 1785 name = target 1786 else: 1787 target = re.sub(r"_\d+$", "", target) # de-mangle arg name 1788 name = target 1789 if name in exclusive_arg_names: 1790 return ExclusiveKeywordArg(name) 1791 else: 1792 return KeywordArg(name) 1793 1794 def call_function( 1795 self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] 1796 ) -> PatternExpr: 1797 args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) 1798 if list in ignore_types: 1799 # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] 1800 args = [process_arg(a) for a in args] 1801 kwargs = {k: process_arg(a) for k, a in kwargs.items()} 1802 return CallFunction(target, *args, **kwargs) 1803 1804 def run_node(self, n: torch.fx.Node) -> Any: 1805 rv = super().run_node(n) 1806 if n.op == "output" and isinstance(rv, tuple): 1807 assert len(rv) == len(n.args[0]) # type: ignore[arg-type] 1808 for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] 1809 r.users = len(arg.users) 1810 else: 1811 rv.users = len(n.users) 1812 return rv 1813 1814 pattern = Converter(gm).run() # type: ignore[arg-type] 1815 if not isinstance(pattern, PatternExpr): 1816 return MultiOutputPattern(pytree.tree_leaves(pattern)) 1817 return pattern 1818 1819 1820@torch.no_grad() 1821def fwd_only( 1822 fn: Callable[..., Any], 1823 args: Sequence[Any], 1824 *, 1825 run_functional_passes: bool = True, 1826 get_decomp_fn: Optional[Callable[..., Any]] = None, 1827) -> torch.fx.GraphModule: 1828 """Build a normalized inference graph, for use with fx_to_pattern""" 1829 # TODO - look into using aot autograd, asserting no mutating ops here 1830 with enable_python_dispatcher(): 1831 decompositions = ( 1832 get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() 1833 ) 1834 gm = make_fx(fn, decompositions, tracing_mode="real")(*args) 1835 1836 from .fx_passes.post_grad import remove_noop_ops 1837 1838 if run_functional_passes: 1839 remove_noop_ops(gm.graph) 1840 gm.graph.eliminate_dead_code() 1841 1842 gm.recompile() 1843 return gm 1844 1845 1846@torch.enable_grad() 1847def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: 1848 """Build a normalized training graph, for use with fx_to_pattern""" 1849 gm: Optional[torch.fx.GraphModule] = None 1850 1851 def record_joint_graph( 1852 joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any 1853 ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: 1854 nonlocal gm 1855 assert not gm 1856 gm = clone_graph(joint_graph) 1857 return default_partition(joint_graph, inputs, **kwargs) 1858 1859 with torch._guards.tracing(None): 1860 aot_function( 1861 fn, 1862 lambda g, i: make_boxed_func(g), 1863 partition_fn=record_joint_graph, 1864 decompositions=select_decomp_table(), 1865 keep_inference_input_mutations=True, 1866 enable_log=False, 1867 )(*args) 1868 assert gm 1869 1870 from .fx_passes.post_grad import remove_noop_ops 1871 1872 remove_noop_ops(gm.graph) 1873 1874 from .fx_passes.joint_graph import pointless_view 1875 1876 matcher_pass = PatternMatcherPass() 1877 1878 pattern = CallFunction( 1879 torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") 1880 ) 1881 GraphPatternEntry( 1882 pattern=pattern, handler=pointless_view, extra_check=_return_true 1883 ).register(matcher_pass.patterns) 1884 matcher_pass.apply(gm.graph) # type: ignore[arg-type] 1885 1886 # remove in/out specs 1887 gm.graph._codegen = torch.fx.graph.CodeGen() 1888 gm.graph.eliminate_dead_code() 1889 gm.recompile() 1890 return gm 1891 1892 1893def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: 1894 args: List[torch.fx.node.Argument] = [] 1895 torch.fx.map_arg((n.args, n.kwargs), args.append) 1896 return args 1897 1898 1899def stable_topological_sort(graph: torch.fx.Graph) -> None: 1900 # Nodes are in exactly one of these three collections: 1901 1902 # - Nodes in `pending` are waiting to be processed (in reverse order): 1903 pending = list(reversed(graph.nodes)) 1904 1905 # - Nodes in `ready` have been processed and are already in the correct 1906 # order. 1907 ready = set() 1908 1909 # - `waiting` is a mapping from a dependency to nodes which depend on that 1910 # dependency. 1911 waiting = defaultdict(list) 1912 1913 # The cursor indicates the last processed node so we can add new nodes 1914 # after it. 1915 cursor = None 1916 while pending: 1917 node = pending.pop() 1918 waiting_for = [x for x in _args(node) if x not in ready] 1919 if waiting_for: 1920 # We have unprocessed input nodes. Might as well wait for the last 1921 # arg so an already sorted list will only recheck this node once. 1922 waiting[waiting_for[-1]].append(node) 1923 else: 1924 ready.add(node) 1925 if cursor and cursor.next is not node: 1926 cursor.append(node) 1927 cursor = node 1928 # Mark the nodes that have been waiting for this node to finish as 1929 # ready to check again. 1930 pending.extend(reversed(waiting.pop(node, ()))) 1931 1932 assert not waiting and len(ready) == len(graph.nodes) 1933 1934 1935def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: 1936 """Wrapper around lazy init functions in fx_passes/""" 1937 1938 @functools.lru_cache(None) 1939 @functools.wraps(fn) 1940 def lazy_init() -> Any: 1941 counters_ref = counters["inductor"].copy() 1942 1943 with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): 1944 result = fn() 1945 1946 # clear view matches encountered during tracing 1947 counters["inductor"] = counters_ref 1948 1949 return result 1950 1951 return lazy_init 1952 1953 1954def config_flag(name: str) -> Callable[[Match], Any]: 1955 """Function for extra_check to put pass behind a flag""" 1956 1957 def flag_check(match: Match) -> Any: 1958 return getattr(config, name) 1959 1960 return flag_check 1961 1962 1963def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: 1964 class CopyGraph(Transformer): 1965 def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: 1966 new_node = super().run_node(old_node) 1967 if isinstance(new_node, torch.fx.Proxy): 1968 new_node.node.meta.update(old_node.meta) 1969 new_node.node.name = self.new_graph._graph_namespace.create_name( 1970 old_node.name, None 1971 ) 1972 return new_node 1973 1974 return CopyGraph(input_graph).transform() 1975 1976 1977_seen_patterns: Set[str] = set() 1978 1979 1980def get_arg_value( 1981 node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None 1982) -> Any: 1983 return ( 1984 node.args[arg_number] 1985 if len(node.args) > arg_number 1986 else node.kwargs.get(kwarg_name) # type: ignore[arg-type] 1987 ) 1988 1989 1990def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: 1991 fns = [fn] 1992 if isinstance(fn, torch._ops.OpOverloadPacket): 1993 fns.extend([getattr(fn, overload) for overload in fn.overloads()]) 1994 1995 return [node for node in nodes if node.target in fns] 1996 1997 1998def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: 1999 """For call_function and call_method, we directly use the target function; 2000 For call_module, the target is string, and we treat the module class 2001 as a function. 2002 """ 2003 if node.op == "call_module": 2004 return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] 2005 return node.target 2006