1# mypy: allow-untyped-defs 2import abc 3from collections import namedtuple 4from typing import Optional 5 6from torch.fx.graph_module import GraphModule 7from torch.fx._compatibility import compatibility 8 9 10__all__ = ['PassResult', 'PassBase'] 11 12@compatibility(is_backward_compatible=False) 13class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): 14 """ 15 Result of a pass: 16 graph_module: The modified graph module 17 modified: A flag for if the pass has modified the graph module 18 """ 19 def __new__(cls, graph_module, modified): 20 return super().__new__(cls, graph_module, modified) 21 22@compatibility(is_backward_compatible=False) 23class PassBase(abc.ABC): 24 """ 25 Base interface for implementing passes. 26 27 It is required to implement the `call` function so that we can directly 28 pass instances of the Pass directly to the PassManager and call them as a 29 function. 30 31 We can directly pass an instance of a class implementing this interface into 32 the PassManager's `passes` attribute. 33 """ 34 35 def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: 36 """ 37 Runs the precondition check, the pass itself, and the postcondition check. 38 """ 39 40 self.requires(graph_module) 41 res = self.call(graph_module) 42 self.ensures(graph_module) 43 return res 44 45 @abc.abstractmethod 46 def call(self, graph_module: GraphModule) -> Optional[PassResult]: 47 """ 48 The pass that is run through the given graph module. To implement a 49 pass, it is required to implement this function. 50 51 Args: 52 graph_module: The graph module we will run a pass on 53 """ 54 55 def requires(self, graph_module: GraphModule) -> None: # noqa: B027 56 """ 57 This function will be called before the pass is run and will check that 58 the given graph module contains the preconditions needed to run the 59 pass. It is not required to implement this function. 60 61 Args: 62 graph_module: The graph module we will run checks on 63 """ 64 65 def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 66 """ 67 This function will be called after the pass is run and will check that 68 the given graph module contains the postconditions needed to run the 69 pass. It is not required to implement this function. 70 71 Args: 72 graph_module: The graph module we will run checks on 73 """ 74