xref: /aosp_15_r20/external/executorch/exir/pass_manager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Callable, List, Optional, Union
10
11import torch
12import torch.fx.passes.infra.pass_manager as fx
13import torch.utils._pytree as pytree
14from executorch.exir.error import ExportError, ExportErrorType
15from torch.fx.passes.infra.pass_base import PassResult
16from typing_extensions import TypeAlias
17
18PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
19
20
21class PassManager(fx.PassManager):
22    """
23    Class to run multiple passes on a given graph module. The PassManager is
24    callable so to run it, we can just call the PassManager instance.
25
26    Private Attributes:
27        * **passes**: A list of callable passes
28        * **params**: An instance of PassManagerParams containing the result of the
29            flags set in the constructor.
30    """
31
32    def __init__(
33        self,
34        passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
35        run_checks_after_each_pass: bool = False,
36        suppress_check_failures: bool = False,
37    ) -> None:
38        r"""
39        Args:
40            passes: A list of passes
41            enable_debug_pass: set to true to enable the debug passes
42            run_checks_after_each_pass: whether to run checks and linting after each pass
43        """
44
45        # Flatten the passes to a list of callables
46        passes = passes if passes else []
47        flattened_passes = [
48            fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0]
49        ]
50
51        super().__init__(
52            flattened_passes,
53            run_checks_after_each_pass=run_checks_after_each_pass,
54            suppress_check_failures=suppress_check_failures,
55        )
56
57    def check(self, module: torch.nn.Module) -> None:
58        """
59        Runs various checks on the given graph module to make sure it contains
60        the needed data for passes.
61
62        Some checks that need to be run:
63            - Ensure that types of operator node match the types specified in
64              the node's spec field (ex. if the op returns a tuple then the
65              node's spec field is a tuple)
66            - Ensure that the graph module has type torch.fx.GraphModule
67        """
68        assert isinstance(module, fx.GraphModule)
69        module.recompile()
70        module.graph.lint()
71        # TODO(qihan): use verifier.check_is_exir
72
73        for node in module.graph.nodes:
74            if node.op == "call_method":
75                raise ExportError(
76                    ExportErrorType.NOT_SUPPORTED,
77                    f"call_method `{node}` is not supported except for backend delegate.",
78                )
79