xref: /aosp_15_r20/external/pytorch/torch/_inductor/pattern_matcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2"""
3# Inductor Pattern Matcher
4
5The pattern matcher enables search/replace within an FX graph.
6
7The main entrypoint to the pattern matcher is register_replacement(). Given a
8search function and a replacement function this will register a replacement with
9a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
10
11Internally the pattern matcher represents patterns as a graph (a DAG). Creating
12new patterns manually as a graph is cumbersome and error-prone so the standard
13way to create patterns (using register_replacement()) is to provide a search
14function and a replacement function which is traced and converted into a graph.
15
16Because the search functions are built somewhat generic (they tend to ignore
17tensor sizes, for example) register_replacement() allows you to specify an
18`extra_check` function which performs additional checks to verify that the
19matched pattern fully matches before returning it.
20
21## Precompiled Patterns
22
23New patterns are added using register_replacement(). Patterns added in this way
24can have a compile-time overhead because they need to be traced before
25use. Patterns can be precompiled and added using gen_register_replacement()
26instead. To do this you call gen_register_replacement() instead of
27register_replacement(). The arguments are the same except for an additional
28unique name which is used as a lookup key.
29
30## Internals
31
32The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
33implements a `_match` method which returns either a `Match` object for a
34successful match or a `FailedMatch` object for a failure to match.
35"""
36
37from __future__ import annotations
38
39import contextlib
40import dataclasses
41import functools
42import importlib
43import inspect
44import itertools
45import logging
46import operator
47import os
48import re
49import textwrap
50import typing
51from abc import ABC, abstractmethod
52from collections import defaultdict
53from pathlib import Path
54from typing import (
55    Any,
56    Callable,
57    DefaultDict,
58    Dict,
59    Generator,
60    Iterable,
61    List,
62    Mapping,
63    NoReturn,
64    Optional,
65    Protocol,
66    Sequence,
67    Set,
68    Tuple,
69    Type,
70    TypeVar,
71    Union,
72)
73from typing_extensions import Self, TypeGuard
74
75import torch
76import torch._guards
77import torch.fx
78import torch.utils._pytree as pytree
79from torch._dispatch.python import enable_python_dispatcher
80from torch._dynamo.utils import counters
81from torch._inductor.config import trace as trace_config
82from torch._prims_common import is_integer_dtype
83from torch._subclasses.fake_tensor import unset_fake_temporarily
84from torch.fx.experimental.proxy_tensor import make_fx
85from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
86from torch.fx.immutable_collections import immutable_dict, immutable_list
87from torch.fx.passes.graph_transform_observer import GraphTransformObserver
88
89from .._functorch import config as functorch_config
90from .._functorch.aot_autograd import aot_function, make_boxed_func
91from .._functorch.partitioners import default_partition
92from .._subclasses import FakeTensor, FakeTensorMode
93from ..fx import Transformer
94from . import config
95from .decomposition import select_decomp_table
96from .lowering import fallback_node_due_to_unsupported_type
97
98
99log = logging.getLogger(__name__)
100aten = torch.ops.aten
101prims = torch.ops.prims
102
103Constant = Any
104NodeOrConstant = Union[Constant, torch.fx.Node]
105
106
107class SearchFn(Protocol):
108    __name__: str
109
110    def __call__(self, *args: Any, **kwargs: Any) -> Any:
111        ...
112
113
114class ReplaceFn(Protocol):
115    def __call__(self, *args: Any, **kwargs: Any) -> Any:
116        ...
117
118
119class TraceFn(Protocol):
120    def __call__(
121        self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
122    ) -> torch.fx.GraphModule:
123        ...
124
125
126T = TypeVar("T")
127
128# What's a better name for this?
129FnsType = Union[torch.fx.node.Target, str]
130
131
132class Multiple:
133    def __init__(self) -> None:
134        # Ensure we're really a singleton.
135        assert "MULTIPLE" not in globals() or self is MULTIPLE
136
137
138# Sentinel indicating multiple quantities can be matched
139MULTIPLE = Multiple()
140
141
142class Match:
143    """
144    Represents a successfully matched pattern.
145
146    The `Match` object is returned to represent a successfully matched
147    pattern. Included in the Match are the pattern that was matched, the graph
148    nodes matched, and any args that were used during the matching.
149
150    The args and kwargs are specific to the type of pattern that was matched and
151    provide hints about what was matched.
152    """
153
154    pattern: PatternExpr
155    args: List[Any]
156    kwargs: Dict[str, Any]
157    nodes: List[torch.fx.Node]
158    targets: Dict[_TargetExpr, torch.fx.node.Target]
159    ctx: MatchContext
160    replacement_graph: Optional[torch.fx.Graph]
161
162    def __init__(
163        self,
164        ctx: MatchContext,
165        pattern: PatternExpr,
166        args: Optional[Sequence[Any]] = None,
167        kwargs: Optional[Dict[str, Any]] = None,
168    ) -> None:
169        super().__init__()
170        self.pattern = pattern
171        # The input nodes that must be passed in to the result
172        self.args = list(args or [])
173        self.kwargs = kwargs or {}
174        # The nodes matched in this expression
175        self.nodes = []
176        # Mapping CallFunction to the node.target
177        self.targets = {}
178        self.ctx = ctx
179        self.replacement_graph = None
180
181    @property
182    def graph(self) -> torch.fx.Graph:
183        return self.ctx.graph
184
185    def extend(self, other: Match) -> None:
186        if self.kwargs:
187            for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
188                if self.kwargs[key] != other.kwargs[key]:
189                    raise FailedMatch("kwarg mismatch: {}", key)
190        self.args.extend(other.args)
191        self.nodes.extend(other.nodes)
192        self.kwargs.update(other.kwargs)
193        self.targets.update(other.targets)
194
195    def bundle(self) -> Match:
196        # Wrap args in an extra list
197        self.args = [tuple(self.args)] if self.args else []
198        return self
199
200    def __repr__(self) -> str:
201        return f"Match(..., {self.args}, {self.kwargs})"
202
203    def erase_nodes(self) -> None:
204        graph = self.graph
205        for n in reversed(self.nodes):
206            if not n._erased and not n.users:
207                graph.erase_node(n)
208
209    def output_nodes(self) -> List[Optional[torch.fx.Node]]:
210        return [
211            (self.ctx.pattern_to_node[p] if p is not None else None)
212            for p in self.ctx.outputs
213        ]
214
215    def output_node(self) -> torch.fx.Node:
216        return next(p for p in self.output_nodes() if p)
217
218    def replace_with_graph(
219        self, replacement_graph: torch.fx.Graph, args: Sequence[Any]
220    ) -> None:
221        ReplacementPatternEntry.replace_with_graph(
222            self, self.ctx.graph, replacement_graph, args
223        )
224
225    def replace_by_example(
226        self,
227        replacement_fn: ReplaceFn,
228        args: Sequence[Any],
229        trace_fn: Optional[TraceFn] = None,
230        run_functional_passes: bool = True,
231    ) -> None:
232        """Replace with a graph generated by tracing the replacement_fn.
233
234        Args:
235            run_functional_passes (bool). If we should run passes that
236                assume functional IR (like DCE, remove_noop_ops), on the
237                replacement graph.
238
239        """
240        from torch._inductor.virtualized import NullHandler, V
241
242        context = (
243            V.fake_mode
244            if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
245            else contextlib.nullcontext()
246        )
247
248        with context:
249            if trace_fn is None:
250                trace_fn = functools.partial(
251                    fwd_only, run_functional_passes=run_functional_passes
252                )
253            replacement = trace_fn(
254                replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])  # type: ignore[arg-type]
255            )
256            ReplacementPatternEntry.replace_with_graph(
257                self,
258                self.ctx.graph,
259                replacement,
260                args,
261            )
262
263
264class FailedMatch(RuntimeError):
265    """
266    Represents a unsuccessful match.
267
268    The `FailedMatch` object is returned to represent a failure to match a
269    pattern.
270    """
271
272    format_string: str
273
274    def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
275        self.format_string = format_string
276        # We want to construct error messages lazily instead of eagerly, as
277        # constructing them eagerly can significantly worsen compile times.
278        if len(format_string) > 200:
279            raise RuntimeError(
280                f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
281            )
282        self.args = args
283        self.kwargs = kwargs
284
285    def __str__(self) -> str:
286        return self.format_string.format(*self.args, **self.kwargs)
287
288    def __bool__(self) -> bool:
289        return False
290
291
292MatchResult = Union[Match, FailedMatch]
293
294
295def is_match(m: MatchResult) -> TypeGuard[Match]:
296    """
297    TypeGuards cannot act on `self`. Thus this function exists to let mypy
298    recognize FailedMatch.__bool__ as a TypeGuard.
299    """
300    return bool(m)
301
302
303class MatchContext:
304    """
305    Internal state needed while running PatternExpr._match().
306    """
307
308    outputs: List[Optional[PatternExpr]]
309    pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
310    graph: torch.fx.Graph
311    exclusive_node_set: List[NodeOrConstant]
312
313    def __init__(
314        self,
315        outputs: List[Optional[PatternExpr]],
316        pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
317        *,
318        graph: torch.fx.Graph,
319    ) -> None:
320        self.outputs = outputs
321        self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node)
322        self.graph = graph
323        self.exclusive_node_set = []
324
325    def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult:
326        """wrapper to check reused nodes in patterns"""
327        if pattern in self.pattern_to_node:
328            if self.pattern_to_node[pattern] == node:
329                return Match(self, pattern)  # already checked this node
330            else:
331                return FailedMatch("repeated pattern differs")
332        m = pattern._match(node, self)
333        assert pattern not in self.pattern_to_node
334        self.pattern_to_node[pattern] = node if m else None
335        return m
336
337    def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
338        return {
339            pattern: node
340            for pattern, node in self.pattern_to_node.items()
341            if pattern.has_multiple_users() and node is not None
342        }
343
344
345class PatternExpr(ABC):
346    """
347    Base class for types of patterns.
348    """
349
350    @abstractmethod
351    def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
352        ...
353
354    def match(self, node: torch.fx.Node) -> MatchResult:
355        try:
356            return MatchContext([self], graph=node.graph).match(self, node)
357        except FailedMatch as e:
358            return e
359
360    def has_multiple_users(self) -> bool:
361        return False
362
363    def __repr__(self) -> str:
364        return self.__class__.__name__ + "()"
365
366    def find_anchor_nodes(
367        self, ctx: MatchContext, searched: Set[torch.fx.Node]
368    ) -> Generator[Optional[torch.fx.Node], None, None]:
369        if self in ctx.pattern_to_node:
370            yield ctx.pattern_to_node[self]
371
372    def pattern_eq(self, other: Any) -> bool:
373        """
374        Compare two `PatternExpr`s and return true if they are the
375        same. Note this is NOT matching a pattern - it is comparing the pattern
376        structures (for debugging).
377        """
378        return isinstance(other, self.__class__)
379
380
381class Arg(PatternExpr):
382    """
383    Capture an arg which will become an input to the handler.  Args are
384    passed in depth first order.
385    """
386
387    def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
388        return Match(ctx, self, args=[node])  # matches anything
389
390
391class Ignored(PatternExpr):
392    """
393    Match an arg, but don't pass it to handler
394    """
395
396    def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
397        return Match(ctx, self)  # matches anything
398
399    def __repr__(self) -> str:
400        return "*"
401
402    def pretty_print(self, pp: PatternPrettyPrinter) -> str:
403        return "Ignored()"
404
405
406class KeywordArg(PatternExpr):
407    """
408    Capture a kwarg which will become an input to the handler.
409    """
410
411    def __init__(self, name: str) -> None:
412        super().__init__()
413        self.name = name
414
415    def __repr__(self) -> str:
416        return f"KeywordArg({self.name!r})"
417
418    def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
419        return Match(ctx, self, kwargs={self.name: node})  # matches anything
420
421    def pattern_eq(self, other: Any) -> bool:
422        other = typing.cast(Self, other)  # super makes sure this is true
423        return super().pattern_eq(other) and self.name == other.name
424
425
426class ExclusiveKeywordArg(PatternExpr):
427    """
428    Capture a kwarg which will become an input to the handler.
429    """
430
431    name: str
432
433    def __init__(self, name: str) -> None:
434        super().__init__()
435        self.name = name
436
437    def __repr__(self) -> str:
438        return f"ExclusiveKeywordArg({self.name!r})"
439
440    def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
441        if node in ctx.exclusive_node_set:
442            return FailedMatch("exclusive arg appears twice")
443
444        ctx.exclusive_node_set.append(node)
445        return Match(ctx, self, kwargs={self.name: node})  # matches anything
446
447    def pattern_eq(self, other: Any) -> bool:
448        other = typing.cast(Self, other)  # super makes sure this is true
449        return super().pattern_eq(other) and self.name == other.name
450
451
452class _TargetExpr(PatternExpr):
453    """
454    Base class for filtering match by node.target
455    """
456
457    fns: List[FnsType]
458    fns_set: Set[FnsType]
459
460    def __init__(
461        self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1
462    ) -> None:
463        super().__init__()
464        fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
465        for fn in fns:
466            if isinstance(fn, torch._ops.OpOverloadPacket):
467                fns.extend(getattr(fn, overload) for overload in fn.overloads())
468
469        self.fns = fns
470        self.fns_set = set(fns)
471        self.users = users
472
473    @property
474    @abstractmethod
475    def op(self) -> str:
476        ...
477
478    def fns_repr(self) -> str:
479        first_repr = self.fns[0]
480        if not isinstance(first_repr, str):
481            first_repr = first_repr.__name__
482
483        if len(self.fns) > 1:
484            return f"[{first_repr}, ...]"
485        elif self.fns[0] is getattr(torch, first_repr, None):
486            return f"torch.{first_repr}"
487        elif isinstance(self.fns[0], torch._ops.OpOverload):
488            return str(self.fns[0])
489        else:
490            return first_repr
491
492    def __repr__(self) -> str:
493        if self.users is MULTIPLE:
494            comma_users = ", MULTIPLE"
495        elif self.users != 1:
496            comma_users = f", {self.users})"
497        else:
498            comma_users = ""
499        return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})"
500
501    def has_multiple_users(self) -> bool:
502        return isinstance(self.users, Multiple) or self.users > 1
503
504    def find_anchor_nodes(
505        self, ctx: MatchContext, searched: Set[torch.fx.Node]
506    ) -> Generator[Optional[torch.fx.Node], None, None]:
507        raise NotImplementedError
508
509    def _match_fns(self, node: torch.fx.Node) -> bool:
510        return (
511            isinstance(node, torch.fx.Node)
512            and node.op == self.op
513            and extract_target(node) in self.fns_set
514        )
515
516    def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool:
517        return (
518            self in ctx.outputs
519            or self.users is MULTIPLE
520            or len(node.users) == self.users
521        )
522
523    def pattern_eq(self, other: Any) -> bool:
524        other = typing.cast(Self, other)  # super makes sure this is true
525        return (
526            super().pattern_eq(other)
527            and self.op == other.op
528            and self.fns == other.fns
529            and self.users == other.users
530        )
531
532
533_SimpleSpec = Tuple[Any, ...]
534
535
536class _TargetArgsExpr(_TargetExpr):
537    """
538    Base class for filtering match by node.{target,args,kwargs}
539    """
540
541    def __init__(
542        self,
543        fns: Union[torch.fx.node.Target, str, Sequence[Any]],
544        *args: Any,
545        _users: Union[int, Multiple] = 1,
546        **kwargs: Any,
547    ) -> None:
548        super().__init__(fns, _users)
549        self.args = tuple(args)
550        self.kwargs = dict(kwargs)
551        if any(
552            isinstance(x, (dict, list, tuple))
553            for x in itertools.chain(args, kwargs.values())
554        ):
555            self.flatten = self.pytree_flatten
556        else:
557            self.flatten = self.simple_flatten
558        self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
559
560    @staticmethod
561    def simple_flatten(
562        args: Sequence[Any], kwargs: Mapping[Any, Any]
563    ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
564        values = (*args, *kwargs.values())
565        spec = (len(args), *kwargs.keys())
566        return values, spec
567
568    @staticmethod
569    def pytree_flatten(
570        args: Sequence[Any], kwargs: Mapping[Any, Any]
571    ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
572        def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
573            if s.type is None:
574                return s
575            mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
576            return pytree.TreeSpec(
577                mapping.get(s.type, s.type),
578                s.context,
579                list(map(norm_spec, s.children_specs)),
580            )
581
582        flat, spec = pytree.tree_flatten([args, kwargs])
583        spec = norm_spec(spec)
584        return flat, spec
585
586    def __repr__(self) -> str:
587        args = [
588            self.fns_repr(),
589            *map(repr, self.args),
590            *[f"{k}={v}" for k, v in self.kwargs.items()],
591        ]
592        if self.users is MULTIPLE:
593            args.append("_users=MULTIPLE")
594        elif self.users != 1:
595            args.append(f"_users={self.users}")
596        return f"{self.__class__.__name__}({', '.join(args)})"
597
598    def pretty_print(self, pp: PatternPrettyPrinter) -> str:
599        args = [
600            self.fns_repr(),
601            *(pp.pretty_print(x) for x in self.args),
602            *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
603        ]
604        if self.users is MULTIPLE:
605            args.append("_users=MULTIPLE")
606        elif self.users != 1:
607            args.append(f"_users={self.users}")
608
609        joiner_str = ", "
610        return f"{self.__class__.__name__}({joiner_str.join(args)})"
611
612    def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
613        if not self._match_fns(node) or len(node.args) != len(self.args):
614            return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
615
616        if not self._match_users(node, ctx):
617            return FailedMatch("multiple_users {}", self)
618
619        _args = node.args
620        _kwargs = node.kwargs
621        if len(_kwargs) < len(self.kwargs):
622            from torch.fx.operator_schemas import normalize_function
623
624            normalized_args_and_kwargs = normalize_function(
625                node.target, node.args, node.kwargs  # type: ignore[arg-type]
626            )
627
628            if normalized_args_and_kwargs is None:
629                return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
630            else:
631                _args, _kwargs = normalized_args_and_kwargs
632                if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
633                    _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
634                else:
635                    return FailedMatch(
636                        "function_mismatch: node={}, pattern={}", node, self
637                    )
638        else:
639            _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
640
641        node_items, node_spec = self.flatten(_args, _kwargs)
642        self_items, self_spec = self.flat_args_kwargs
643        if node_spec != self_spec:
644            return FailedMatch("args_structure {} {}", node_spec, self_spec)
645        assert len(node_items) == len(self_items)
646
647        m = Match(ctx, self)
648        for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
649            if isinstance(pattern, PatternExpr):
650                child_match = ctx.match(pattern, child_node)
651                if not is_match(child_match):
652                    return child_match
653                m.extend(child_match)
654            elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
655                return FailedMatch(
656                    "constant_args: {} {!r}!={pattern!r}", node, child_node
657                )
658        m.nodes.append(node)
659        m.targets[self] = node.target
660        return m
661
662    def find_anchor_nodes(
663        self, ctx: MatchContext, searched: Set[torch.fx.Node]
664    ) -> Generator[Optional[torch.fx.Node], None, None]:
665        """
666        This is used when we are matching a pattern with multiple outputs.
667        There is a partial match (stored in ctx) and we want to walk
668        this pattern to find a connection to an already-matched node.
669
670        Yields candidate nodes that `self._match` might like.
671        """
672        if self in ctx.pattern_to_node:
673            yield ctx.pattern_to_node[self]
674            return
675
676        for pattern in self.flat_args_kwargs[0]:
677            if isinstance(pattern, PatternExpr):
678                for other_node in pattern.find_anchor_nodes(ctx, searched):
679                    if not isinstance(other_node, torch.fx.Node):
680                        continue
681                    for node in other_node.users:
682                        if node not in searched:
683                            if self._match_fns(node):
684                                yield node
685                                searched.add(node)
686
687    def pattern_eq(self, other: Any) -> bool:
688        other = typing.cast(Self, other)  # super makes sure this is true
689        return (
690            super().pattern_eq(other)
691            and self.flat_args_kwargs[1] == other.flat_args_kwargs[1]
692            and all(
693                a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
694                for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0])
695            )
696        )
697
698
699class CallFunction(_TargetArgsExpr):
700    """
701    Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
702    """
703
704    op = "call_function"
705
706
707class CallMethod(_TargetArgsExpr):
708    """
709    Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
710    """
711
712    op = "call_method"
713
714
715class CallModule(_TargetArgsExpr):
716    """
717    Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
718    """
719
720    op = "call_module"
721
722
723class _TargetExprVarArgs(_TargetExpr):
724    """
725    Matches a call_function node with any arguments which are passed into the pattern
726    """
727
728    def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
729        if not self._match_fns(node):
730            return FailedMatch("function_mismatch")
731
732        if not self._match_users(node, ctx):
733            return FailedMatch("multiple_users")
734
735        m = Match(ctx, self)
736        m.nodes.append(node)
737        m.targets[self] = node.target
738        m.args.extend(node.args)
739        m.kwargs.update(node.kwargs)
740        return m
741
742
743class CallFunctionVarArgs(_TargetExprVarArgs):
744    op = "call_function"
745
746
747class CallMethodVarArgs(_TargetExprVarArgs):
748    op = "call_method"
749
750
751class CallModuleVarArgs(_TargetExprVarArgs):
752    op = "call_module"
753
754
755class ListOf(PatternExpr):
756    """
757    Matches a repeated pattern
758    """
759
760    def __init__(self, pattern: PatternExpr, partial: bool = False) -> None:
761        super().__init__()
762        assert isinstance(pattern, PatternExpr)
763        self.pattern = pattern
764        self.partial = partial
765
766    def __repr__(self) -> str:
767        return f"{self.__class__.__name__}({self.pattern})"
768
769    def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult:  # type: ignore[override]
770        if not isinstance(node, (list, tuple)) or len(node) == 0:
771            return FailedMatch("non_list")
772        m = Match(ctx, self)
773        # Propagating patterns with multiple users will ensure we don't revisit
774        # the same nodes
775        pattern_to_node = ctx.filter_multi_user_patterns()
776        matched = False
777        for i, child_node in enumerate(node):
778            child_ctx = MatchContext(
779                ctx.outputs, pattern_to_node, graph=child_node.graph
780            )
781            child_match = child_ctx.match(self.pattern, child_node)
782            pattern_to_node = child_ctx.filter_multi_user_patterns()
783            if not is_match(child_match):
784                if not self.partial:
785                    return FailedMatch("list[{}]: {}", i, child_match)
786                continue
787            matched = True
788            m.extend(child_match.bundle())
789        if not matched:
790            return FailedMatch("list: no_match")
791        return m.bundle()
792
793    def pattern_eq(self, other: Any) -> bool:
794        other = typing.cast(Self, other)  # super makes sure this is true
795        return (
796            super().pattern_eq(other)
797            and self.pattern.pattern_eq(other.pattern)
798            and self.partial == other.partial
799        )
800
801
802class MultiOutputPattern(PatternExpr):
803    outputs: List[Optional[PatternExpr]]
804
805    def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
806        super().__init__()
807        assert isinstance(outputs[0], _TargetExpr)
808        assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs
809        self.outputs = list(outputs)
810        self.op = outputs[0].op
811
812    @property
813    def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]:
814        # This cast is checked above in __init__()
815        output = typing.cast(_TargetExpr, self.outputs[0])
816        return output.fns
817
818    def __repr__(self) -> str:
819        return f"{self.__class__.__name__}({self.outputs})"
820
821    def pretty_print(self, pp: PatternPrettyPrinter) -> str:
822        args = [pp.pretty_print(x) for x in self.outputs]
823        joiner_str = f",\n{'  '}"
824        str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
825        str_out = f"{str_out}\n])"
826        return str_out
827
828    def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
829        output = typing.cast(_TargetExpr, self.outputs[0])
830        m = ctx.match(output, node)
831        if not is_match(m):
832            return m
833
834        for pattern in self.outputs[1:]:
835            if pattern is None:
836                continue
837            child_match = self._match_from_anchors(pattern, ctx)
838            if not is_match(child_match):
839                return child_match
840            m.extend(child_match)
841
842        return m
843
844    def _match_from_anchors(
845        self, pattern: PatternExpr, ctx: MatchContext
846    ) -> MatchResult:
847        prior = dict(ctx.pattern_to_node)
848        m: MatchResult = FailedMatch("no anchor found")
849        for node in pattern.find_anchor_nodes(ctx, set()):
850            m = ctx.match(pattern, node)
851            if is_match(m):
852                return m
853            # revert any partial matches
854            ctx.pattern_to_node = dict(prior)
855        return m
856
857    def match(self, node: torch.fx.Node) -> MatchResult:
858        try:
859            return MatchContext(self.outputs, graph=node.graph).match(self, node)
860        except FailedMatch as e:
861            return e
862
863    def pattern_eq(self, other: Any) -> bool:
864        other = typing.cast(Self, other)  # super makes sure this is true
865        return (
866            super().pattern_eq(other)
867            and len(self.outputs) == len(other.outputs)
868            and all(
869                a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
870                for a, b in zip(self.outputs, other.outputs)
871            )
872        )
873
874
875class RepeatedExpr(PatternExpr):
876    """
877    Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
878    """
879
880    def __init__(self, inner_pattern: _TargetExpr) -> None:
881        super().__init__()
882        self.inner_pattern = inner_pattern
883        self.op = inner_pattern.op
884
885    @property
886    def fns(self) -> Sequence[FnsType]:
887        return self.inner_pattern.fns
888
889    def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
890        m = ctx.match(self.inner_pattern, node)
891        if not is_match(m):
892            return m
893        ctx.pattern_to_node.pop(
894            self.inner_pattern,
895        )
896        # Check all anchor nodes match the pattern
897        for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
898            anchor_m = MatchContext([self], graph=node.graph).match(
899                self.inner_pattern, anchor_node
900            )
901            if not is_match(anchor_m):
902                return anchor_m
903            m.extend(anchor_m)
904        return m
905
906    def pattern_eq(self, other: Any) -> bool:
907        other = typing.cast(Self, other)  # super makes sure this is true
908        return super().pattern_eq(other) and self.inner_pattern.pattern_eq(
909            other.inner_pattern
910        )
911
912
913class PatternPrettyPrinter:
914    """
915    Serializes Patterns to executable python.
916    XXX: currently only used and tested for fuse attention patterns. May not cover
917    all patterns.
918    """
919
920    def __init__(self) -> None:
921        self.namespace = torch.fx.graph._Namespace()
922        self.memoized_objs_names: Dict[PatternExpr, str] = {}
923        self.memoized_objs_pp: Dict[PatternExpr, str] = {}
924
925    @staticmethod
926    @functools.lru_cache(None)
927    def run(obj: PatternExpr, output_name: str = "output") -> str:
928        """
929        Serializes obj to python code with obj written out to `output_name`
930        """
931
932        pp = PatternPrettyPrinter()
933        assert hasattr(obj, "pretty_print")
934        out_str = obj.pretty_print(pp=pp)
935
936        output = []
937        for key in pp.memoized_objs_names:
938            output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
939
940        output.append(f"{output_name} = {out_str}")
941
942        return "\n".join(output)
943
944    def pretty_print(self, obj: Any) -> str:
945        if isinstance(obj, _TargetArgsExpr):
946            if memoized_name := self.memoized_objs_names.get(obj):
947                return memoized_name
948            else:
949                return self.memoize(obj)
950        if hasattr(obj, "pretty_print"):
951            return obj.pretty_print(self)
952
953        return repr(obj)
954
955    def memoize(self, obj: _TargetArgsExpr) -> str:
956        obj_str = obj.pretty_print(self)
957        obj_name = obj.fns_repr()
958        for prefix in ("aten.", "torch.", "prims."):
959            obj_name = obj_name.replace(prefix, "")
960
961        tmp_name = self.namespace.create_name(obj_name, None)
962        self.memoized_objs_names[obj] = tmp_name
963        self.memoized_objs_pp[obj] = obj_str
964        return tmp_name
965
966
967class _PassDictsType(Protocol):
968    def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
969        ...
970
971
972@dataclasses.dataclass
973class PatternEntry:
974    pattern: PatternExpr
975    extra_check: Callable[[Match], bool]
976
977    def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
978        raise NotImplementedError
979
980    def register(
981        self,
982        pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
983        target: Union[torch.fx.node.Target, None] = None,
984        prepend: bool = False,
985    ) -> None:
986        if target is None:
987            assert hasattr(self.pattern, "fns")
988            for fn in self.pattern.fns:
989                self.register(pass_dicts, fn, prepend=prepend)
990        elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
991            assert hasattr(self.pattern, "op")
992            if prepend:
993                pass_dicts[(self.pattern.op, target)].insert(0, self)
994            else:
995                pass_dicts[(self.pattern.op, target)].append(self)
996        else:
997            pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts)
998            for x in pass_dicts:
999                self.register(x, target, prepend=prepend)
1000
1001
1002@dataclasses.dataclass
1003class LoweringPatternEntry(PatternEntry):
1004    handler: Callable[..., Any]
1005
1006    def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1007        handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
1008        with graph.inserting_before(node):
1009            replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
1010            replacement.meta.update(node.meta)
1011            node.replace_all_uses_with(replacement)
1012        assert match.nodes[-1] is node
1013        match.erase_nodes()
1014
1015
1016@dataclasses.dataclass
1017class GraphPatternEntry(PatternEntry):
1018    """
1019    A pattern that runs a function on the FX graph
1020    """
1021
1022    handler: Callable[..., Any]
1023
1024    def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1025        with graph.inserting_before(node):
1026            self.handler(match, *match.args, **match.kwargs)
1027
1028
1029@dataclasses.dataclass
1030class ReplacementPatternEntry(PatternEntry):
1031    normalize_args: Callable[..., List[Any]]
1032
1033    @staticmethod
1034    def replace_with_graph(
1035        match: Match,
1036        graph: torch.fx.Graph,
1037        replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule],
1038        args: Sequence[torch.fx.Node],
1039    ) -> None:
1040        class Replacer(torch.fx.Interpreter):
1041            call_method = None  # type: ignore[assignment]
1042            call_module = None  # type: ignore[assignment]
1043            get_attr = None  # type: ignore[assignment]
1044
1045            def run_node(self, node: torch.fx.Node) -> Any:
1046                if node.op in ("placeholder", "output"):
1047                    return super().run_node(node)
1048                if node.op == "call_function":
1049                    target = node.target
1050                    args, kwargs = self.fetch_args_kwargs_from_env(node)
1051                    result = graph.call_function(target, args, kwargs)  # type: ignore[arg-type]
1052                    if "val" in node.meta and "val" not in result.meta:
1053                        result.meta["val"] = node.meta["val"]
1054                        if isinstance(node.meta["val"], torch.Tensor):
1055                            assert "tensor_meta" in node.meta
1056                            result.meta["tensor_meta"] = node.meta["tensor_meta"]
1057                    return result
1058                raise NotImplementedError(f"unhandled {node}")
1059
1060        output_nodes = match.output_nodes()
1061
1062        if len(output_nodes) == 1:
1063            last_node = output_nodes[0]
1064        else:
1065            assert output_nodes[0]
1066            nodes = list(output_nodes[0].graph.nodes)
1067            indices = [
1068                (nodes.index(n), n)
1069                for n in output_nodes
1070                if isinstance(n, torch.fx.Node)
1071            ]
1072            last_node = min(indices, key=operator.itemgetter(0))[1]
1073
1074        def percolate_tags(
1075            node: torch.fx.Node,
1076            tag_name: str,
1077            tag_value: str,
1078            input_stops: Set[torch.fx.Node],
1079        ) -> None:
1080            queue = [node]
1081            visited = set()
1082
1083            while queue:
1084                arg = queue.pop()
1085                if (
1086                    arg not in visited
1087                    and arg not in input_stops
1088                    and hasattr(arg, "meta")
1089                ):
1090                    visited.add(arg)
1091                    arg.meta[tag_name] = tag_value
1092                    queue.extend(arg.all_input_nodes)
1093
1094        with graph.inserting_before(last_node):
1095            replacement = Replacer(replacement_graph).run(*args)  # type: ignore[arg-type]
1096            if isinstance(replacement, torch.fx.Node):
1097                replacement = [replacement]
1098
1099            def maybe_getitem(node: torch.fx.Node) -> Any:
1100                if node.op != "call_function":
1101                    return None
1102                if node.target != operator.getitem:
1103                    return None
1104                assert len(node.args) == 2
1105                return node.args[1]
1106
1107            def replace(
1108                old: Union[torch.fx.Node, None],
1109                new: Union[torch.fx.Node, Sequence[torch.fx.Node], None],
1110            ) -> None:
1111                if old is None:
1112                    assert new is None
1113                    return
1114                assert isinstance(old, torch.fx.Node)
1115                if new is None:
1116                    old.replace_all_uses_with(None)  # type: ignore[arg-type]
1117                    graph.erase_node(old)
1118                    return
1119                if isinstance(new, torch.fx.Node):
1120                    if "val" not in new.meta:
1121                        new.meta.update(old.meta)
1122
1123                    # Preserve the recompute tags in the replacement graph. We
1124                    # look at the recompute tags of the original output node to
1125                    # propagate the tag from the output all the way to the input
1126                    # args (named as args in the replace_with_graph).
1127                    # Note that this is best effort. Since patterns are from
1128                    # many to many, there is no easy way to correctly map the
1129                    # recomputable tags. It is possible in some scenarios that we
1130                    # incorrectly tag some nodes as recomputables.
1131                    for tag_name in ["recompute", "ac_graph_id"]:
1132                        if tag_name in old.meta:
1133                            percolate_tags(new, tag_name, old.meta[tag_name], set(args))
1134
1135                    old.replace_all_uses_with(new)
1136                    graph.erase_node(old)
1137                    return
1138
1139                # `new` is not a node: it's a list of nodes.
1140                #
1141                # This happens when we want to replace a node that has a single
1142                # packed return with multiple unpacked returns. We need to do
1143                # some graph surgery here.
1144                #
1145                # Example:
1146                #   def original_graph(x):
1147                #      a = op(x)
1148                #      b = a[0]
1149                #      c = a[1]
1150                #      ...
1151                #
1152                # Assume that we want to replace op(x) with the graph
1153                #   def new_op(x):
1154                #      w = x + 1
1155                #      z = x + 2
1156                #      return (w, z)
1157                #
1158                # We need to replace `op` with the contents of `new_op`,
1159                # and then rewrite a[0] to be w and a[1] to be z, as so:
1160                #   def new_graph(x):
1161                #     w = x + 1
1162                #     z = x + 2
1163                #     b = w
1164                #     c = z
1165                #     ...
1166                old_uses = list(old.users.keys())
1167                for user in old_uses:
1168                    idx = maybe_getitem(user)
1169                    if idx is None:
1170                        raise AssertionError("can't handle")
1171                    replace(user, new[idx])  # type: ignore[index]
1172                graph.erase_node(old)
1173
1174            if len(output_nodes) == len(replacement):
1175                for old, new in zip(output_nodes, replacement):
1176                    replace(old, new)
1177            else:
1178                assert len(output_nodes) == 1
1179                replace(output_nodes[0], replacement)
1180
1181        match.erase_nodes()
1182
1183    def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1184        assert match.replacement_graph is not None
1185        self.replace_with_graph(
1186            match,
1187            graph,
1188            match.replacement_graph,
1189            self.normalize_args(*match.args, **match.kwargs),
1190        )
1191
1192
1193def _return_true(match: Match) -> bool:
1194    return True
1195
1196
1197def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
1198    log.info(
1199        "Replacement pattern %s failed to apply due to shape mismatch: %s",
1200        search_fn.__name__,
1201        e,
1202    )
1203
1204
1205def register_replacement(
1206    search_fn: SearchFn,
1207    replace_fn: ReplaceFn,
1208    example_inputs: Iterable[Any],
1209    trace_fn: TraceFn,
1210    pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
1211    extra_check: Callable[[Match], bool] = _return_true,
1212    scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1213    exclusive_arg_names: Sequence[str] = (),
1214    search_fn_pattern: Union[PatternExpr, None] = None,
1215) -> bool:
1216    """
1217    Create a replacement rule based on example functions that get traced
1218    to create patterns.  This supports both training and inference when
1219    run on a joint forward+backward graph.
1220
1221    Args:
1222        search_fn: traced to give original pattern
1223        replace_fn: traced to give replacement graph
1224        example_inputs: example inputs for initial trace
1225        trace_fn: fwd_only or joint_fwd_bwd
1226        pass_dict: dict of passes to register to
1227        extra_check: additional check to run on match(using real shapes)
1228    """
1229    argnames_static = [*inspect.signature(search_fn).parameters.keys()]
1230
1231    def check_fn(match: Match) -> bool:
1232        """
1233        Often shapes get burned into the pattern, so our initial match ran with
1234        `ignore_types=(int, ...)`.
1235
1236        Recheck the match with the correct shapes.
1237        """
1238        argnames = list(argnames_static)
1239        for name in argnames:
1240            if name not in match.kwargs:
1241                raise RuntimeError(
1242                    f"Not all inputs to pattern found in match.kwargs. Perhaps one "
1243                    f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
1244                )
1245
1246        args = list(
1247            torch.fx.map_arg(  # type: ignore[arg-type]
1248                [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
1249            )
1250        )
1251        sym_args: List[torch.SymInt] = []
1252        with torch._dynamo.utils.detect_fake_mode(args):
1253            for i, grad in enumerate(requires_grad):
1254                if isinstance(args[i], torch.Tensor):
1255                    if grad and is_integer_dtype(args[i].dtype):
1256                        return False
1257
1258                    args[i] = torch.empty_strided(
1259                        args[i].size(),
1260                        args[i].stride(),
1261                        dtype=args[i].dtype,
1262                        device=args[i].device,
1263                        requires_grad=grad,
1264                    )
1265                    for v in itertools.chain(args[i].shape, args[i].stride()):
1266                        if isinstance(v, torch.SymInt) and all(
1267                            guard_size_oblivious(v != a) for a in sym_args
1268                        ):
1269                            sym_args.append(v)
1270
1271            # If we were given a pre-traced pattern then use that instead of
1272            # retracing. Note that this means the pattern has to be independent
1273            # of its args.
1274            specific_pattern = search_fn_pattern
1275
1276            if not specific_pattern:
1277                if sym_args:
1278                    # AOT Autograd and make fx will dedupe symbolic shape size
1279                    # accesses of sym ints that appear as inputs
1280                    # We don't want the sym_size uses to interfere with pattern matching
1281                    # so we provide them as inputs.
1282                    # Later, when we actually do the replacement, the symbolic shape
1283                    # sizes will get re-traced and added to the graph.
1284
1285                    def search_fn_new(*args_new: Any) -> Any:
1286                        return search_fn(*args_new[len(args_new) - len(args) :])
1287
1288                    try:
1289                        specific_graph = trace_fn(search_fn_new, sym_args + args)
1290                    except RuntimeError as e:
1291                        log_trace_failure(search_fn, e)
1292                        return False
1293
1294                    # correct argnames in the graph
1295                    sym_arg_names = []
1296                    for i, placeholder in zip(
1297                        range(len(sym_args) + len(args)),
1298                        specific_graph.graph.nodes,
1299                    ):
1300                        if i < len(sym_args):
1301                            sym_arg_names.append(placeholder.target)
1302                            continue
1303
1304                        with specific_graph.graph.inserting_after(placeholder):
1305                            new_node = specific_graph.graph.placeholder(
1306                                argnames[i - len(sym_args)]
1307                            )
1308                            new_node.target = new_node.name
1309                            placeholder.replace_all_uses_with(new_node)
1310                            specific_graph.graph.erase_node(placeholder)
1311
1312                    argnames = sym_arg_names + argnames
1313                else:
1314                    try:
1315                        specific_graph = trace_fn(search_fn, args)
1316                    except RuntimeError as e:
1317                        log_trace_failure(search_fn, e)
1318                        return False
1319
1320                specific_pattern = fx_to_pattern(
1321                    specific_graph,
1322                    argnames=argnames,
1323                    exclusive_arg_names=exclusive_arg_names,
1324                    scalar_workaround=scalar_workaround,
1325                )
1326
1327            node = match.output_nodes()[0]
1328            assert node is not None
1329            specific_pattern_match = specific_pattern.match(node)
1330
1331            if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
1332                # trace the pattern using the shapes from the user program
1333                match.replacement_graph = trace_fn(replace_fn, args)  # type: ignore[assignment]
1334                return True
1335            return False
1336
1337    def normalize_args(**kwargs: Any) -> List[Any]:
1338        args = []
1339        for name in argnames_static:
1340            args.append(kwargs.pop(name))
1341        for i in range(1, len(kwargs) + 1):
1342            if f"tangents_{i}" not in kwargs:
1343                break
1344            args.append(kwargs.pop(f"tangents_{i}"))
1345        assert not kwargs, f"leftover kwargs: {kwargs!r}"
1346        return args
1347
1348    if trace_fn is joint_fwd_bwd:
1349        # If inference mode is enabled during compilation, assume that we don't
1350        # want to match on any training graph patterns
1351        if torch.is_inference_mode_enabled():
1352            return False
1353
1354    # TODO: Revisit the functionalize_rng_ops for lowmem dropout
1355    with functorch_config.patch(functionalize_rng_ops=False):
1356        requires_grad: List[bool] = [
1357            isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
1358        ]
1359        if search_fn_pattern is None:
1360            pattern = gen_pattern(
1361                search_fn,
1362                example_inputs,
1363                trace_fn,
1364                scalar_workaround,
1365                exclusive_arg_names,
1366            )
1367        else:
1368            pattern = search_fn_pattern
1369
1370        pattern_repr = PatternPrettyPrinter.run(pattern)
1371        assert pattern_repr not in _seen_patterns
1372        _seen_patterns.add(pattern_repr)
1373        pattern = ReplacementPatternEntry(
1374            pattern=pattern,
1375            extra_check=check_fn,
1376            normalize_args=normalize_args,
1377        )
1378        pattern.register(pass_dicts)
1379        return pattern.pattern
1380
1381
1382_serialized_patterns: Set[str] = set()
1383
1384
1385def _serialize_pattern(
1386    unique_name: str,
1387    search_fn: SearchFn,
1388    example_inputs: Iterable[Any],
1389    trace_fn: TraceFn,
1390    scalar_workaround: Union[Dict[str, Union[float, int]], None],
1391) -> PatternExpr:
1392    def get_file_template() -> str:
1393        auto_generated_msg = textwrap.dedent(
1394            """\
1395            # This is an auto-generated file. Please do not modify it by hand.
1396            # To re-generate, run:
1397            # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
1398            """
1399        )
1400
1401        file_template = textwrap.dedent(
1402            """\
1403            # mypy: ignore-errors
1404
1405            # noqa: F401, E501
1406            {msg}
1407            import torch
1408            import torch._inductor
1409
1410            aten = torch.ops.aten
1411            prims = torch.ops.prims
1412
1413            """
1414        ).format(msg=auto_generated_msg)
1415
1416        pattern_matcher_imports = []
1417        for name in dir(torch._inductor.pattern_matcher):
1418            attr = getattr(torch._inductor.pattern_matcher, name)
1419            if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
1420                pattern_matcher_imports.append(name)
1421
1422        formatted_imports = ",\n   ".join(pattern_matcher_imports)
1423        formatted_imports = f"from torch._inductor.pattern_matcher import (\n   {formatted_imports},\n)\n"
1424        return f"{file_template}{formatted_imports}"
1425
1426    if not SERIALIZED_PATTERN_PATH.is_dir():
1427        raise RuntimeError(
1428            f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}"
1429        )
1430
1431    pattern_name = search_fn.__name__
1432
1433    from torch._functorch import config as functorch_config
1434
1435    with functorch_config.patch(functionalize_rng_ops=False):
1436        pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)
1437
1438    serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name)
1439    if pattern_name not in _serialized_patterns:
1440        write_mode = "w"
1441        _serialized_patterns.add(pattern_name)
1442    else:
1443        write_mode = "a"
1444
1445    file_template = get_file_template()
1446
1447    with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f:
1448        if write_mode == "w":
1449            f.write(file_template)
1450        else:
1451            f.write("\n\n")
1452        f.write(serialized_pattern)
1453        f.write("\n")
1454
1455    return pattern
1456
1457
1458SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns"
1459
1460# This is the set of serialized patterns that we've registered.  Used by
1461# test_serialized_patterns_up_to_date() to ensure the patterns are up
1462# to date.
1463_known_precompiled_patterns: List[
1464    Tuple[
1465        Any,
1466        Iterable[Any],
1467        Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
1468        Any,
1469        PatternExpr,
1470    ]
1471] = []
1472
1473
1474def gen_register_replacement(
1475    unique_name: str,
1476    search_fn: SearchFn,
1477    replace_fn: ReplaceFn,
1478    example_inputs: Iterable[Any],
1479    trace_fn: TraceFn,
1480    pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
1481    extra_check: Callable[[Match], bool] = _return_true,
1482    scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1483    exclusive_arg_names: Sequence[str] = (),
1484    skip_duplicates: bool = False,
1485) -> None:
1486    # Make sure the example_inputs is materialized.
1487    example_inputs = tuple(example_inputs)
1488
1489    if "PYTORCH_GEN_PATTERNS" in os.environ:
1490        pat = _serialize_pattern(
1491            unique_name, search_fn, example_inputs, trace_fn, scalar_workaround
1492        )
1493    else:
1494        pattern_name = search_fn.__name__
1495        m = importlib.import_module(
1496            f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}"
1497        )
1498        if not m or not hasattr(m, unique_name):
1499            log.warning(
1500                "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.",
1501                unique_name,
1502            )
1503        pat = getattr(m, unique_name)
1504
1505    for arg in pytree.tree_iter(example_inputs):
1506        if isinstance(arg, FakeTensor) and arg.constant is not None:
1507            # This can be a problem - small fake tensors (e.g. `tensor(2)`) will
1508            # hold onto their original constant value - and by stashing it here
1509            # will cause a memory leak if the constant value is on GPU.
1510            # Since this is just an optimization we can clear it out.
1511            arg.constant = None
1512
1513    if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
1514        return
1515    _known_precompiled_patterns.append(
1516        (search_fn, example_inputs, trace_fn, scalar_workaround, pat)
1517    )
1518    register_replacement(
1519        search_fn,
1520        replace_fn,
1521        example_inputs,
1522        trace_fn,
1523        pass_dicts,
1524        extra_check,
1525        scalar_workaround,
1526        exclusive_arg_names,
1527        search_fn_pattern=pat,
1528    )
1529
1530
1531@functorch_config.patch(functionalize_rng_ops=False)
1532def gen_pattern(
1533    search_fn: SearchFn,
1534    example_inputs: Sequence[Any],
1535    trace_fn: TraceFn,
1536    scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1537    exclusive_arg_names: Sequence[str] = (),
1538) -> PatternExpr:
1539    argnames = [*inspect.signature(search_fn).parameters.keys()]
1540
1541    if scalar_workaround is None:
1542        scalar_workaround = {}
1543    flat_inputs = []
1544    input_idx = 0  # Positional arguments index
1545
1546    for argname in argnames:
1547        if argname in scalar_workaround:
1548            flat_inputs.append(scalar_workaround[argname])
1549        else:
1550            flat_inputs.append(example_inputs[input_idx])
1551            input_idx += 1
1552
1553    search_gm = trace_fn(search_fn, flat_inputs)
1554    return fx_to_pattern(
1555        search_gm,
1556        ignore_types=(int, float, list, torch.device, torch.dtype),
1557        argnames=argnames,
1558        scalar_workaround=scalar_workaround,
1559        exclusive_arg_names=exclusive_arg_names,
1560    )
1561
1562
1563def register_lowering_pattern(
1564    pattern: PatternExpr,
1565    extra_check: Callable[[Match], bool] = _return_true,
1566    *,
1567    pass_dict: _PassDictsType,
1568    prepend: bool = False,
1569) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1570    """
1571    Register an aten to inductor IR replacement pattern.  The decorated
1572    function is saved and then called a lowering time allowing direct
1573    pattern to inductor IR conversion.
1574    """
1575
1576    def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
1577        assert callable(handler)
1578        LoweringPatternEntry(
1579            pattern=pattern, extra_check=extra_check, handler=handler
1580        ).register(pass_dict, prepend=prepend)
1581        handler._inductor_lowering_function = True  # type: ignore[attr-defined]
1582        return handler
1583
1584    return decorator
1585
1586
1587def register_graph_pattern(
1588    pattern: PatternExpr,
1589    extra_check: Callable[[Match], bool] = _return_true,
1590    *,
1591    pass_dict: _PassDictsType,
1592    prepend: bool = False,
1593) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1594    """
1595    Register a pattern that runs a function on the FX graph, allowing
1596    custom transformation code.
1597    """
1598
1599    def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
1600        assert callable(handler)
1601        GraphPatternEntry(
1602            pattern=pattern, extra_check=extra_check, handler=handler
1603        ).register(pass_dict, prepend=prepend)
1604        return handler
1605
1606    return decorator
1607
1608
1609def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
1610    # first node in the graph
1611    return node is next(iter(graph.nodes))
1612
1613
1614# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc
1615# doesn't match: __rshift__, etc
1616_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
1617
1618
1619def is_mutation_op(node: torch.fx.Node) -> bool:
1620    if node.op == "call_function":
1621        if _mutation_op_re.search(node.target.__name__):  # type: ignore[union-attr]
1622            return True
1623    elif node.op == "call_method":
1624        if _mutation_op_re.search(node.target):  # type: ignore[union-attr, arg-type]
1625            return True
1626    return node.kwargs.get("out") is not None
1627
1628
1629def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
1630    assert "mutation_region_id" in a.meta
1631    assert "mutation_region_id" in b.meta
1632    return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
1633
1634
1635def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
1636    n = node
1637    while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
1638        n = n.prev
1639    mutation_region_id = n.meta.get("mutation_region_id", 0)
1640    while n is not node:
1641        n = n.next
1642        if is_mutation_op(n):
1643            mutation_region_id += 1
1644        n.meta["mutation_region_id"] = mutation_region_id
1645    return mutation_region_id
1646
1647
1648def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
1649    return "mutation_region_id" not in next(iter(graph.nodes)).meta
1650
1651
1652def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
1653    mutation_region_id = 0
1654    for nd in graph.nodes:
1655        if is_mutation_op(nd):
1656            mutation_region_id += 1
1657        nd.meta["mutation_region_id"] = mutation_region_id
1658
1659
1660class PatternMatcherPass:
1661    def __init__(
1662        self,
1663        pass_name: Optional[str] = None,
1664    ) -> None:
1665        super().__init__()
1666        self.patterns: DefaultDict[
1667            Tuple[str, torch.fx.node.Target], List[PatternEntry]
1668        ] = defaultdict(list)
1669        self.pass_name = pass_name
1670
1671    def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
1672        return self.patterns[item]
1673
1674    def apply(self, gm: torch.fx.GraphModule) -> int:
1675        if not self.patterns:
1676            return 0
1677        if isinstance(gm, torch.fx.GraphModule):
1678            graph = gm.graph
1679        elif isinstance(gm, torch.fx.Graph):
1680            graph = gm
1681            gm = graph.owning_module
1682        else:
1683            raise RuntimeError(
1684                f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
1685            )
1686        if should_compute_mutation_region_ids(graph):  # type: ignore[arg-type]
1687            compute_mutation_region_ids(graph)  # type: ignore[arg-type]
1688        get_mutation_region_id_partial = functools.partial(
1689            get_mutation_region_id, graph
1690        )
1691        count = 0
1692        nodes = []
1693        has_call_module = False
1694        for op, target in self.patterns:
1695            if op == "call_module":
1696                has_call_module = True
1697            else:
1698                nodes.append(graph.find_nodes(op=op, target=target, sort=False))
1699        if has_call_module:
1700            nodes.append(graph.find_nodes(op="call_module", sort=False))
1701        pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
1702        with GraphTransformObserver(
1703            gm, pass_name, trace_config.log_url_for_graph_xform
1704        ):
1705            for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
1706                target = extract_target(node)
1707                if node.op == "call_module":
1708                    if (node.op, target) not in self.patterns:
1709                        continue
1710
1711                # conservatively not applying pattern for cpu input,
1712                # since some of the patterns induce codegen and split nodes.
1713                # Note: we will only skip cpu compute if disable_cpp_codegen=True
1714                if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
1715                    continue
1716
1717                for entry in self.patterns[(node.op, target)]:
1718                    if node._erased:
1719                        break
1720                    m = entry.pattern.match(node)
1721                    # pattern match crosses mutation barrier - discard
1722                    if (
1723                        is_match(m)
1724                        and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1  # type: ignore[possibly-undefined]
1725                    ):
1726                        continue
1727                    if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
1728                        log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
1729                    if is_match(m) and entry.extra_check(m):
1730                        count += 1
1731                        entry.apply(m, graph, node)  # type: ignore[arg-type]
1732                        counters["inductor"]["pattern_matcher_count"] += 1
1733                        counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
1734        return count
1735
1736    def clear(self) -> None:
1737        self.patterns.clear()
1738
1739
1740def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
1741    raise NotImplementedError
1742
1743
1744def fx_to_pattern(
1745    gm: Union[torch.fx.GraphModule, torch.fx.Graph],
1746    ignore_types: Sequence[Type[Any]] = (),
1747    argnames: Sequence[str] = (),
1748    scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1749    exclusive_arg_names: Sequence[str] = (),
1750) -> PatternExpr:
1751    """
1752    Convert an FX graph into a PatternExpr.  This is useful for simple
1753    patterns that can only match single functions and fixed-length lists.
1754    """
1755    # scalar_workaround is a hack to capture dropout_p
1756    # see https://github.com/pytorch/pytorch/issues/97894
1757    scalar_workaround = scalar_workaround or {}
1758    inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
1759    assert len(inv_scalar_workaround) == len(scalar_workaround)
1760
1761    def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
1762        if isinstance(x, (float, int)) and x in inv_scalar_workaround:
1763            return KeywordArg(inv_scalar_workaround[x])
1764        if type(x) in ignore_types:
1765            return Ignored()
1766        if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
1767            return Ignored()
1768        return x
1769
1770    argnum = itertools.count()
1771
1772    class Converter(torch.fx.Interpreter):
1773        call_method = _not_implemented
1774        call_module = _not_implemented
1775        get_attr = _not_implemented
1776
1777        def placeholder(
1778            self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]  # type: ignore[override]
1779        ) -> Union[ExclusiveKeywordArg, KeywordArg]:
1780            n = next(argnum)
1781            if n < len(argnames):
1782                name = argnames[n]
1783            elif argnames:
1784                assert target.startswith("tangent")
1785                name = target
1786            else:
1787                target = re.sub(r"_\d+$", "", target)  # de-mangle arg name
1788                name = target
1789            if name in exclusive_arg_names:
1790                return ExclusiveKeywordArg(name)
1791            else:
1792                return KeywordArg(name)
1793
1794        def call_function(
1795            self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]  # type: ignore[override]
1796        ) -> PatternExpr:
1797            args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
1798            if list in ignore_types:
1799                # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
1800                args = [process_arg(a) for a in args]
1801                kwargs = {k: process_arg(a) for k, a in kwargs.items()}
1802            return CallFunction(target, *args, **kwargs)
1803
1804        def run_node(self, n: torch.fx.Node) -> Any:
1805            rv = super().run_node(n)
1806            if n.op == "output" and isinstance(rv, tuple):
1807                assert len(rv) == len(n.args[0])  # type: ignore[arg-type]
1808                for r, arg in zip(rv, n.args[0]):  # type: ignore[arg-type]
1809                    r.users = len(arg.users)
1810            else:
1811                rv.users = len(n.users)
1812            return rv
1813
1814    pattern = Converter(gm).run()  # type: ignore[arg-type]
1815    if not isinstance(pattern, PatternExpr):
1816        return MultiOutputPattern(pytree.tree_leaves(pattern))
1817    return pattern
1818
1819
1820@torch.no_grad()
1821def fwd_only(
1822    fn: Callable[..., Any],
1823    args: Sequence[Any],
1824    *,
1825    run_functional_passes: bool = True,
1826    get_decomp_fn: Optional[Callable[..., Any]] = None,
1827) -> torch.fx.GraphModule:
1828    """Build a normalized inference graph, for use with fx_to_pattern"""
1829    # TODO - look into using aot autograd, asserting no mutating ops here
1830    with enable_python_dispatcher():
1831        decompositions = (
1832            get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
1833        )
1834        gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
1835
1836    from .fx_passes.post_grad import remove_noop_ops
1837
1838    if run_functional_passes:
1839        remove_noop_ops(gm.graph)
1840        gm.graph.eliminate_dead_code()
1841
1842    gm.recompile()
1843    return gm
1844
1845
1846@torch.enable_grad()
1847def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule:
1848    """Build a normalized training graph, for use with fx_to_pattern"""
1849    gm: Optional[torch.fx.GraphModule] = None
1850
1851    def record_joint_graph(
1852        joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any
1853    ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
1854        nonlocal gm
1855        assert not gm
1856        gm = clone_graph(joint_graph)
1857        return default_partition(joint_graph, inputs, **kwargs)
1858
1859    with torch._guards.tracing(None):
1860        aot_function(
1861            fn,
1862            lambda g, i: make_boxed_func(g),
1863            partition_fn=record_joint_graph,
1864            decompositions=select_decomp_table(),
1865            keep_inference_input_mutations=True,
1866            enable_log=False,
1867        )(*args)
1868    assert gm
1869
1870    from .fx_passes.post_grad import remove_noop_ops
1871
1872    remove_noop_ops(gm.graph)
1873
1874    from .fx_passes.joint_graph import pointless_view
1875
1876    matcher_pass = PatternMatcherPass()
1877
1878    pattern = CallFunction(
1879        torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
1880    )
1881    GraphPatternEntry(
1882        pattern=pattern, handler=pointless_view, extra_check=_return_true
1883    ).register(matcher_pass.patterns)
1884    matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
1885
1886    # remove in/out specs
1887    gm.graph._codegen = torch.fx.graph.CodeGen()
1888    gm.graph.eliminate_dead_code()
1889    gm.recompile()
1890    return gm
1891
1892
1893def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
1894    args: List[torch.fx.node.Argument] = []
1895    torch.fx.map_arg((n.args, n.kwargs), args.append)
1896    return args
1897
1898
1899def stable_topological_sort(graph: torch.fx.Graph) -> None:
1900    # Nodes are in exactly one of these three collections:
1901
1902    # - Nodes in `pending` are waiting to be processed (in reverse order):
1903    pending = list(reversed(graph.nodes))
1904
1905    # - Nodes in `ready` have been processed and are already in the correct
1906    #   order.
1907    ready = set()
1908
1909    # - `waiting` is a mapping from a dependency to nodes which depend on that
1910    #   dependency.
1911    waiting = defaultdict(list)
1912
1913    # The cursor indicates the last processed node so we can add new nodes
1914    # after it.
1915    cursor = None
1916    while pending:
1917        node = pending.pop()
1918        waiting_for = [x for x in _args(node) if x not in ready]
1919        if waiting_for:
1920            # We have unprocessed input nodes. Might as well wait for the last
1921            # arg so an already sorted list will only recheck this node once.
1922            waiting[waiting_for[-1]].append(node)
1923        else:
1924            ready.add(node)
1925            if cursor and cursor.next is not node:
1926                cursor.append(node)
1927            cursor = node
1928            # Mark the nodes that have been waiting for this node to finish as
1929            # ready to check again.
1930            pending.extend(reversed(waiting.pop(node, ())))
1931
1932    assert not waiting and len(ready) == len(graph.nodes)
1933
1934
1935def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
1936    """Wrapper around lazy init functions in fx_passes/"""
1937
1938    @functools.lru_cache(None)
1939    @functools.wraps(fn)
1940    def lazy_init() -> Any:
1941        counters_ref = counters["inductor"].copy()
1942
1943        with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
1944            result = fn()
1945
1946        # clear view matches encountered during tracing
1947        counters["inductor"] = counters_ref
1948
1949        return result
1950
1951    return lazy_init
1952
1953
1954def config_flag(name: str) -> Callable[[Match], Any]:
1955    """Function for extra_check to put pass behind a flag"""
1956
1957    def flag_check(match: Match) -> Any:
1958        return getattr(config, name)
1959
1960    return flag_check
1961
1962
1963def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
1964    class CopyGraph(Transformer):
1965        def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:
1966            new_node = super().run_node(old_node)
1967            if isinstance(new_node, torch.fx.Proxy):
1968                new_node.node.meta.update(old_node.meta)
1969                new_node.node.name = self.new_graph._graph_namespace.create_name(
1970                    old_node.name, None
1971                )
1972            return new_node
1973
1974    return CopyGraph(input_graph).transform()
1975
1976
1977_seen_patterns: Set[str] = set()
1978
1979
1980def get_arg_value(
1981    node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
1982) -> Any:
1983    return (
1984        node.args[arg_number]
1985        if len(node.args) > arg_number
1986        else node.kwargs.get(kwarg_name)  # type: ignore[arg-type]
1987    )
1988
1989
1990def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
1991    fns = [fn]
1992    if isinstance(fn, torch._ops.OpOverloadPacket):
1993        fns.extend([getattr(fn, overload) for overload in fn.overloads()])
1994
1995    return [node for node in nodes if node.target in fns]
1996
1997
1998def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
1999    """For call_function and call_method, we directly use the target function;
2000    For call_module, the target is string, and we treat the module class
2001     as a function.
2002    """
2003    if node.op == "call_module":
2004        return getattr(node.graph.owning_module, node.target).__class__  # type: ignore[arg-type]
2005    return node.target
2006