1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Callable, List, Optional, Union 10 11import torch 12import torch.fx.passes.infra.pass_manager as fx 13import torch.utils._pytree as pytree 14from executorch.exir.error import ExportError, ExportErrorType 15from torch.fx.passes.infra.pass_base import PassResult 16from typing_extensions import TypeAlias 17 18PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]] 19 20 21class PassManager(fx.PassManager): 22 """ 23 Class to run multiple passes on a given graph module. The PassManager is 24 callable so to run it, we can just call the PassManager instance. 25 26 Private Attributes: 27 * **passes**: A list of callable passes 28 * **params**: An instance of PassManagerParams containing the result of the 29 flags set in the constructor. 30 """ 31 32 def __init__( 33 self, 34 passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, 35 run_checks_after_each_pass: bool = False, 36 suppress_check_failures: bool = False, 37 ) -> None: 38 r""" 39 Args: 40 passes: A list of passes 41 enable_debug_pass: set to true to enable the debug passes 42 run_checks_after_each_pass: whether to run checks and linting after each pass 43 """ 44 45 # Flatten the passes to a list of callables 46 passes = passes if passes else [] 47 flattened_passes = [ 48 fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0] 49 ] 50 51 super().__init__( 52 flattened_passes, 53 run_checks_after_each_pass=run_checks_after_each_pass, 54 suppress_check_failures=suppress_check_failures, 55 ) 56 57 def check(self, module: torch.nn.Module) -> None: 58 """ 59 Runs various checks on the given graph module to make sure it contains 60 the needed data for passes. 61 62 Some checks that need to be run: 63 - Ensure that types of operator node match the types specified in 64 the node's spec field (ex. if the op returns a tuple then the 65 node's spec field is a tuple) 66 - Ensure that the graph module has type torch.fx.GraphModule 67 """ 68 assert isinstance(module, fx.GraphModule) 69 module.recompile() 70 module.graph.lint() 71 # TODO(qihan): use verifier.check_is_exir 72 73 for node in module.graph.nodes: 74 if node.op == "call_method": 75 raise ExportError( 76 ExportErrorType.NOT_SUPPORTED, 77 f"call_method `{node}` is not supported except for backend delegate.", 78 ) 79