xref: /aosp_15_r20/external/pytorch/torch/fx/interpreter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom .graph_module import GraphModule
3*da0073e9SAndroid Build Coastguard Workerfrom ._lazy_graph_module import _make_graph_module
4*da0073e9SAndroid Build Coastguard Workerfrom .graph import Graph
5*da0073e9SAndroid Build Coastguard Workerfrom .node import Argument, Node, Target, map_arg, map_aggregate
6*da0073e9SAndroid Build Coastguard Workerfrom .proxy import Proxy
7*da0073e9SAndroid Build Coastguard Workerfrom ._symbolic_trace import Tracer
8*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
9*da0073e9SAndroid Build Coastguard Workerfrom . import config
10*da0073e9SAndroid Build Coastguard Workerimport torch.fx.traceback as fx_traceback
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterator, List, Optional, Tuple, Union
13*da0073e9SAndroid Build Coastguard Workerimport inspect
14*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
15*da0073e9SAndroid Build Coastguard Workerfrom torch.hub import tqdm
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker__all__ = ['Interpreter', 'Transformer']
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
20*da0073e9SAndroid Build Coastguard Workerclass Interpreter:
21*da0073e9SAndroid Build Coastguard Worker    """
22*da0073e9SAndroid Build Coastguard Worker    An Interpreter executes an FX graph Node-by-Node. This pattern
23*da0073e9SAndroid Build Coastguard Worker    can be useful for many things, including writing code
24*da0073e9SAndroid Build Coastguard Worker    transformations as well as analysis passes.
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    Methods in the Interpreter class can be overridden to customize
27*da0073e9SAndroid Build Coastguard Worker    the behavior of execution. The map of overrideable methods
28*da0073e9SAndroid Build Coastguard Worker    in terms of call hierarchy::
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker        run()
31*da0073e9SAndroid Build Coastguard Worker            +-- run_node
32*da0073e9SAndroid Build Coastguard Worker                +-- placeholder()
33*da0073e9SAndroid Build Coastguard Worker                +-- get_attr()
34*da0073e9SAndroid Build Coastguard Worker                +-- call_function()
35*da0073e9SAndroid Build Coastguard Worker                +-- call_method()
36*da0073e9SAndroid Build Coastguard Worker                +-- call_module()
37*da0073e9SAndroid Build Coastguard Worker                +-- output()
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    Example:
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        Suppose we want to swap all instances of ``torch.neg`` with
42*da0073e9SAndroid Build Coastguard Worker        ``torch.sigmoid`` and vice versa (including their ``Tensor``
43*da0073e9SAndroid Build Coastguard Worker        method equivalents). We could subclass Interpreter like so::
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker            class NegSigmSwapInterpreter(Interpreter):
46*da0073e9SAndroid Build Coastguard Worker                def call_function(self, target : Target,
47*da0073e9SAndroid Build Coastguard Worker                                  args : Tuple, kwargs : Dict) -> Any:
48*da0073e9SAndroid Build Coastguard Worker                    if target == torch.sigmoid:
49*da0073e9SAndroid Build Coastguard Worker                        return torch.neg(*args, **kwargs)
50*da0073e9SAndroid Build Coastguard Worker                    return super().call_function(n)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker                def call_method(self, target : Target,
53*da0073e9SAndroid Build Coastguard Worker                                args : Tuple, kwargs : Dict) -> Any:
54*da0073e9SAndroid Build Coastguard Worker                    if target == 'neg':
55*da0073e9SAndroid Build Coastguard Worker                        call_self, *args_tail = args
56*da0073e9SAndroid Build Coastguard Worker                        return call_self.sigmoid(*args_tail, **kwargs)
57*da0073e9SAndroid Build Coastguard Worker                    return super().call_method(n)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker            def fn(x):
60*da0073e9SAndroid Build Coastguard Worker                return torch.sigmoid(x).neg()
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker            gm = torch.fx.symbolic_trace(fn)
63*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 4)
64*da0073e9SAndroid Build Coastguard Worker            result = NegSigmSwapInterpreter(gm).run(input)
65*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, torch.neg(input).sigmoid())
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    Args:
68*da0073e9SAndroid Build Coastguard Worker        module (torch.nn.Module): The module to be executed
69*da0073e9SAndroid Build Coastguard Worker        garbage_collect_values (bool): Whether to delete values after their last
70*da0073e9SAndroid Build Coastguard Worker            use within the Module's execution. This ensures optimal memory usage during
71*da0073e9SAndroid Build Coastguard Worker            execution. This can be disabled to, for example, examine all of the intermediate
72*da0073e9SAndroid Build Coastguard Worker            values in the execution by looking at the ``Interpreter.env`` attribute.
73*da0073e9SAndroid Build Coastguard Worker        graph (Optional[Graph]): If passed, the interpreter will execute this
74*da0073e9SAndroid Build Coastguard Worker            graph instead of `module.graph`, using the provided `module`
75*da0073e9SAndroid Build Coastguard Worker            argument to satisfy any requests for state.
76*da0073e9SAndroid Build Coastguard Worker    """
77*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
78*da0073e9SAndroid Build Coastguard Worker    def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
79*da0073e9SAndroid Build Coastguard Worker        self.module = module
80*da0073e9SAndroid Build Coastguard Worker        self.submodules = dict(self.module.named_modules())
81*da0073e9SAndroid Build Coastguard Worker        if graph is not None:
82*da0073e9SAndroid Build Coastguard Worker            self.graph = graph
83*da0073e9SAndroid Build Coastguard Worker        else:
84*da0073e9SAndroid Build Coastguard Worker            self.graph = self.module.graph
85*da0073e9SAndroid Build Coastguard Worker        self.env : Dict[Node, Any] = {}
86*da0073e9SAndroid Build Coastguard Worker        self.name = "Interpreter"
87*da0073e9SAndroid Build Coastguard Worker        self.garbage_collect_values = garbage_collect_values
88*da0073e9SAndroid Build Coastguard Worker        self.extra_traceback = True
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        if self.garbage_collect_values:
91*da0073e9SAndroid Build Coastguard Worker            # Run through reverse nodes and record the first instance of a use
92*da0073e9SAndroid Build Coastguard Worker            # of a given node. This represents the *last* use of the node in the
93*da0073e9SAndroid Build Coastguard Worker            # execution order of the program, which we will use to free unused
94*da0073e9SAndroid Build Coastguard Worker            # values
95*da0073e9SAndroid Build Coastguard Worker            node_to_last_use : Dict[Node, Node] = {}
96*da0073e9SAndroid Build Coastguard Worker            self.user_to_last_uses : Dict[Node, List[Node]] = {}
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker            def register_last_uses(n : Node, user : Node):
99*da0073e9SAndroid Build Coastguard Worker                if n not in node_to_last_use:
100*da0073e9SAndroid Build Coastguard Worker                    node_to_last_use[n] = user
101*da0073e9SAndroid Build Coastguard Worker                    self.user_to_last_uses.setdefault(user, []).append(n)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            for node in reversed(self.graph.nodes):
104*da0073e9SAndroid Build Coastguard Worker                map_arg(node.args, lambda n: register_last_uses(n, node))
105*da0073e9SAndroid Build Coastguard Worker                map_arg(node.kwargs, lambda n: register_last_uses(n, node))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
108*da0073e9SAndroid Build Coastguard Worker    def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
109*da0073e9SAndroid Build Coastguard Worker        """
110*da0073e9SAndroid Build Coastguard Worker        Run `module` via interpretation and return the result.
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        Args:
113*da0073e9SAndroid Build Coastguard Worker            *args: The arguments to the Module to run, in positional order
114*da0073e9SAndroid Build Coastguard Worker            initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
115*da0073e9SAndroid Build Coastguard Worker                This is a dict mapping `Node` to any value. This can be used, for example, to
116*da0073e9SAndroid Build Coastguard Worker                pre-populate results for certain `Nodes` so as to do only partial evaluation within
117*da0073e9SAndroid Build Coastguard Worker                the interpreter.
118*da0073e9SAndroid Build Coastguard Worker            enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
119*da0073e9SAndroid Build Coastguard Worker                process_outputs function first before using them.
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        Returns:
122*da0073e9SAndroid Build Coastguard Worker            Any: The value returned from executing the Module
123*da0073e9SAndroid Build Coastguard Worker        """
124*da0073e9SAndroid Build Coastguard Worker        self.env = initial_env if initial_env is not None else {}
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        # Positional function args are consumed left-to-right by
127*da0073e9SAndroid Build Coastguard Worker        # `placeholder` nodes. Use an iterator to keep track of
128*da0073e9SAndroid Build Coastguard Worker        # position and extract those values.
129*da0073e9SAndroid Build Coastguard Worker        if enable_io_processing:
130*da0073e9SAndroid Build Coastguard Worker            args = self.graph.process_inputs(*args)
131*da0073e9SAndroid Build Coastguard Worker        self.args_iter : Iterator[Any] = iter(args)
132*da0073e9SAndroid Build Coastguard Worker        pbar = tqdm(total=len(self.graph.nodes),
133*da0073e9SAndroid Build Coastguard Worker                    desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
134*da0073e9SAndroid Build Coastguard Worker                    initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        for node in self.graph.nodes:
137*da0073e9SAndroid Build Coastguard Worker            pbar.update(1)
138*da0073e9SAndroid Build Coastguard Worker            if node in self.env:
139*da0073e9SAndroid Build Coastguard Worker                # Short circuit if we have this value. This could
140*da0073e9SAndroid Build Coastguard Worker                # be used, for example, for partial evaluation
141*da0073e9SAndroid Build Coastguard Worker                # where the caller has pre-populated `env` with
142*da0073e9SAndroid Build Coastguard Worker                # values for a subset of the program.
143*da0073e9SAndroid Build Coastguard Worker                continue
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            try:
146*da0073e9SAndroid Build Coastguard Worker                self.env[node] = self.run_node(node)
147*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
148*da0073e9SAndroid Build Coastguard Worker                if self.extra_traceback:
149*da0073e9SAndroid Build Coastguard Worker                    msg = f"While executing {node.format_node()}"
150*da0073e9SAndroid Build Coastguard Worker                    msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
151*da0073e9SAndroid Build Coastguard Worker                    msg += f"\nOriginal traceback:\n{node.stack_trace}"
152*da0073e9SAndroid Build Coastguard Worker                    e.args = (msg,) + e.args[1:]
153*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, KeyError):
154*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(*e.args) from e
155*da0073e9SAndroid Build Coastguard Worker                raise
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker            if self.garbage_collect_values:
158*da0073e9SAndroid Build Coastguard Worker                for to_delete in self.user_to_last_uses.get(node, []):
159*da0073e9SAndroid Build Coastguard Worker                    del self.env[to_delete]
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker            if node.op == 'output':
162*da0073e9SAndroid Build Coastguard Worker                output_val = self.env[node]
163*da0073e9SAndroid Build Coastguard Worker                return self.graph.process_outputs(output_val) if enable_io_processing else output_val
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
166*da0073e9SAndroid Build Coastguard Worker    def boxed_run(self, args_list):
167*da0073e9SAndroid Build Coastguard Worker        """
168*da0073e9SAndroid Build Coastguard Worker        Run `module` via interpretation and return the result.  This uses the "boxed"
169*da0073e9SAndroid Build Coastguard Worker        calling convention, where you pass a list of arguments, which will be cleared
170*da0073e9SAndroid Build Coastguard Worker        by the interpreter.  This ensures that input tensors are promptly deallocated.
171*da0073e9SAndroid Build Coastguard Worker        """
172*da0073e9SAndroid Build Coastguard Worker        args_iter = iter(args_list)
173*da0073e9SAndroid Build Coastguard Worker        env = {}
174*da0073e9SAndroid Build Coastguard Worker        for n in self.graph.nodes:
175*da0073e9SAndroid Build Coastguard Worker            if n.op == "placeholder":
176*da0073e9SAndroid Build Coastguard Worker                env[n] = next(args_iter)
177*da0073e9SAndroid Build Coastguard Worker        args_list.clear()
178*da0073e9SAndroid Build Coastguard Worker        return self.run(initial_env=env)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @contextmanager
181*da0073e9SAndroid Build Coastguard Worker    def _set_current_node(self, node):
182*da0073e9SAndroid Build Coastguard Worker        with fx_traceback.set_current_meta(node):
183*da0073e9SAndroid Build Coastguard Worker            yield
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
186*da0073e9SAndroid Build Coastguard Worker    def run_node(self, n : Node) -> Any:
187*da0073e9SAndroid Build Coastguard Worker        """
188*da0073e9SAndroid Build Coastguard Worker        Run a specific node ``n`` and return the result.
189*da0073e9SAndroid Build Coastguard Worker        Calls into placeholder, get_attr, call_function,
190*da0073e9SAndroid Build Coastguard Worker        call_method, call_module, or output depending
191*da0073e9SAndroid Build Coastguard Worker        on ``node.op``
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        Args:
194*da0073e9SAndroid Build Coastguard Worker            n (Node): The Node to execute
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        Returns:
197*da0073e9SAndroid Build Coastguard Worker            Any: The result of executing ``n``
198*da0073e9SAndroid Build Coastguard Worker        """
199*da0073e9SAndroid Build Coastguard Worker        with self._set_current_node(n):
200*da0073e9SAndroid Build Coastguard Worker            args, kwargs = self.fetch_args_kwargs_from_env(n)
201*da0073e9SAndroid Build Coastguard Worker            assert isinstance(args, tuple)
202*da0073e9SAndroid Build Coastguard Worker            assert isinstance(kwargs, dict)
203*da0073e9SAndroid Build Coastguard Worker            return getattr(self, n.op)(n.target, args, kwargs)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    # Main Node running APIs
206*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
207*da0073e9SAndroid Build Coastguard Worker    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
208*da0073e9SAndroid Build Coastguard Worker        """
209*da0073e9SAndroid Build Coastguard Worker        Execute a ``placeholder`` node. Note that this is stateful:
210*da0073e9SAndroid Build Coastguard Worker        ``Interpreter`` maintains an internal iterator over
211*da0073e9SAndroid Build Coastguard Worker        arguments passed to ``run`` and this method returns
212*da0073e9SAndroid Build Coastguard Worker        next() on that iterator.
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        Args:
215*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
216*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
217*da0073e9SAndroid Build Coastguard Worker                details on semantics
218*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
219*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        Returns:
222*da0073e9SAndroid Build Coastguard Worker            Any: The argument value that was retrieved.
223*da0073e9SAndroid Build Coastguard Worker        """
224*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
225*da0073e9SAndroid Build Coastguard Worker        if target.startswith('*'):
226*da0073e9SAndroid Build Coastguard Worker            # For a starred parameter e.g. `*args`, retrieve all
227*da0073e9SAndroid Build Coastguard Worker            # remaining values from the args list.
228*da0073e9SAndroid Build Coastguard Worker            return list(self.args_iter)
229*da0073e9SAndroid Build Coastguard Worker        else:
230*da0073e9SAndroid Build Coastguard Worker            try:
231*da0073e9SAndroid Build Coastguard Worker                return next(self.args_iter)
232*da0073e9SAndroid Build Coastguard Worker            except StopIteration as si:
233*da0073e9SAndroid Build Coastguard Worker                if len(args) > 0:
234*da0073e9SAndroid Build Coastguard Worker                    return args[0]
235*da0073e9SAndroid Build Coastguard Worker                else:
236*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
239*da0073e9SAndroid Build Coastguard Worker    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
240*da0073e9SAndroid Build Coastguard Worker        """
241*da0073e9SAndroid Build Coastguard Worker        Execute a ``get_attr`` node. Will retrieve an attribute
242*da0073e9SAndroid Build Coastguard Worker        value from the ``Module`` hierarchy of ``self.module``.
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        Args:
245*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
246*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
247*da0073e9SAndroid Build Coastguard Worker                details on semantics
248*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
249*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        Return:
252*da0073e9SAndroid Build Coastguard Worker            Any: The value of the attribute that was retrieved
253*da0073e9SAndroid Build Coastguard Worker        """
254*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
255*da0073e9SAndroid Build Coastguard Worker        return self.fetch_attr(target)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
258*da0073e9SAndroid Build Coastguard Worker    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
259*da0073e9SAndroid Build Coastguard Worker        """
260*da0073e9SAndroid Build Coastguard Worker        Execute a ``call_function`` node and return the result.
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        Args:
263*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
264*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
265*da0073e9SAndroid Build Coastguard Worker                details on semantics
266*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
267*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        Return
270*da0073e9SAndroid Build Coastguard Worker            Any: The value returned by the function invocation
271*da0073e9SAndroid Build Coastguard Worker        """
272*da0073e9SAndroid Build Coastguard Worker        assert not isinstance(target, str)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        # Execute the function and return the result
275*da0073e9SAndroid Build Coastguard Worker        return target(*args, **kwargs)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
278*da0073e9SAndroid Build Coastguard Worker    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
279*da0073e9SAndroid Build Coastguard Worker        """
280*da0073e9SAndroid Build Coastguard Worker        Execute a ``call_method`` node and return the result.
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        Args:
283*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
284*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
285*da0073e9SAndroid Build Coastguard Worker                details on semantics
286*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
287*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        Return
290*da0073e9SAndroid Build Coastguard Worker            Any: The value returned by the method invocation
291*da0073e9SAndroid Build Coastguard Worker        """
292*da0073e9SAndroid Build Coastguard Worker        # args[0] is the `self` object for this method call
293*da0073e9SAndroid Build Coastguard Worker        self_obj, *args_tail = args
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        # Execute the method and return the result
296*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
297*da0073e9SAndroid Build Coastguard Worker        return getattr(self_obj, target)(*args_tail, **kwargs)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
300*da0073e9SAndroid Build Coastguard Worker    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
301*da0073e9SAndroid Build Coastguard Worker        """
302*da0073e9SAndroid Build Coastguard Worker        Execute a ``call_module`` node and return the result.
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        Args:
305*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
306*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
307*da0073e9SAndroid Build Coastguard Worker                details on semantics
308*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
309*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        Return
312*da0073e9SAndroid Build Coastguard Worker            Any: The value returned by the module invocation
313*da0073e9SAndroid Build Coastguard Worker        """
314*da0073e9SAndroid Build Coastguard Worker        # Retrieve executed args and kwargs values from the environment
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        # Execute the method and return the result
317*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
318*da0073e9SAndroid Build Coastguard Worker        submod = self.fetch_attr(target)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        return submod(*args, **kwargs)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
323*da0073e9SAndroid Build Coastguard Worker    def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
324*da0073e9SAndroid Build Coastguard Worker        """
325*da0073e9SAndroid Build Coastguard Worker        Execute an ``output`` node. This really just retrieves
326*da0073e9SAndroid Build Coastguard Worker        the value referenced by the ``output`` node and returns it.
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        Args:
329*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
330*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
331*da0073e9SAndroid Build Coastguard Worker                details on semantics
332*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
333*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        Return:
336*da0073e9SAndroid Build Coastguard Worker            Any: The return value referenced by the output node
337*da0073e9SAndroid Build Coastguard Worker        """
338*da0073e9SAndroid Build Coastguard Worker        return args[0]
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    # Helper methods
341*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
342*da0073e9SAndroid Build Coastguard Worker    def fetch_attr(self, target : str):
343*da0073e9SAndroid Build Coastguard Worker        """
344*da0073e9SAndroid Build Coastguard Worker        Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        Args:
347*da0073e9SAndroid Build Coastguard Worker            target (str): The fully-qualified name of the attribute to fetch
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        Return:
350*da0073e9SAndroid Build Coastguard Worker            Any: The value of the attribute.
351*da0073e9SAndroid Build Coastguard Worker        """
352*da0073e9SAndroid Build Coastguard Worker        target_atoms = target.split('.')
353*da0073e9SAndroid Build Coastguard Worker        attr_itr = self.module
354*da0073e9SAndroid Build Coastguard Worker        for i, atom in enumerate(target_atoms):
355*da0073e9SAndroid Build Coastguard Worker            if not hasattr(attr_itr, atom):
356*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}")
357*da0073e9SAndroid Build Coastguard Worker            attr_itr = getattr(attr_itr, atom)
358*da0073e9SAndroid Build Coastguard Worker        return attr_itr
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
361*da0073e9SAndroid Build Coastguard Worker    def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
362*da0073e9SAndroid Build Coastguard Worker        """
363*da0073e9SAndroid Build Coastguard Worker        Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
364*da0073e9SAndroid Build Coastguard Worker        from the current execution environment.
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        Args:
367*da0073e9SAndroid Build Coastguard Worker            n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        Return:
370*da0073e9SAndroid Build Coastguard Worker            Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
371*da0073e9SAndroid Build Coastguard Worker        """
372*da0073e9SAndroid Build Coastguard Worker        args = self.map_nodes_to_values(n.args, n)
373*da0073e9SAndroid Build Coastguard Worker        assert isinstance(args, tuple)
374*da0073e9SAndroid Build Coastguard Worker        kwargs = self.map_nodes_to_values(n.kwargs, n)
375*da0073e9SAndroid Build Coastguard Worker        assert isinstance(kwargs, dict)
376*da0073e9SAndroid Build Coastguard Worker        return args, kwargs
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
379*da0073e9SAndroid Build Coastguard Worker    def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
380*da0073e9SAndroid Build Coastguard Worker        """
381*da0073e9SAndroid Build Coastguard Worker        Recursively descend through ``args`` and look up the concrete value
382*da0073e9SAndroid Build Coastguard Worker        for each ``Node`` in the current execution environment.
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        Args:
385*da0073e9SAndroid Build Coastguard Worker            args (Argument): Data structure within which to look up concrete values
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker            n (Node): Node to which ``args`` belongs. This is only used for error reporting.
388*da0073e9SAndroid Build Coastguard Worker        """
389*da0073e9SAndroid Build Coastguard Worker        def load_arg(n_arg : Node) -> Any:
390*da0073e9SAndroid Build Coastguard Worker            if n_arg not in self.env:
391*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
392*da0073e9SAndroid Build Coastguard Worker                                   f'to diagnose such issues')
393*da0073e9SAndroid Build Coastguard Worker            return self.env[n_arg]
394*da0073e9SAndroid Build Coastguard Worker        return map_arg(args, load_arg)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
397*da0073e9SAndroid Build Coastguard Workerclass Transformer(Interpreter):
398*da0073e9SAndroid Build Coastguard Worker    """
399*da0073e9SAndroid Build Coastguard Worker    ``Transformer`` is a special type of interpreter that produces a
400*da0073e9SAndroid Build Coastguard Worker    new ``Module``. It exposes a ``transform()`` method that returns
401*da0073e9SAndroid Build Coastguard Worker    the transformed ``Module``. ``Transformer`` does not require
402*da0073e9SAndroid Build Coastguard Worker    arguments to run, as ``Interpreter`` does. ``Transformer`` works
403*da0073e9SAndroid Build Coastguard Worker    entirely symbolically.
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker    Example:
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        Suppose we want to swap all instances of ``torch.neg`` with
408*da0073e9SAndroid Build Coastguard Worker        ``torch.sigmoid`` and vice versa (including their ``Tensor``
409*da0073e9SAndroid Build Coastguard Worker        method equivalents). We could subclass ``Transformer`` like so::
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker            class NegSigmSwapXformer(Transformer):
412*da0073e9SAndroid Build Coastguard Worker                def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
413*da0073e9SAndroid Build Coastguard Worker                    if target == torch.sigmoid:
414*da0073e9SAndroid Build Coastguard Worker                        return torch.neg(*args, **kwargs)
415*da0073e9SAndroid Build Coastguard Worker                    return super().call_function(n)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker                def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
418*da0073e9SAndroid Build Coastguard Worker                    if target == 'neg':
419*da0073e9SAndroid Build Coastguard Worker                        call_self, *args_tail = args
420*da0073e9SAndroid Build Coastguard Worker                        return call_self.sigmoid(*args_tail, **kwargs)
421*da0073e9SAndroid Build Coastguard Worker                    return super().call_method(n)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker            def fn(x):
424*da0073e9SAndroid Build Coastguard Worker                return torch.sigmoid(x).neg()
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker            gm = torch.fx.symbolic_trace(fn)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker            transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
429*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 4)
430*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    Args:
433*da0073e9SAndroid Build Coastguard Worker        module (GraphModule): The ``Module`` to be transformed.
434*da0073e9SAndroid Build Coastguard Worker    """
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
437*da0073e9SAndroid Build Coastguard Worker    def __init__(self, module):
438*da0073e9SAndroid Build Coastguard Worker        super().__init__(module)
439*da0073e9SAndroid Build Coastguard Worker        self.new_graph = Graph()
440*da0073e9SAndroid Build Coastguard Worker        self.new_graph.set_codegen(module.graph._codegen)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        class TransformerTracer(Tracer):
443*da0073e9SAndroid Build Coastguard Worker            def __init__(self, graph: Graph):
444*da0073e9SAndroid Build Coastguard Worker                super().__init__()
445*da0073e9SAndroid Build Coastguard Worker                self.graph = graph
446*da0073e9SAndroid Build Coastguard Worker                self.tensor_attrs: Dict[torch.Tensor, str] = {}  # type: ignore[assignment]
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, _, __) -> bool:
449*da0073e9SAndroid Build Coastguard Worker                return True
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        self.tracer = TransformerTracer(self.new_graph)
452*da0073e9SAndroid Build Coastguard Worker        self.tracer.root = module
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
455*da0073e9SAndroid Build Coastguard Worker    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
456*da0073e9SAndroid Build Coastguard Worker        """
457*da0073e9SAndroid Build Coastguard Worker        Execute a ``placeholder`` node. In ``Transformer``, this is
458*da0073e9SAndroid Build Coastguard Worker        overridden to insert a new ``placeholder`` into the output
459*da0073e9SAndroid Build Coastguard Worker        graph.
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        Args:
462*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
463*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
464*da0073e9SAndroid Build Coastguard Worker                details on semantics
465*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
466*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
467*da0073e9SAndroid Build Coastguard Worker        """
468*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
469*da0073e9SAndroid Build Coastguard Worker        default_value = next(iter(args)) if args else inspect.Signature.empty
470*da0073e9SAndroid Build Coastguard Worker        return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
473*da0073e9SAndroid Build Coastguard Worker    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
474*da0073e9SAndroid Build Coastguard Worker        """
475*da0073e9SAndroid Build Coastguard Worker        Execute a ``get_attr`` node. In ``Transformer``, this is
476*da0073e9SAndroid Build Coastguard Worker        overridden to insert a new ``get_attr`` node into the output
477*da0073e9SAndroid Build Coastguard Worker        graph.
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker        Args:
480*da0073e9SAndroid Build Coastguard Worker            target (Target): The call target for this node. See
481*da0073e9SAndroid Build Coastguard Worker                `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
482*da0073e9SAndroid Build Coastguard Worker                details on semantics
483*da0073e9SAndroid Build Coastguard Worker            args (Tuple): Tuple of positional args for this invocation
484*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict): Dict of keyword arguments for this invocation
485*da0073e9SAndroid Build Coastguard Worker        """
486*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
487*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy("get_attr", target, args, kwargs)
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
490*da0073e9SAndroid Build Coastguard Worker    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
491*da0073e9SAndroid Build Coastguard Worker        # Override so that the leaf module policy from `self.tracer` is respected.
492*da0073e9SAndroid Build Coastguard Worker        assert isinstance(target, str)
493*da0073e9SAndroid Build Coastguard Worker        submod = self.fetch_attr(target)
494*da0073e9SAndroid Build Coastguard Worker        return self.tracer.call_module(submod, submod.forward, args, kwargs)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
497*da0073e9SAndroid Build Coastguard Worker    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
498*da0073e9SAndroid Build Coastguard Worker        # Override so that functions that were wrapped are still wrapped.
499*da0073e9SAndroid Build Coastguard Worker        return self.tracer.create_proxy('call_function', target, args, kwargs)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
502*da0073e9SAndroid Build Coastguard Worker    def transform(self) -> GraphModule:
503*da0073e9SAndroid Build Coastguard Worker        """
504*da0073e9SAndroid Build Coastguard Worker        Transform ``self.module`` and return the transformed
505*da0073e9SAndroid Build Coastguard Worker        ``GraphModule``.
506*da0073e9SAndroid Build Coastguard Worker        """
507*da0073e9SAndroid Build Coastguard Worker        with fx_traceback.preserve_node_meta():
508*da0073e9SAndroid Build Coastguard Worker            result = super().run(enable_io_processing=False)
509*da0073e9SAndroid Build Coastguard Worker        if result is not None:
510*da0073e9SAndroid Build Coastguard Worker            def strip_proxy(a : Union[Argument, Proxy]) -> Any:
511*da0073e9SAndroid Build Coastguard Worker                return a.node if isinstance(a, Proxy) else a
512*da0073e9SAndroid Build Coastguard Worker            new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
513*da0073e9SAndroid Build Coastguard Worker            # also preserve the metadata from the old output node, if it exists
514*da0073e9SAndroid Build Coastguard Worker            old_output_node = list(self.graph.nodes)[-1]
515*da0073e9SAndroid Build Coastguard Worker            assert old_output_node.op == "output"
516*da0073e9SAndroid Build Coastguard Worker            for k, v in old_output_node.meta.items():
517*da0073e9SAndroid Build Coastguard Worker                new_output_node.meta[k] = v
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        return _make_graph_module(self.module, self.new_graph)
521