1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom .graph_module import GraphModule 3*da0073e9SAndroid Build Coastguard Workerfrom ._lazy_graph_module import _make_graph_module 4*da0073e9SAndroid Build Coastguard Workerfrom .graph import Graph 5*da0073e9SAndroid Build Coastguard Workerfrom .node import Argument, Node, Target, map_arg, map_aggregate 6*da0073e9SAndroid Build Coastguard Workerfrom .proxy import Proxy 7*da0073e9SAndroid Build Coastguard Workerfrom ._symbolic_trace import Tracer 8*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility 9*da0073e9SAndroid Build Coastguard Workerfrom . import config 10*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterator, List, Optional, Tuple, Union 13*da0073e9SAndroid Build Coastguard Workerimport inspect 14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 15*da0073e9SAndroid Build Coastguard Workerfrom torch.hub import tqdm 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker__all__ = ['Interpreter', 'Transformer'] 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 20*da0073e9SAndroid Build Coastguard Workerclass Interpreter: 21*da0073e9SAndroid Build Coastguard Worker """ 22*da0073e9SAndroid Build Coastguard Worker An Interpreter executes an FX graph Node-by-Node. This pattern 23*da0073e9SAndroid Build Coastguard Worker can be useful for many things, including writing code 24*da0073e9SAndroid Build Coastguard Worker transformations as well as analysis passes. 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker Methods in the Interpreter class can be overridden to customize 27*da0073e9SAndroid Build Coastguard Worker the behavior of execution. The map of overrideable methods 28*da0073e9SAndroid Build Coastguard Worker in terms of call hierarchy:: 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker run() 31*da0073e9SAndroid Build Coastguard Worker +-- run_node 32*da0073e9SAndroid Build Coastguard Worker +-- placeholder() 33*da0073e9SAndroid Build Coastguard Worker +-- get_attr() 34*da0073e9SAndroid Build Coastguard Worker +-- call_function() 35*da0073e9SAndroid Build Coastguard Worker +-- call_method() 36*da0073e9SAndroid Build Coastguard Worker +-- call_module() 37*da0073e9SAndroid Build Coastguard Worker +-- output() 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker Example: 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker Suppose we want to swap all instances of ``torch.neg`` with 42*da0073e9SAndroid Build Coastguard Worker ``torch.sigmoid`` and vice versa (including their ``Tensor`` 43*da0073e9SAndroid Build Coastguard Worker method equivalents). We could subclass Interpreter like so:: 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker class NegSigmSwapInterpreter(Interpreter): 46*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : Target, 47*da0073e9SAndroid Build Coastguard Worker args : Tuple, kwargs : Dict) -> Any: 48*da0073e9SAndroid Build Coastguard Worker if target == torch.sigmoid: 49*da0073e9SAndroid Build Coastguard Worker return torch.neg(*args, **kwargs) 50*da0073e9SAndroid Build Coastguard Worker return super().call_function(n) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def call_method(self, target : Target, 53*da0073e9SAndroid Build Coastguard Worker args : Tuple, kwargs : Dict) -> Any: 54*da0073e9SAndroid Build Coastguard Worker if target == 'neg': 55*da0073e9SAndroid Build Coastguard Worker call_self, *args_tail = args 56*da0073e9SAndroid Build Coastguard Worker return call_self.sigmoid(*args_tail, **kwargs) 57*da0073e9SAndroid Build Coastguard Worker return super().call_method(n) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def fn(x): 60*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x).neg() 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(fn) 63*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 64*da0073e9SAndroid Build Coastguard Worker result = NegSigmSwapInterpreter(gm).run(input) 65*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, torch.neg(input).sigmoid()) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker Args: 68*da0073e9SAndroid Build Coastguard Worker module (torch.nn.Module): The module to be executed 69*da0073e9SAndroid Build Coastguard Worker garbage_collect_values (bool): Whether to delete values after their last 70*da0073e9SAndroid Build Coastguard Worker use within the Module's execution. This ensures optimal memory usage during 71*da0073e9SAndroid Build Coastguard Worker execution. This can be disabled to, for example, examine all of the intermediate 72*da0073e9SAndroid Build Coastguard Worker values in the execution by looking at the ``Interpreter.env`` attribute. 73*da0073e9SAndroid Build Coastguard Worker graph (Optional[Graph]): If passed, the interpreter will execute this 74*da0073e9SAndroid Build Coastguard Worker graph instead of `module.graph`, using the provided `module` 75*da0073e9SAndroid Build Coastguard Worker argument to satisfy any requests for state. 76*da0073e9SAndroid Build Coastguard Worker """ 77*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 78*da0073e9SAndroid Build Coastguard Worker def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): 79*da0073e9SAndroid Build Coastguard Worker self.module = module 80*da0073e9SAndroid Build Coastguard Worker self.submodules = dict(self.module.named_modules()) 81*da0073e9SAndroid Build Coastguard Worker if graph is not None: 82*da0073e9SAndroid Build Coastguard Worker self.graph = graph 83*da0073e9SAndroid Build Coastguard Worker else: 84*da0073e9SAndroid Build Coastguard Worker self.graph = self.module.graph 85*da0073e9SAndroid Build Coastguard Worker self.env : Dict[Node, Any] = {} 86*da0073e9SAndroid Build Coastguard Worker self.name = "Interpreter" 87*da0073e9SAndroid Build Coastguard Worker self.garbage_collect_values = garbage_collect_values 88*da0073e9SAndroid Build Coastguard Worker self.extra_traceback = True 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker if self.garbage_collect_values: 91*da0073e9SAndroid Build Coastguard Worker # Run through reverse nodes and record the first instance of a use 92*da0073e9SAndroid Build Coastguard Worker # of a given node. This represents the *last* use of the node in the 93*da0073e9SAndroid Build Coastguard Worker # execution order of the program, which we will use to free unused 94*da0073e9SAndroid Build Coastguard Worker # values 95*da0073e9SAndroid Build Coastguard Worker node_to_last_use : Dict[Node, Node] = {} 96*da0073e9SAndroid Build Coastguard Worker self.user_to_last_uses : Dict[Node, List[Node]] = {} 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def register_last_uses(n : Node, user : Node): 99*da0073e9SAndroid Build Coastguard Worker if n not in node_to_last_use: 100*da0073e9SAndroid Build Coastguard Worker node_to_last_use[n] = user 101*da0073e9SAndroid Build Coastguard Worker self.user_to_last_uses.setdefault(user, []).append(n) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker for node in reversed(self.graph.nodes): 104*da0073e9SAndroid Build Coastguard Worker map_arg(node.args, lambda n: register_last_uses(n, node)) 105*da0073e9SAndroid Build Coastguard Worker map_arg(node.kwargs, lambda n: register_last_uses(n, node)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 108*da0073e9SAndroid Build Coastguard Worker def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: 109*da0073e9SAndroid Build Coastguard Worker """ 110*da0073e9SAndroid Build Coastguard Worker Run `module` via interpretation and return the result. 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker Args: 113*da0073e9SAndroid Build Coastguard Worker *args: The arguments to the Module to run, in positional order 114*da0073e9SAndroid Build Coastguard Worker initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. 115*da0073e9SAndroid Build Coastguard Worker This is a dict mapping `Node` to any value. This can be used, for example, to 116*da0073e9SAndroid Build Coastguard Worker pre-populate results for certain `Nodes` so as to do only partial evaluation within 117*da0073e9SAndroid Build Coastguard Worker the interpreter. 118*da0073e9SAndroid Build Coastguard Worker enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and 119*da0073e9SAndroid Build Coastguard Worker process_outputs function first before using them. 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker Returns: 122*da0073e9SAndroid Build Coastguard Worker Any: The value returned from executing the Module 123*da0073e9SAndroid Build Coastguard Worker """ 124*da0073e9SAndroid Build Coastguard Worker self.env = initial_env if initial_env is not None else {} 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker # Positional function args are consumed left-to-right by 127*da0073e9SAndroid Build Coastguard Worker # `placeholder` nodes. Use an iterator to keep track of 128*da0073e9SAndroid Build Coastguard Worker # position and extract those values. 129*da0073e9SAndroid Build Coastguard Worker if enable_io_processing: 130*da0073e9SAndroid Build Coastguard Worker args = self.graph.process_inputs(*args) 131*da0073e9SAndroid Build Coastguard Worker self.args_iter : Iterator[Any] = iter(args) 132*da0073e9SAndroid Build Coastguard Worker pbar = tqdm(total=len(self.graph.nodes), 133*da0073e9SAndroid Build Coastguard Worker desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", 134*da0073e9SAndroid Build Coastguard Worker initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker for node in self.graph.nodes: 137*da0073e9SAndroid Build Coastguard Worker pbar.update(1) 138*da0073e9SAndroid Build Coastguard Worker if node in self.env: 139*da0073e9SAndroid Build Coastguard Worker # Short circuit if we have this value. This could 140*da0073e9SAndroid Build Coastguard Worker # be used, for example, for partial evaluation 141*da0073e9SAndroid Build Coastguard Worker # where the caller has pre-populated `env` with 142*da0073e9SAndroid Build Coastguard Worker # values for a subset of the program. 143*da0073e9SAndroid Build Coastguard Worker continue 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker try: 146*da0073e9SAndroid Build Coastguard Worker self.env[node] = self.run_node(node) 147*da0073e9SAndroid Build Coastguard Worker except Exception as e: 148*da0073e9SAndroid Build Coastguard Worker if self.extra_traceback: 149*da0073e9SAndroid Build Coastguard Worker msg = f"While executing {node.format_node()}" 150*da0073e9SAndroid Build Coastguard Worker msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) 151*da0073e9SAndroid Build Coastguard Worker msg += f"\nOriginal traceback:\n{node.stack_trace}" 152*da0073e9SAndroid Build Coastguard Worker e.args = (msg,) + e.args[1:] 153*da0073e9SAndroid Build Coastguard Worker if isinstance(e, KeyError): 154*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(*e.args) from e 155*da0073e9SAndroid Build Coastguard Worker raise 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker if self.garbage_collect_values: 158*da0073e9SAndroid Build Coastguard Worker for to_delete in self.user_to_last_uses.get(node, []): 159*da0073e9SAndroid Build Coastguard Worker del self.env[to_delete] 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker if node.op == 'output': 162*da0073e9SAndroid Build Coastguard Worker output_val = self.env[node] 163*da0073e9SAndroid Build Coastguard Worker return self.graph.process_outputs(output_val) if enable_io_processing else output_val 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 166*da0073e9SAndroid Build Coastguard Worker def boxed_run(self, args_list): 167*da0073e9SAndroid Build Coastguard Worker """ 168*da0073e9SAndroid Build Coastguard Worker Run `module` via interpretation and return the result. This uses the "boxed" 169*da0073e9SAndroid Build Coastguard Worker calling convention, where you pass a list of arguments, which will be cleared 170*da0073e9SAndroid Build Coastguard Worker by the interpreter. This ensures that input tensors are promptly deallocated. 171*da0073e9SAndroid Build Coastguard Worker """ 172*da0073e9SAndroid Build Coastguard Worker args_iter = iter(args_list) 173*da0073e9SAndroid Build Coastguard Worker env = {} 174*da0073e9SAndroid Build Coastguard Worker for n in self.graph.nodes: 175*da0073e9SAndroid Build Coastguard Worker if n.op == "placeholder": 176*da0073e9SAndroid Build Coastguard Worker env[n] = next(args_iter) 177*da0073e9SAndroid Build Coastguard Worker args_list.clear() 178*da0073e9SAndroid Build Coastguard Worker return self.run(initial_env=env) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @contextmanager 181*da0073e9SAndroid Build Coastguard Worker def _set_current_node(self, node): 182*da0073e9SAndroid Build Coastguard Worker with fx_traceback.set_current_meta(node): 183*da0073e9SAndroid Build Coastguard Worker yield 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 186*da0073e9SAndroid Build Coastguard Worker def run_node(self, n : Node) -> Any: 187*da0073e9SAndroid Build Coastguard Worker """ 188*da0073e9SAndroid Build Coastguard Worker Run a specific node ``n`` and return the result. 189*da0073e9SAndroid Build Coastguard Worker Calls into placeholder, get_attr, call_function, 190*da0073e9SAndroid Build Coastguard Worker call_method, call_module, or output depending 191*da0073e9SAndroid Build Coastguard Worker on ``node.op`` 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker Args: 194*da0073e9SAndroid Build Coastguard Worker n (Node): The Node to execute 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker Returns: 197*da0073e9SAndroid Build Coastguard Worker Any: The result of executing ``n`` 198*da0073e9SAndroid Build Coastguard Worker """ 199*da0073e9SAndroid Build Coastguard Worker with self._set_current_node(n): 200*da0073e9SAndroid Build Coastguard Worker args, kwargs = self.fetch_args_kwargs_from_env(n) 201*da0073e9SAndroid Build Coastguard Worker assert isinstance(args, tuple) 202*da0073e9SAndroid Build Coastguard Worker assert isinstance(kwargs, dict) 203*da0073e9SAndroid Build Coastguard Worker return getattr(self, n.op)(n.target, args, kwargs) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker # Main Node running APIs 206*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 207*da0073e9SAndroid Build Coastguard Worker def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 208*da0073e9SAndroid Build Coastguard Worker """ 209*da0073e9SAndroid Build Coastguard Worker Execute a ``placeholder`` node. Note that this is stateful: 210*da0073e9SAndroid Build Coastguard Worker ``Interpreter`` maintains an internal iterator over 211*da0073e9SAndroid Build Coastguard Worker arguments passed to ``run`` and this method returns 212*da0073e9SAndroid Build Coastguard Worker next() on that iterator. 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker Args: 215*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 216*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 217*da0073e9SAndroid Build Coastguard Worker details on semantics 218*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 219*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker Returns: 222*da0073e9SAndroid Build Coastguard Worker Any: The argument value that was retrieved. 223*da0073e9SAndroid Build Coastguard Worker """ 224*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 225*da0073e9SAndroid Build Coastguard Worker if target.startswith('*'): 226*da0073e9SAndroid Build Coastguard Worker # For a starred parameter e.g. `*args`, retrieve all 227*da0073e9SAndroid Build Coastguard Worker # remaining values from the args list. 228*da0073e9SAndroid Build Coastguard Worker return list(self.args_iter) 229*da0073e9SAndroid Build Coastguard Worker else: 230*da0073e9SAndroid Build Coastguard Worker try: 231*da0073e9SAndroid Build Coastguard Worker return next(self.args_iter) 232*da0073e9SAndroid Build Coastguard Worker except StopIteration as si: 233*da0073e9SAndroid Build Coastguard Worker if len(args) > 0: 234*da0073e9SAndroid Build Coastguard Worker return args[0] 235*da0073e9SAndroid Build Coastguard Worker else: 236*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 239*da0073e9SAndroid Build Coastguard Worker def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 240*da0073e9SAndroid Build Coastguard Worker """ 241*da0073e9SAndroid Build Coastguard Worker Execute a ``get_attr`` node. Will retrieve an attribute 242*da0073e9SAndroid Build Coastguard Worker value from the ``Module`` hierarchy of ``self.module``. 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker Args: 245*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 246*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 247*da0073e9SAndroid Build Coastguard Worker details on semantics 248*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 249*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker Return: 252*da0073e9SAndroid Build Coastguard Worker Any: The value of the attribute that was retrieved 253*da0073e9SAndroid Build Coastguard Worker """ 254*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 255*da0073e9SAndroid Build Coastguard Worker return self.fetch_attr(target) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 258*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 259*da0073e9SAndroid Build Coastguard Worker """ 260*da0073e9SAndroid Build Coastguard Worker Execute a ``call_function`` node and return the result. 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker Args: 263*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 264*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 265*da0073e9SAndroid Build Coastguard Worker details on semantics 266*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 267*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker Return 270*da0073e9SAndroid Build Coastguard Worker Any: The value returned by the function invocation 271*da0073e9SAndroid Build Coastguard Worker """ 272*da0073e9SAndroid Build Coastguard Worker assert not isinstance(target, str) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker # Execute the function and return the result 275*da0073e9SAndroid Build Coastguard Worker return target(*args, **kwargs) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 278*da0073e9SAndroid Build Coastguard Worker def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 279*da0073e9SAndroid Build Coastguard Worker """ 280*da0073e9SAndroid Build Coastguard Worker Execute a ``call_method`` node and return the result. 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker Args: 283*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 284*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 285*da0073e9SAndroid Build Coastguard Worker details on semantics 286*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 287*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker Return 290*da0073e9SAndroid Build Coastguard Worker Any: The value returned by the method invocation 291*da0073e9SAndroid Build Coastguard Worker """ 292*da0073e9SAndroid Build Coastguard Worker # args[0] is the `self` object for this method call 293*da0073e9SAndroid Build Coastguard Worker self_obj, *args_tail = args 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker # Execute the method and return the result 296*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 297*da0073e9SAndroid Build Coastguard Worker return getattr(self_obj, target)(*args_tail, **kwargs) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 300*da0073e9SAndroid Build Coastguard Worker def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 301*da0073e9SAndroid Build Coastguard Worker """ 302*da0073e9SAndroid Build Coastguard Worker Execute a ``call_module`` node and return the result. 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker Args: 305*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 306*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 307*da0073e9SAndroid Build Coastguard Worker details on semantics 308*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 309*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker Return 312*da0073e9SAndroid Build Coastguard Worker Any: The value returned by the module invocation 313*da0073e9SAndroid Build Coastguard Worker """ 314*da0073e9SAndroid Build Coastguard Worker # Retrieve executed args and kwargs values from the environment 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker # Execute the method and return the result 317*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 318*da0073e9SAndroid Build Coastguard Worker submod = self.fetch_attr(target) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker return submod(*args, **kwargs) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 323*da0073e9SAndroid Build Coastguard Worker def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 324*da0073e9SAndroid Build Coastguard Worker """ 325*da0073e9SAndroid Build Coastguard Worker Execute an ``output`` node. This really just retrieves 326*da0073e9SAndroid Build Coastguard Worker the value referenced by the ``output`` node and returns it. 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker Args: 329*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 330*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 331*da0073e9SAndroid Build Coastguard Worker details on semantics 332*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 333*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker Return: 336*da0073e9SAndroid Build Coastguard Worker Any: The return value referenced by the output node 337*da0073e9SAndroid Build Coastguard Worker """ 338*da0073e9SAndroid Build Coastguard Worker return args[0] 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # Helper methods 341*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 342*da0073e9SAndroid Build Coastguard Worker def fetch_attr(self, target : str): 343*da0073e9SAndroid Build Coastguard Worker """ 344*da0073e9SAndroid Build Coastguard Worker Fetch an attribute from the ``Module`` hierarchy of ``self.module``. 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker Args: 347*da0073e9SAndroid Build Coastguard Worker target (str): The fully-qualified name of the attribute to fetch 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker Return: 350*da0073e9SAndroid Build Coastguard Worker Any: The value of the attribute. 351*da0073e9SAndroid Build Coastguard Worker """ 352*da0073e9SAndroid Build Coastguard Worker target_atoms = target.split('.') 353*da0073e9SAndroid Build Coastguard Worker attr_itr = self.module 354*da0073e9SAndroid Build Coastguard Worker for i, atom in enumerate(target_atoms): 355*da0073e9SAndroid Build Coastguard Worker if not hasattr(attr_itr, atom): 356*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") 357*da0073e9SAndroid Build Coastguard Worker attr_itr = getattr(attr_itr, atom) 358*da0073e9SAndroid Build Coastguard Worker return attr_itr 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 361*da0073e9SAndroid Build Coastguard Worker def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: 362*da0073e9SAndroid Build Coastguard Worker """ 363*da0073e9SAndroid Build Coastguard Worker Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` 364*da0073e9SAndroid Build Coastguard Worker from the current execution environment. 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker Args: 367*da0073e9SAndroid Build Coastguard Worker n (Node): The node for which ``args`` and ``kwargs`` should be fetched. 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker Return: 370*da0073e9SAndroid Build Coastguard Worker Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. 371*da0073e9SAndroid Build Coastguard Worker """ 372*da0073e9SAndroid Build Coastguard Worker args = self.map_nodes_to_values(n.args, n) 373*da0073e9SAndroid Build Coastguard Worker assert isinstance(args, tuple) 374*da0073e9SAndroid Build Coastguard Worker kwargs = self.map_nodes_to_values(n.kwargs, n) 375*da0073e9SAndroid Build Coastguard Worker assert isinstance(kwargs, dict) 376*da0073e9SAndroid Build Coastguard Worker return args, kwargs 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 379*da0073e9SAndroid Build Coastguard Worker def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: 380*da0073e9SAndroid Build Coastguard Worker """ 381*da0073e9SAndroid Build Coastguard Worker Recursively descend through ``args`` and look up the concrete value 382*da0073e9SAndroid Build Coastguard Worker for each ``Node`` in the current execution environment. 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker Args: 385*da0073e9SAndroid Build Coastguard Worker args (Argument): Data structure within which to look up concrete values 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker n (Node): Node to which ``args`` belongs. This is only used for error reporting. 388*da0073e9SAndroid Build Coastguard Worker """ 389*da0073e9SAndroid Build Coastguard Worker def load_arg(n_arg : Node) -> Any: 390*da0073e9SAndroid Build Coastguard Worker if n_arg not in self.env: 391*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' 392*da0073e9SAndroid Build Coastguard Worker f'to diagnose such issues') 393*da0073e9SAndroid Build Coastguard Worker return self.env[n_arg] 394*da0073e9SAndroid Build Coastguard Worker return map_arg(args, load_arg) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 397*da0073e9SAndroid Build Coastguard Workerclass Transformer(Interpreter): 398*da0073e9SAndroid Build Coastguard Worker """ 399*da0073e9SAndroid Build Coastguard Worker ``Transformer`` is a special type of interpreter that produces a 400*da0073e9SAndroid Build Coastguard Worker new ``Module``. It exposes a ``transform()`` method that returns 401*da0073e9SAndroid Build Coastguard Worker the transformed ``Module``. ``Transformer`` does not require 402*da0073e9SAndroid Build Coastguard Worker arguments to run, as ``Interpreter`` does. ``Transformer`` works 403*da0073e9SAndroid Build Coastguard Worker entirely symbolically. 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker Example: 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker Suppose we want to swap all instances of ``torch.neg`` with 408*da0073e9SAndroid Build Coastguard Worker ``torch.sigmoid`` and vice versa (including their ``Tensor`` 409*da0073e9SAndroid Build Coastguard Worker method equivalents). We could subclass ``Transformer`` like so:: 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker class NegSigmSwapXformer(Transformer): 412*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 413*da0073e9SAndroid Build Coastguard Worker if target == torch.sigmoid: 414*da0073e9SAndroid Build Coastguard Worker return torch.neg(*args, **kwargs) 415*da0073e9SAndroid Build Coastguard Worker return super().call_function(n) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 418*da0073e9SAndroid Build Coastguard Worker if target == 'neg': 419*da0073e9SAndroid Build Coastguard Worker call_self, *args_tail = args 420*da0073e9SAndroid Build Coastguard Worker return call_self.sigmoid(*args_tail, **kwargs) 421*da0073e9SAndroid Build Coastguard Worker return super().call_method(n) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def fn(x): 424*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x).neg() 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(fn) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() 429*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4) 430*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker Args: 433*da0073e9SAndroid Build Coastguard Worker module (GraphModule): The ``Module`` to be transformed. 434*da0073e9SAndroid Build Coastguard Worker """ 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 437*da0073e9SAndroid Build Coastguard Worker def __init__(self, module): 438*da0073e9SAndroid Build Coastguard Worker super().__init__(module) 439*da0073e9SAndroid Build Coastguard Worker self.new_graph = Graph() 440*da0073e9SAndroid Build Coastguard Worker self.new_graph.set_codegen(module.graph._codegen) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker class TransformerTracer(Tracer): 443*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph: Graph): 444*da0073e9SAndroid Build Coastguard Worker super().__init__() 445*da0073e9SAndroid Build Coastguard Worker self.graph = graph 446*da0073e9SAndroid Build Coastguard Worker self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker def is_leaf_module(self, _, __) -> bool: 449*da0073e9SAndroid Build Coastguard Worker return True 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker self.tracer = TransformerTracer(self.new_graph) 452*da0073e9SAndroid Build Coastguard Worker self.tracer.root = module 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 455*da0073e9SAndroid Build Coastguard Worker def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: 456*da0073e9SAndroid Build Coastguard Worker """ 457*da0073e9SAndroid Build Coastguard Worker Execute a ``placeholder`` node. In ``Transformer``, this is 458*da0073e9SAndroid Build Coastguard Worker overridden to insert a new ``placeholder`` into the output 459*da0073e9SAndroid Build Coastguard Worker graph. 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker Args: 462*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 463*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 464*da0073e9SAndroid Build Coastguard Worker details on semantics 465*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 466*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 467*da0073e9SAndroid Build Coastguard Worker """ 468*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 469*da0073e9SAndroid Build Coastguard Worker default_value = next(iter(args)) if args else inspect.Signature.empty 470*da0073e9SAndroid Build Coastguard Worker return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 473*da0073e9SAndroid Build Coastguard Worker def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: 474*da0073e9SAndroid Build Coastguard Worker """ 475*da0073e9SAndroid Build Coastguard Worker Execute a ``get_attr`` node. In ``Transformer``, this is 476*da0073e9SAndroid Build Coastguard Worker overridden to insert a new ``get_attr`` node into the output 477*da0073e9SAndroid Build Coastguard Worker graph. 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker Args: 480*da0073e9SAndroid Build Coastguard Worker target (Target): The call target for this node. See 481*da0073e9SAndroid Build Coastguard Worker `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 482*da0073e9SAndroid Build Coastguard Worker details on semantics 483*da0073e9SAndroid Build Coastguard Worker args (Tuple): Tuple of positional args for this invocation 484*da0073e9SAndroid Build Coastguard Worker kwargs (Dict): Dict of keyword arguments for this invocation 485*da0073e9SAndroid Build Coastguard Worker """ 486*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 487*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy("get_attr", target, args, kwargs) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 490*da0073e9SAndroid Build Coastguard Worker def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 491*da0073e9SAndroid Build Coastguard Worker # Override so that the leaf module policy from `self.tracer` is respected. 492*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 493*da0073e9SAndroid Build Coastguard Worker submod = self.fetch_attr(target) 494*da0073e9SAndroid Build Coastguard Worker return self.tracer.call_module(submod, submod.forward, args, kwargs) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 497*da0073e9SAndroid Build Coastguard Worker def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 498*da0073e9SAndroid Build Coastguard Worker # Override so that functions that were wrapped are still wrapped. 499*da0073e9SAndroid Build Coastguard Worker return self.tracer.create_proxy('call_function', target, args, kwargs) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 502*da0073e9SAndroid Build Coastguard Worker def transform(self) -> GraphModule: 503*da0073e9SAndroid Build Coastguard Worker """ 504*da0073e9SAndroid Build Coastguard Worker Transform ``self.module`` and return the transformed 505*da0073e9SAndroid Build Coastguard Worker ``GraphModule``. 506*da0073e9SAndroid Build Coastguard Worker """ 507*da0073e9SAndroid Build Coastguard Worker with fx_traceback.preserve_node_meta(): 508*da0073e9SAndroid Build Coastguard Worker result = super().run(enable_io_processing=False) 509*da0073e9SAndroid Build Coastguard Worker if result is not None: 510*da0073e9SAndroid Build Coastguard Worker def strip_proxy(a : Union[Argument, Proxy]) -> Any: 511*da0073e9SAndroid Build Coastguard Worker return a.node if isinstance(a, Proxy) else a 512*da0073e9SAndroid Build Coastguard Worker new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) 513*da0073e9SAndroid Build Coastguard Worker # also preserve the metadata from the old output node, if it exists 514*da0073e9SAndroid Build Coastguard Worker old_output_node = list(self.graph.nodes)[-1] 515*da0073e9SAndroid Build Coastguard Worker assert old_output_node.op == "output" 516*da0073e9SAndroid Build Coastguard Worker for k, v in old_output_node.meta.items(): 517*da0073e9SAndroid Build Coastguard Worker new_output_node.meta[k] = v 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker return _make_graph_module(self.module, self.new_graph) 521