xref: /aosp_15_r20/external/pytorch/torch/fx/passes/infra/pass_base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3from collections import namedtuple
4from typing import Optional
5
6from torch.fx.graph_module import GraphModule
7from torch.fx._compatibility import compatibility
8
9
10__all__ = ['PassResult', 'PassBase']
11
12@compatibility(is_backward_compatible=False)
13class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
14    """
15    Result of a pass:
16        graph_module: The modified graph module
17        modified: A flag for if the pass has modified the graph module
18    """
19    def __new__(cls, graph_module, modified):
20        return super().__new__(cls, graph_module, modified)
21
22@compatibility(is_backward_compatible=False)
23class PassBase(abc.ABC):
24    """
25    Base interface for implementing passes.
26
27    It is required to implement the `call` function so that we can directly
28    pass instances of the Pass directly to the PassManager and call them as a
29    function.
30
31    We can directly pass an instance of a class implementing this interface into
32    the PassManager's `passes` attribute.
33    """
34
35    def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
36        """
37        Runs the precondition check, the pass itself, and the postcondition check.
38        """
39
40        self.requires(graph_module)
41        res = self.call(graph_module)
42        self.ensures(graph_module)
43        return res
44
45    @abc.abstractmethod
46    def call(self, graph_module: GraphModule) -> Optional[PassResult]:
47        """
48        The pass that is run through the given graph module. To implement a
49        pass, it is required to implement this function.
50
51        Args:
52            graph_module: The graph module we will run a pass on
53        """
54
55    def requires(self, graph_module: GraphModule) -> None:  # noqa: B027
56        """
57        This function will be called before the pass is run and will check that
58        the given graph module contains the preconditions needed to run the
59        pass. It is not required to implement this function.
60
61        Args:
62            graph_module: The graph module we will run checks on
63        """
64
65    def ensures(self, graph_module: GraphModule) -> None:  # noqa: B027
66        """
67        This function will be called after the pass is run and will check that
68        the given graph module contains the postconditions needed to run the
69        pass. It is not required to implement this function.
70
71        Args:
72            graph_module: The graph module we will run checks on
73        """
74