xref: /aosp_15_r20/external/pytorch/torch/fx/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
3*da0073e9SAndroid Build Coastguard Workerfrom .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
4*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
5*da0073e9SAndroid Build Coastguard Workerfrom . import _pytree as fx_pytree
6*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _NodeIter
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport os
10*da0073e9SAndroid Build Coastguard Workerimport contextlib
11*da0073e9SAndroid Build Coastguard Workerfrom typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable
12*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
13*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
14*da0073e9SAndroid Build Coastguard Workerimport copy
15*da0073e9SAndroid Build Coastguard Workerimport enum
16*da0073e9SAndroid Build Coastguard Workerimport torch
17*da0073e9SAndroid Build Coastguard Workerimport keyword
18*da0073e9SAndroid Build Coastguard Workerimport re
19*da0073e9SAndroid Build Coastguard Workerimport builtins
20*da0073e9SAndroid Build Coastguard Workerimport math
21*da0073e9SAndroid Build Coastguard Workerimport warnings
22*da0073e9SAndroid Build Coastguard Workerimport inspect
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker__all__ = ["PythonCode", "CodeGen", "Graph"]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
27*da0073e9SAndroid Build Coastguard Worker    from .graph_module import GraphModule  # noqa: F401
28*da0073e9SAndroid Build Coastguard Worker    from ._symbolic_trace import Tracer   # noqa: F401
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker# Mapping of builtins to their `typing` equivalent.
32*da0073e9SAndroid Build Coastguard Worker_origin_type_map = {
33*da0073e9SAndroid Build Coastguard Worker    list: List,
34*da0073e9SAndroid Build Coastguard Worker    dict: Dict,
35*da0073e9SAndroid Build Coastguard Worker    set: Set,
36*da0073e9SAndroid Build Coastguard Worker    frozenset: FrozenSet,
37*da0073e9SAndroid Build Coastguard Worker    tuple: Tuple,
38*da0073e9SAndroid Build Coastguard Worker}
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker# Signature for functions thattransforms the body (`list[str]`) of the
42*da0073e9SAndroid Build Coastguard Worker# generated code
43*da0073e9SAndroid Build Coastguard WorkerTransformCodeFunc = Callable[[List[str]], List[str]]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass _CustomBuiltin(NamedTuple):
47*da0073e9SAndroid Build Coastguard Worker    """Additional objs that we add to every graph's globals.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    The repr() for some standard library objects is not valid Python code without
50*da0073e9SAndroid Build Coastguard Worker    an import. For common objects of this sort, we bundle them in the globals of
51*da0073e9SAndroid Build Coastguard Worker    every FX graph.
52*da0073e9SAndroid Build Coastguard Worker    """
53*da0073e9SAndroid Build Coastguard Worker    # How to import this object from the standard library.
54*da0073e9SAndroid Build Coastguard Worker    import_str: str
55*da0073e9SAndroid Build Coastguard Worker    # The actual object, produced from that import string.
56*da0073e9SAndroid Build Coastguard Worker    obj: Any
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker_custom_builtins: Dict[str, _CustomBuiltin] = {}
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerdef _register_custom_builtin(name: str, import_str: str, obj: Any):
62*da0073e9SAndroid Build Coastguard Worker    _custom_builtins[name] = _CustomBuiltin(import_str, obj)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('inf', 'from math import inf', math.inf)
66*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('nan', 'from math import nan', math.nan)
67*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
68*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('torch', 'import torch', torch)
69*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('device', 'from torch import device', torch.device)
70*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
71*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Workerdef _is_magic(x: str) -> bool:
75*da0073e9SAndroid Build Coastguard Worker    return x.startswith('__') and x.endswith('__')
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workerdef _snake_case(s: str) -> str:
79*da0073e9SAndroid Build Coastguard Worker    """
80*da0073e9SAndroid Build Coastguard Worker    Transforms the given string ``s`` to a Python-style variable name
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    Examples:
83*da0073e9SAndroid Build Coastguard Worker        ``mod.snake_case`` -> ``mod.snake_case``
84*da0073e9SAndroid Build Coastguard Worker        ``mod.pascalCase``-> ``mod.pascal_case``
85*da0073e9SAndroid Build Coastguard Worker        ``mod.ALL_CAPS`` -> ``mod.all_caps``
86*da0073e9SAndroid Build Coastguard Worker    """
87*da0073e9SAndroid Build Coastguard Worker    chars = []
88*da0073e9SAndroid Build Coastguard Worker    prev_lower = False
89*da0073e9SAndroid Build Coastguard Worker    for c in s:
90*da0073e9SAndroid Build Coastguard Worker        if prev_lower and c.isupper():
91*da0073e9SAndroid Build Coastguard Worker            chars.append('_')
92*da0073e9SAndroid Build Coastguard Worker        chars.append(c.lower())
93*da0073e9SAndroid Build Coastguard Worker        prev_lower = c.islower()
94*da0073e9SAndroid Build Coastguard Worker    return ''.join(chars)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerdef _is_from_torch(obj: Any) -> bool:
98*da0073e9SAndroid Build Coastguard Worker    module_name = getattr(obj, '__module__', None)
99*da0073e9SAndroid Build Coastguard Worker    if module_name is not None:
100*da0073e9SAndroid Build Coastguard Worker        base_module = module_name.partition('.')[0]
101*da0073e9SAndroid Build Coastguard Worker        return (
102*da0073e9SAndroid Build Coastguard Worker            base_module == 'torch' and
103*da0073e9SAndroid Build Coastguard Worker            not module_name.startswith("torch._dynamo.") and
104*da0073e9SAndroid Build Coastguard Worker            not module_name.startswith("torch._inductor.")
105*da0073e9SAndroid Build Coastguard Worker        )
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    name = getattr(obj, '__name__', None)
108*da0073e9SAndroid Build Coastguard Worker    # exclude torch because torch.torch.torch.torch works. idk mang
109*da0073e9SAndroid Build Coastguard Worker    if name is not None and name != 'torch':
110*da0073e9SAndroid Build Coastguard Worker        for guess in [torch, torch.nn.functional]:
111*da0073e9SAndroid Build Coastguard Worker            if getattr(guess, name, None) is obj:
112*da0073e9SAndroid Build Coastguard Worker                return True
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    return False
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerclass _Namespace:
118*da0073e9SAndroid Build Coastguard Worker    """A context for associating names uniquely with objects.
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    The following invariants are enforced:
121*da0073e9SAndroid Build Coastguard Worker    - Each object gets a single name.
122*da0073e9SAndroid Build Coastguard Worker    - Each name is unique within a given namespace.
123*da0073e9SAndroid Build Coastguard Worker    - Names generated do not shadow builtins, unless the object is indeed that builtin.
124*da0073e9SAndroid Build Coastguard Worker    """
125*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
126*da0073e9SAndroid Build Coastguard Worker        self._obj_to_name: Dict[Any, str] = {}
127*da0073e9SAndroid Build Coastguard Worker        self._unassociated_names = set()
128*da0073e9SAndroid Build Coastguard Worker        self._used_names: Set[str] = set()
129*da0073e9SAndroid Build Coastguard Worker        self._base_count: Dict[str, int] = defaultdict(int)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
132*da0073e9SAndroid Build Coastguard Worker        self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    def create_name(self, candidate: str, obj: Optional[Any]) -> str:
135*da0073e9SAndroid Build Coastguard Worker        """Create a unique name.
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker        Arguments:
138*da0073e9SAndroid Build Coastguard Worker            candidate: used as the basis for the unique name, relevant to the user.
139*da0073e9SAndroid Build Coastguard Worker            obj: If not None, an object that will be associated with the unique name.
140*da0073e9SAndroid Build Coastguard Worker        """
141*da0073e9SAndroid Build Coastguard Worker        if obj is not None and obj in self._obj_to_name:
142*da0073e9SAndroid Build Coastguard Worker            return self._obj_to_name[obj]
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # delete all characters that are illegal in a Python identifier
145*da0073e9SAndroid Build Coastguard Worker        candidate = self._illegal_char_regex.sub('_', candidate)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        if not candidate:
148*da0073e9SAndroid Build Coastguard Worker            candidate = '_unnamed'
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        if candidate[0].isdigit():
151*da0073e9SAndroid Build Coastguard Worker            candidate = f'_{candidate}'
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        match = self._name_suffix_regex.match(candidate)
154*da0073e9SAndroid Build Coastguard Worker        if match is None:
155*da0073e9SAndroid Build Coastguard Worker            base = candidate
156*da0073e9SAndroid Build Coastguard Worker            num = None
157*da0073e9SAndroid Build Coastguard Worker        else:
158*da0073e9SAndroid Build Coastguard Worker            base, num_str = match.group(1, 2)
159*da0073e9SAndroid Build Coastguard Worker            num = int(num_str)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        candidate = base if num is None else f'{base}_{num}'
162*da0073e9SAndroid Build Coastguard Worker        if not num:
163*da0073e9SAndroid Build Coastguard Worker            num = self._base_count[base]
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        while candidate in self._used_names or self._is_illegal_name(candidate, obj):
166*da0073e9SAndroid Build Coastguard Worker            num += 1
167*da0073e9SAndroid Build Coastguard Worker            candidate = f'{base}_{num}'
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        self._used_names.add(candidate)
170*da0073e9SAndroid Build Coastguard Worker        self._base_count[base] = num
171*da0073e9SAndroid Build Coastguard Worker        if obj is None:
172*da0073e9SAndroid Build Coastguard Worker            self._unassociated_names.add(candidate)
173*da0073e9SAndroid Build Coastguard Worker        else:
174*da0073e9SAndroid Build Coastguard Worker            self._obj_to_name[obj] = candidate
175*da0073e9SAndroid Build Coastguard Worker        return candidate
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def associate_name_with_obj(self, name: str, obj: Any):
178*da0073e9SAndroid Build Coastguard Worker        """Associate a unique name with an object.
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        Neither `name` nor `obj` should be associated already.
181*da0073e9SAndroid Build Coastguard Worker        """
182*da0073e9SAndroid Build Coastguard Worker        assert obj not in self._obj_to_name
183*da0073e9SAndroid Build Coastguard Worker        assert name in self._unassociated_names
184*da0073e9SAndroid Build Coastguard Worker        self._obj_to_name[obj] = name
185*da0073e9SAndroid Build Coastguard Worker        self._unassociated_names.remove(name)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def _is_illegal_name(self, name: str, obj: Any) -> bool:
188*da0073e9SAndroid Build Coastguard Worker        # 1. keywords are never allowed as names.
189*da0073e9SAndroid Build Coastguard Worker        if name in keyword.kwlist:
190*da0073e9SAndroid Build Coastguard Worker            return True
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        # 2. Can't shadow a builtin name, unless you *are* that builtin.
193*da0073e9SAndroid Build Coastguard Worker        if name in builtins.__dict__:
194*da0073e9SAndroid Build Coastguard Worker            return obj is not builtins.__dict__[name]
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        # 3. Can't shadow our custom builtins either
197*da0073e9SAndroid Build Coastguard Worker        if name in _custom_builtins:
198*da0073e9SAndroid Build Coastguard Worker            return obj is not _custom_builtins[name].obj
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        return False
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def _rename_object(self, obj: Any, name: str):
203*da0073e9SAndroid Build Coastguard Worker        assert obj in self._obj_to_name
204*da0073e9SAndroid Build Coastguard Worker        self._obj_to_name[obj] = name
205*da0073e9SAndroid Build Coastguard Worker        self._used_names.add(name)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Workerdtype_abbrs = {
208*da0073e9SAndroid Build Coastguard Worker    torch.bfloat16: 'bf16',
209*da0073e9SAndroid Build Coastguard Worker    torch.float64: 'f64',
210*da0073e9SAndroid Build Coastguard Worker    torch.float32: 'f32',
211*da0073e9SAndroid Build Coastguard Worker    torch.float16: 'f16',
212*da0073e9SAndroid Build Coastguard Worker    torch.float8_e4m3fn: 'f8e4m3fn',
213*da0073e9SAndroid Build Coastguard Worker    torch.float8_e5m2: 'f8e5m2',
214*da0073e9SAndroid Build Coastguard Worker    torch.float8_e4m3fnuz: 'f8e4m3fnuz',
215*da0073e9SAndroid Build Coastguard Worker    torch.float8_e5m2fnuz: 'f8e5m2fnuz',
216*da0073e9SAndroid Build Coastguard Worker    torch.complex32: 'c32',
217*da0073e9SAndroid Build Coastguard Worker    torch.complex64: 'c64',
218*da0073e9SAndroid Build Coastguard Worker    torch.complex128: 'c128',
219*da0073e9SAndroid Build Coastguard Worker    torch.int8: 'i8',
220*da0073e9SAndroid Build Coastguard Worker    torch.int16: 'i16',
221*da0073e9SAndroid Build Coastguard Worker    torch.int32: 'i32',
222*da0073e9SAndroid Build Coastguard Worker    torch.int64: 'i64',
223*da0073e9SAndroid Build Coastguard Worker    torch.bool: 'b8',
224*da0073e9SAndroid Build Coastguard Worker    torch.uint8: 'u8',
225*da0073e9SAndroid Build Coastguard Worker    torch.uint16: 'u16',
226*da0073e9SAndroid Build Coastguard Worker    torch.uint32: 'u32',
227*da0073e9SAndroid Build Coastguard Worker    torch.uint64: 'u64',
228*da0073e9SAndroid Build Coastguard Worker    torch.bits16: 'b16',
229*da0073e9SAndroid Build Coastguard Worker}
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
232*da0073e9SAndroid Build Coastguard Worker@dataclass
233*da0073e9SAndroid Build Coastguard Workerclass PythonCode:
234*da0073e9SAndroid Build Coastguard Worker    """
235*da0073e9SAndroid Build Coastguard Worker    Represents all the information necessary to exec or save a graph as Python code.
236*da0073e9SAndroid Build Coastguard Worker    """
237*da0073e9SAndroid Build Coastguard Worker    # Python source code for the forward function definition.
238*da0073e9SAndroid Build Coastguard Worker    src: str
239*da0073e9SAndroid Build Coastguard Worker    # Values in global scope during execution of `src_def`.
240*da0073e9SAndroid Build Coastguard Worker    globals: Dict[str, Any]
241*da0073e9SAndroid Build Coastguard Worker    # Optional mapping from the forward function's line number to
242*da0073e9SAndroid Build Coastguard Worker    # node index.
243*da0073e9SAndroid Build Coastguard Worker    _lineno_map: Optional[Dict[int, Optional[int]]]
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Workerdef _format_target(base: str, target: str) -> str:
247*da0073e9SAndroid Build Coastguard Worker    elems = target.split('.')
248*da0073e9SAndroid Build Coastguard Worker    r = base
249*da0073e9SAndroid Build Coastguard Worker    for e in elems:
250*da0073e9SAndroid Build Coastguard Worker        if not e.isidentifier():
251*da0073e9SAndroid Build Coastguard Worker            r = f'getattr({r}, "{e}")'
252*da0073e9SAndroid Build Coastguard Worker        else:
253*da0073e9SAndroid Build Coastguard Worker            r = f'{r}.{e}'
254*da0073e9SAndroid Build Coastguard Worker    return r
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Workerclass _InsertPoint:
257*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph, new_insert):
258*da0073e9SAndroid Build Coastguard Worker        self.graph = graph
259*da0073e9SAndroid Build Coastguard Worker        self.orig_insert, graph._insert = graph._insert, new_insert
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
262*da0073e9SAndroid Build Coastguard Worker        pass
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, type, value, tb):
265*da0073e9SAndroid Build Coastguard Worker        self.graph._insert = self.orig_insert
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Workerclass _node_list:
268*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph: 'Graph', direction: str = '_next'):
269*da0073e9SAndroid Build Coastguard Worker        assert direction in ['_next', '_prev']
270*da0073e9SAndroid Build Coastguard Worker        self.graph = graph
271*da0073e9SAndroid Build Coastguard Worker        self.direction = direction
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
274*da0073e9SAndroid Build Coastguard Worker        return self.graph._len
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
277*da0073e9SAndroid Build Coastguard Worker        assert self.direction == "_prev" or self.direction == "_next"
278*da0073e9SAndroid Build Coastguard Worker        yield from _NodeIter(self.graph._root, self.direction == "_prev")
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    def __reversed__(self):
281*da0073e9SAndroid Build Coastguard Worker        return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Workerclass _PyTreeInfo(NamedTuple):
284*da0073e9SAndroid Build Coastguard Worker    """
285*da0073e9SAndroid Build Coastguard Worker    Contains extra info stored when we're using Pytrees
286*da0073e9SAndroid Build Coastguard Worker    """
287*da0073e9SAndroid Build Coastguard Worker    orig_args: List[str]
288*da0073e9SAndroid Build Coastguard Worker    in_spec: pytree.TreeSpec
289*da0073e9SAndroid Build Coastguard Worker    out_spec: Optional[pytree.TreeSpec]
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
292*da0073e9SAndroid Build Coastguard Workerclass _ParsedStackTrace:
293*da0073e9SAndroid Build Coastguard Worker    """
294*da0073e9SAndroid Build Coastguard Worker    Represents the top-most frame of a parsed stack trace
295*da0073e9SAndroid Build Coastguard Worker    """
296*da0073e9SAndroid Build Coastguard Worker    file: str
297*da0073e9SAndroid Build Coastguard Worker    lineno: str
298*da0073e9SAndroid Build Coastguard Worker    name: str
299*da0073e9SAndroid Build Coastguard Worker    code: str
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    def get_summary_str(self):
302*da0073e9SAndroid Build Coastguard Worker        return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}'
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker# get File:lineno code from stack_trace
305*da0073e9SAndroid Build Coastguard Workerdef _parse_stack_trace(stack_trace: str):
306*da0073e9SAndroid Build Coastguard Worker    if stack_trace is None:
307*da0073e9SAndroid Build Coastguard Worker        return None
308*da0073e9SAndroid Build Coastguard Worker    pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
309*da0073e9SAndroid Build Coastguard Worker    lines = stack_trace.strip().split('\n')
310*da0073e9SAndroid Build Coastguard Worker    # stacktrace should have innermost frame last, so we
311*da0073e9SAndroid Build Coastguard Worker    # iterate backwards to find the first line that starts
312*da0073e9SAndroid Build Coastguard Worker    # with 'File '
313*da0073e9SAndroid Build Coastguard Worker    summary_str = ""
314*da0073e9SAndroid Build Coastguard Worker    for idx in range(len(lines) - 2, -1, -1):
315*da0073e9SAndroid Build Coastguard Worker        line = lines[idx].strip()
316*da0073e9SAndroid Build Coastguard Worker        matches = pattern.match(line)
317*da0073e9SAndroid Build Coastguard Worker        if matches:
318*da0073e9SAndroid Build Coastguard Worker            file = matches.group(1)
319*da0073e9SAndroid Build Coastguard Worker            lineno = matches.group(2)
320*da0073e9SAndroid Build Coastguard Worker            name = matches.group(3)
321*da0073e9SAndroid Build Coastguard Worker            # next line should be the code
322*da0073e9SAndroid Build Coastguard Worker            code = lines[idx + 1].strip()
323*da0073e9SAndroid Build Coastguard Worker            return _ParsedStackTrace(file, lineno, name, code)
324*da0073e9SAndroid Build Coastguard Worker    return None
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
327*da0073e9SAndroid Build Coastguard Workerclass CodeGen:
328*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
329*da0073e9SAndroid Build Coastguard Worker        self._body_transformer: Optional[TransformCodeFunc] = None
330*da0073e9SAndroid Build Coastguard Worker        self._func_name: str = "forward"
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
333*da0073e9SAndroid Build Coastguard Worker        """
334*da0073e9SAndroid Build Coastguard Worker        Given the free variables and a return annotation, generates the beginning of the FX function.
335*da0073e9SAndroid Build Coastguard Worker        By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
336*da0073e9SAndroid Build Coastguard Worker        """
337*da0073e9SAndroid Build Coastguard Worker        # If the original function didn't have self as its first argument, we
338*da0073e9SAndroid Build Coastguard Worker        # would have added it.
339*da0073e9SAndroid Build Coastguard Worker        if len(free_vars) == 0 or free_vars[0] != 'self':
340*da0073e9SAndroid Build Coastguard Worker            free_vars.insert(0, 'self')
341*da0073e9SAndroid Build Coastguard Worker        return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def generate_output(self, output_args: Argument) -> str:
344*da0073e9SAndroid Build Coastguard Worker        """
345*da0073e9SAndroid Build Coastguard Worker        Given the output arguments, generates the return statement of the FX function.
346*da0073e9SAndroid Build Coastguard Worker        Note: The returned statement should not be indented.
347*da0073e9SAndroid Build Coastguard Worker        """
348*da0073e9SAndroid Build Coastguard Worker        return f'return {repr(output_args)}'
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    def process_inputs(self, *args: Any) -> Any:
351*da0073e9SAndroid Build Coastguard Worker        """
352*da0073e9SAndroid Build Coastguard Worker        Transforms the inputs so that the graph can take them as arguments, as
353*da0073e9SAndroid Build Coastguard Worker        non-default codegen may result in the inputs to the function being
354*da0073e9SAndroid Build Coastguard Worker        different from the inputs to the graph.
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        If the graph was directly runnable, this invariant should hold true
357*da0073e9SAndroid Build Coastguard Worker        `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
358*da0073e9SAndroid Build Coastguard Worker        """
359*da0073e9SAndroid Build Coastguard Worker        return args
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def process_outputs(self, outputs: Any) -> Any:
362*da0073e9SAndroid Build Coastguard Worker        """
363*da0073e9SAndroid Build Coastguard Worker        Transforms the outputs of the graph to be identical to the codegen.
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        See ``process_inputs`` for more details.
366*da0073e9SAndroid Build Coastguard Worker        """
367*da0073e9SAndroid Build Coastguard Worker        return outputs
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    def additional_globals(self) -> List[Tuple[str, Any]]:
370*da0073e9SAndroid Build Coastguard Worker        """
371*da0073e9SAndroid Build Coastguard Worker        If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
372*da0073e9SAndroid Build Coastguard Worker        For example, return ['List', typing.List] if you need ``List`` in the global context.
373*da0073e9SAndroid Build Coastguard Worker        """
374*da0073e9SAndroid Build Coastguard Worker        return []
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    def _gen_python_code(
377*da0073e9SAndroid Build Coastguard Worker        self, nodes, root_module: str, namespace: _Namespace, *,
378*da0073e9SAndroid Build Coastguard Worker        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False
379*da0073e9SAndroid Build Coastguard Worker    ) -> PythonCode:
380*da0073e9SAndroid Build Coastguard Worker        free_vars: List[str] = []
381*da0073e9SAndroid Build Coastguard Worker        body: List[str] = []
382*da0073e9SAndroid Build Coastguard Worker        globals_: Dict[str, Any] = {}
383*da0073e9SAndroid Build Coastguard Worker        wrapped_fns: Dict[str, None] = {}
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        # Wrap string in list to pass by reference
386*da0073e9SAndroid Build Coastguard Worker        maybe_return_annotation : List[str] = ['']
387*da0073e9SAndroid Build Coastguard Worker        include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1")
388*da0073e9SAndroid Build Coastguard Worker        include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1")
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        def add_global(name_hint: str, obj: Any):
391*da0073e9SAndroid Build Coastguard Worker            """Add an obj to be tracked as a global.
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker            We call this for names that reference objects external to the
394*da0073e9SAndroid Build Coastguard Worker            Graph, like functions or types.
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker            Returns: the global name that should be used to reference 'obj' in generated source.
397*da0073e9SAndroid Build Coastguard Worker            """
398*da0073e9SAndroid Build Coastguard Worker            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
399*da0073e9SAndroid Build Coastguard Worker                # HACK: workaround for how torch custom ops are registered. We
400*da0073e9SAndroid Build Coastguard Worker                # can't import them like normal modules so they must retain their
401*da0073e9SAndroid Build Coastguard Worker                # fully qualified name.
402*da0073e9SAndroid Build Coastguard Worker                return _get_qualified_name(obj)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker            # normalize the name hint to get a proper identifier
405*da0073e9SAndroid Build Coastguard Worker            global_name = namespace.create_name(name_hint, obj)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker            if global_name in globals_:
408*da0073e9SAndroid Build Coastguard Worker                assert globals_[global_name] is obj
409*da0073e9SAndroid Build Coastguard Worker                return global_name
410*da0073e9SAndroid Build Coastguard Worker            globals_[global_name] = obj
411*da0073e9SAndroid Build Coastguard Worker            return global_name
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        # Pre-fill the globals table with registered builtins.
414*da0073e9SAndroid Build Coastguard Worker        for name, (_, obj) in _custom_builtins.items():
415*da0073e9SAndroid Build Coastguard Worker            add_global(name, obj)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        def type_repr(o : Any):
418*da0073e9SAndroid Build Coastguard Worker            if o == ():
419*da0073e9SAndroid Build Coastguard Worker                # Empty tuple is used for empty tuple type annotation Tuple[()]
420*da0073e9SAndroid Build Coastguard Worker                return '()'
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker            typename = _type_repr(o)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker            if hasattr(o, '__origin__'):
425*da0073e9SAndroid Build Coastguard Worker                # This is a generic type, e.g. typing.List[torch.Tensor]
426*da0073e9SAndroid Build Coastguard Worker                origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
427*da0073e9SAndroid Build Coastguard Worker                origin_typename = add_global(_type_repr(origin_type), origin_type)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker                if hasattr(o, '__args__'):
430*da0073e9SAndroid Build Coastguard Worker                    # Assign global names for each of the inner type variables.
431*da0073e9SAndroid Build Coastguard Worker                    args = [type_repr(arg) for arg in o.__args__]
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker                    if len(args) == 0:
434*da0073e9SAndroid Build Coastguard Worker                        # Bare type, such as `typing.Tuple` with no subscript
435*da0073e9SAndroid Build Coastguard Worker                        # This code-path used in Python < 3.9
436*da0073e9SAndroid Build Coastguard Worker                        return origin_typename
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker                    return f'{origin_typename}[{",".join(args)}]'
439*da0073e9SAndroid Build Coastguard Worker                else:
440*da0073e9SAndroid Build Coastguard Worker                    # Bare type, such as `typing.Tuple` with no subscript
441*da0073e9SAndroid Build Coastguard Worker                    # This code-path used in Python 3.9+
442*da0073e9SAndroid Build Coastguard Worker                    return origin_typename
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker            # Common case: this is a regular module name like 'foo.bar.baz'
445*da0073e9SAndroid Build Coastguard Worker            return add_global(typename, o)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        codes = {
448*da0073e9SAndroid Build Coastguard Worker            "yellow": "\033[33m",
449*da0073e9SAndroid Build Coastguard Worker            "cyan": "\033[36m",
450*da0073e9SAndroid Build Coastguard Worker            "green": "\033[32m",
451*da0073e9SAndroid Build Coastguard Worker            "blue": "\033[34m",
452*da0073e9SAndroid Build Coastguard Worker            "red": "\033[31m",
453*da0073e9SAndroid Build Coastguard Worker            "dim": "\033[2m",
454*da0073e9SAndroid Build Coastguard Worker            "dim_blue": "\033[2m\033[34m",
455*da0073e9SAndroid Build Coastguard Worker            "dim_green": "\033[2m\033[32m",
456*da0073e9SAndroid Build Coastguard Worker            "reset": "\033[0m",
457*da0073e9SAndroid Build Coastguard Worker        }
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        def make_wrapper_func(name):
460*da0073e9SAndroid Build Coastguard Worker            def f(s):
461*da0073e9SAndroid Build Coastguard Worker                if colored:
462*da0073e9SAndroid Build Coastguard Worker                    return f"{codes[name]}{s}{codes['reset']}"
463*da0073e9SAndroid Build Coastguard Worker                return s
464*da0073e9SAndroid Build Coastguard Worker            return f
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        yellow = make_wrapper_func("yellow")
467*da0073e9SAndroid Build Coastguard Worker        cyan = make_wrapper_func("cyan")
468*da0073e9SAndroid Build Coastguard Worker        red = make_wrapper_func("red")
469*da0073e9SAndroid Build Coastguard Worker        green = make_wrapper_func("green")
470*da0073e9SAndroid Build Coastguard Worker        dim_green = make_wrapper_func("dim_green")
471*da0073e9SAndroid Build Coastguard Worker        dim = make_wrapper_func("dim")
472*da0073e9SAndroid Build Coastguard Worker        dim_blue = make_wrapper_func("dim_blue")
473*da0073e9SAndroid Build Coastguard Worker        blue = make_wrapper_func("blue")
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        def _get_repr(arg: Any) -> str:
476*da0073e9SAndroid Build Coastguard Worker            # Handle NamedTuples (if it has `_fields`) via add_global.
477*da0073e9SAndroid Build Coastguard Worker            if isinstance(arg, tuple) and hasattr(arg, '_fields'):
478*da0073e9SAndroid Build Coastguard Worker                qualified_name = _get_qualified_name(type(arg))
479*da0073e9SAndroid Build Coastguard Worker                global_name = add_global(qualified_name, type(arg))
480*da0073e9SAndroid Build Coastguard Worker                return f"{global_name}{repr(tuple(arg))}"
481*da0073e9SAndroid Build Coastguard Worker            elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
482*da0073e9SAndroid Build Coastguard Worker                qualified_name = _get_qualified_name(arg)
483*da0073e9SAndroid Build Coastguard Worker                global_name = add_global(qualified_name, arg)
484*da0073e9SAndroid Build Coastguard Worker                return f"{global_name}"
485*da0073e9SAndroid Build Coastguard Worker            elif isinstance(arg, enum.Enum):
486*da0073e9SAndroid Build Coastguard Worker                cls = arg.__class__
487*da0073e9SAndroid Build Coastguard Worker                clsname = add_global(cls.__name__, cls)
488*da0073e9SAndroid Build Coastguard Worker                return f"{clsname}.{arg.name}"
489*da0073e9SAndroid Build Coastguard Worker            elif isinstance(arg, Node):
490*da0073e9SAndroid Build Coastguard Worker                return repr(arg)
491*da0073e9SAndroid Build Coastguard Worker            elif isinstance(arg, torch.Tensor):
492*da0073e9SAndroid Build Coastguard Worker                size = list(arg.size())
493*da0073e9SAndroid Build Coastguard Worker                dtype = str(arg.dtype).split(".")[-1]
494*da0073e9SAndroid Build Coastguard Worker                return f"torch.Tensor(size={size}, dtype={dtype})"
495*da0073e9SAndroid Build Coastguard Worker            else:
496*da0073e9SAndroid Build Coastguard Worker                return blue(repr(arg))
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker        def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
500*da0073e9SAndroid Build Coastguard Worker            args_s = ', '.join(_get_repr(a) for a in args)
501*da0073e9SAndroid Build Coastguard Worker            kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
502*da0073e9SAndroid Build Coastguard Worker            if args_s and kwargs_s:
503*da0073e9SAndroid Build Coastguard Worker                return f'{args_s}, {kwargs_s}'
504*da0073e9SAndroid Build Coastguard Worker            return args_s or kwargs_s
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker        # Run through reverse nodes and record the first instance of a use
507*da0073e9SAndroid Build Coastguard Worker        # of a given node. This represents the *last* use of the node in the
508*da0073e9SAndroid Build Coastguard Worker        # execution order of the program, which we will use to free unused
509*da0073e9SAndroid Build Coastguard Worker        # values
510*da0073e9SAndroid Build Coastguard Worker        node_to_last_use : Dict[Node, Node] = {}
511*da0073e9SAndroid Build Coastguard Worker        user_to_last_uses : Dict[Node, List[Node]] = {}
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker        def register_last_uses(n : Node, user : Node):
514*da0073e9SAndroid Build Coastguard Worker            if n not in node_to_last_use:
515*da0073e9SAndroid Build Coastguard Worker                node_to_last_use[n] = user
516*da0073e9SAndroid Build Coastguard Worker                user_to_last_uses.setdefault(user, []).append(n)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        for node in reversed(nodes):
519*da0073e9SAndroid Build Coastguard Worker            map_arg(node.args, lambda n: register_last_uses(n, node))
520*da0073e9SAndroid Build Coastguard Worker            map_arg(node.kwargs, lambda n: register_last_uses(n, node))
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        def delete_unused_values(user : Node):
523*da0073e9SAndroid Build Coastguard Worker            """
524*da0073e9SAndroid Build Coastguard Worker            Delete values after their last use. This ensures that values that are
525*da0073e9SAndroid Build Coastguard Worker            not used in the remainder of the code are freed and the memory usage
526*da0073e9SAndroid Build Coastguard Worker            of the code is optimal.
527*da0073e9SAndroid Build Coastguard Worker            """
528*da0073e9SAndroid Build Coastguard Worker            if user.op == 'placeholder':
529*da0073e9SAndroid Build Coastguard Worker                return
530*da0073e9SAndroid Build Coastguard Worker            if user.op == 'output':
531*da0073e9SAndroid Build Coastguard Worker                body.append('\n')
532*da0073e9SAndroid Build Coastguard Worker                return
533*da0073e9SAndroid Build Coastguard Worker            nodes_to_delete = user_to_last_uses.get(user, [])
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker            if len(user.users.keys()) == 0:
536*da0073e9SAndroid Build Coastguard Worker                # This node is not used by any others. however it's also not
537*da0073e9SAndroid Build Coastguard Worker                # removed by DCE since side-effect. We want to free it's outputs
538*da0073e9SAndroid Build Coastguard Worker                # right after its execution done to save memory.
539*da0073e9SAndroid Build Coastguard Worker                nodes_to_delete.append(user)
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker            if len(nodes_to_delete):
542*da0073e9SAndroid Build Coastguard Worker                to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
543*da0073e9SAndroid Build Coastguard Worker                body.append(f';  {dim(to_delete_str)}\n')
544*da0073e9SAndroid Build Coastguard Worker            else:
545*da0073e9SAndroid Build Coastguard Worker                body.append('\n')
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        prev_stacktrace = None
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        def append_stacktrace_summary(node : Node):
550*da0073e9SAndroid Build Coastguard Worker            """
551*da0073e9SAndroid Build Coastguard Worker            Append a summary of the stacktrace to the generated code. This is
552*da0073e9SAndroid Build Coastguard Worker            useful for debugging.
553*da0073e9SAndroid Build Coastguard Worker            """
554*da0073e9SAndroid Build Coastguard Worker            nonlocal prev_stacktrace
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker            if node.op not in {'placeholder', 'output'}:
557*da0073e9SAndroid Build Coastguard Worker                if node.stack_trace:
558*da0073e9SAndroid Build Coastguard Worker                    if node.stack_trace != prev_stacktrace:
559*da0073e9SAndroid Build Coastguard Worker                        prev_stacktrace = node.stack_trace
560*da0073e9SAndroid Build Coastguard Worker                        summary_str = ""
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker                        if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
563*da0073e9SAndroid Build Coastguard Worker                            summary_str = parsed_stack_trace.get_summary_str()
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker                        body.append(f'\n {dim("# " + summary_str)}\n')
566*da0073e9SAndroid Build Coastguard Worker                elif prev_stacktrace != "":
567*da0073e9SAndroid Build Coastguard Worker                    prev_stacktrace = ""
568*da0073e9SAndroid Build Coastguard Worker                    no_stacktrace_msg = "# No stacktrace found for following nodes"
569*da0073e9SAndroid Build Coastguard Worker                    body.append(f'\n{dim(no_stacktrace_msg)}\n')
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        def stringify_shape(shape : Iterable) -> str:
572*da0073e9SAndroid Build Coastguard Worker            return f"[{', '.join(str(x) for x in shape)}]"
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        def emit_node(node : Node):
575*da0073e9SAndroid Build Coastguard Worker            maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker            if verbose:
578*da0073e9SAndroid Build Coastguard Worker                # override annotation with more detailed information
579*da0073e9SAndroid Build Coastguard Worker                from torch.fx.experimental.proxy_tensor import py_sym_types
580*da0073e9SAndroid Build Coastguard Worker                from torch.fx.passes.shape_prop import TensorMetadata
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker                meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None)))
583*da0073e9SAndroid Build Coastguard Worker                # use string as annotation, to make it valid python code
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker                if isinstance(meta_val, torch.Tensor):
586*da0073e9SAndroid Build Coastguard Worker                    stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else ""
587*da0073e9SAndroid Build Coastguard Worker                    device_annotation = f"{meta_val.device}" if include_device else ""
588*da0073e9SAndroid Build Coastguard Worker                    maybe_type_annotation = \
589*da0073e9SAndroid Build Coastguard Worker                        f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \
590*da0073e9SAndroid Build Coastguard Worker                        f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
591*da0073e9SAndroid Build Coastguard Worker                elif isinstance(meta_val, py_sym_types):
592*da0073e9SAndroid Build Coastguard Worker                    maybe_type_annotation = f': "Sym({meta_val})"'
593*da0073e9SAndroid Build Coastguard Worker                elif isinstance(meta_val, TensorMetadata):
594*da0073e9SAndroid Build Coastguard Worker                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker            if node.op == 'placeholder':
597*da0073e9SAndroid Build Coastguard Worker                assert isinstance(node.target, str)
598*da0073e9SAndroid Build Coastguard Worker                maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
599*da0073e9SAndroid Build Coastguard Worker                free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
600*da0073e9SAndroid Build Coastguard Worker                raw_name = node.target.replace('*', '')
601*da0073e9SAndroid Build Coastguard Worker                if raw_name != repr(node):
602*da0073e9SAndroid Build Coastguard Worker                    body.append(f'{repr(node)} = {raw_name}\n')
603*da0073e9SAndroid Build Coastguard Worker                return
604*da0073e9SAndroid Build Coastguard Worker            elif node.op == 'call_method':
605*da0073e9SAndroid Build Coastguard Worker                assert isinstance(node.target, str)
606*da0073e9SAndroid Build Coastguard Worker                body.append(
607*da0073e9SAndroid Build Coastguard Worker                    f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
608*da0073e9SAndroid Build Coastguard Worker                    f'({_format_args(node.args[1:], node.kwargs)})')
609*da0073e9SAndroid Build Coastguard Worker                return
610*da0073e9SAndroid Build Coastguard Worker            elif node.op == 'call_function':
611*da0073e9SAndroid Build Coastguard Worker                assert callable(node.target)
612*da0073e9SAndroid Build Coastguard Worker                # pretty print operators
613*da0073e9SAndroid Build Coastguard Worker                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
614*da0073e9SAndroid Build Coastguard Worker                    assert isinstance(node.args, tuple)
615*da0073e9SAndroid Build Coastguard Worker                    body.append(f'{repr(node)}{maybe_type_annotation} = '
616*da0073e9SAndroid Build Coastguard Worker                                f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
617*da0073e9SAndroid Build Coastguard Worker                    return
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker                # pretty print inplace operators; required for jit.script to work properly
620*da0073e9SAndroid Build Coastguard Worker                # not currently supported in normal FX graphs, but generated by torchdynamo
621*da0073e9SAndroid Build Coastguard Worker                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
622*da0073e9SAndroid Build Coastguard Worker                    body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))};  '
623*da0073e9SAndroid Build Coastguard Worker                                f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
624*da0073e9SAndroid Build Coastguard Worker                    return
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker                qualified_name = _get_qualified_name(node.target)
627*da0073e9SAndroid Build Coastguard Worker                global_name = add_global(qualified_name, node.target)
628*da0073e9SAndroid Build Coastguard Worker                # special case for getattr: node.args could be 2-argument or 3-argument
629*da0073e9SAndroid Build Coastguard Worker                # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
630*da0073e9SAndroid Build Coastguard Worker                if global_name == 'getattr' and \
631*da0073e9SAndroid Build Coastguard Worker                   isinstance(node.args, tuple) and \
632*da0073e9SAndroid Build Coastguard Worker                   isinstance(node.args[1], str) and \
633*da0073e9SAndroid Build Coastguard Worker                   node.args[1].isidentifier() and \
634*da0073e9SAndroid Build Coastguard Worker                   len(node.args) == 2:
635*da0073e9SAndroid Build Coastguard Worker                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
636*da0073e9SAndroid Build Coastguard Worker                    return
637*da0073e9SAndroid Build Coastguard Worker                body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
638*da0073e9SAndroid Build Coastguard Worker                if node.meta.get('is_wrapped', False):
639*da0073e9SAndroid Build Coastguard Worker                    wrapped_fns.setdefault(global_name)
640*da0073e9SAndroid Build Coastguard Worker                return
641*da0073e9SAndroid Build Coastguard Worker            elif node.op == 'call_module':
642*da0073e9SAndroid Build Coastguard Worker                assert isinstance(node.target, str)
643*da0073e9SAndroid Build Coastguard Worker                body.append(f'{repr(node)}{maybe_type_annotation} = '
644*da0073e9SAndroid Build Coastguard Worker                            f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
645*da0073e9SAndroid Build Coastguard Worker                return
646*da0073e9SAndroid Build Coastguard Worker            elif node.op == 'get_attr':
647*da0073e9SAndroid Build Coastguard Worker                assert isinstance(node.target, str)
648*da0073e9SAndroid Build Coastguard Worker                body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
649*da0073e9SAndroid Build Coastguard Worker                return
650*da0073e9SAndroid Build Coastguard Worker            elif node.op == 'output':
651*da0073e9SAndroid Build Coastguard Worker                if node.type is not None:
652*da0073e9SAndroid Build Coastguard Worker                    maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
653*da0073e9SAndroid Build Coastguard Worker                body.append(self.generate_output(node.args[0]))
654*da0073e9SAndroid Build Coastguard Worker                return
655*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(f'node: {node.op} {node.target}')
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        for i, node in enumerate(nodes):
658*da0073e9SAndroid Build Coastguard Worker            # NOTE: emit_node does not emit a string with newline. It depends
659*da0073e9SAndroid Build Coastguard Worker            # on delete_unused_values to append one
660*da0073e9SAndroid Build Coastguard Worker            if verbose:
661*da0073e9SAndroid Build Coastguard Worker                append_stacktrace_summary(node)
662*da0073e9SAndroid Build Coastguard Worker            # emit a counter comment to keep track of
663*da0073e9SAndroid Build Coastguard Worker            # node index, which will be deleted later
664*da0073e9SAndroid Build Coastguard Worker            # after going through _body_transformer
665*da0073e9SAndroid Build Coastguard Worker            body.append(f"# COUNTER: {i}\n")
666*da0073e9SAndroid Build Coastguard Worker            emit_node(node)
667*da0073e9SAndroid Build Coastguard Worker            delete_unused_values(node)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        if len(body) == 0:
670*da0073e9SAndroid Build Coastguard Worker            # If the Graph has no non-placeholder nodes, no lines for the body
671*da0073e9SAndroid Build Coastguard Worker            # have been emitted. To continue to have valid Python code, emit a
672*da0073e9SAndroid Build Coastguard Worker            # single pass statement
673*da0073e9SAndroid Build Coastguard Worker            body.append('pass\n')
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        if len(wrapped_fns) > 0:
678*da0073e9SAndroid Build Coastguard Worker            wrap_name = add_global('wrap', torch.fx.wrap)
679*da0073e9SAndroid Build Coastguard Worker            wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
680*da0073e9SAndroid Build Coastguard Worker        else:
681*da0073e9SAndroid Build Coastguard Worker            wrap_stmts = ''
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        if self._body_transformer:
684*da0073e9SAndroid Build Coastguard Worker            body = self._body_transformer(body)
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker        for name, value in self.additional_globals():
687*da0073e9SAndroid Build Coastguard Worker            add_global(name, value)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker        prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        # remove counter and generate lineno to node index mapping
692*da0073e9SAndroid Build Coastguard Worker        lineno_map: Dict[int, Optional[int]] = {}
693*da0073e9SAndroid Build Coastguard Worker        prologue_len = prologue.count('\n') + 1
694*da0073e9SAndroid Build Coastguard Worker        new_lines: List[str] = []
695*da0073e9SAndroid Build Coastguard Worker        cur_idx = None
696*da0073e9SAndroid Build Coastguard Worker        for line in ''.join(body).split('\n'):
697*da0073e9SAndroid Build Coastguard Worker            counter = re.search(r"# COUNTER: (\d+)", line)
698*da0073e9SAndroid Build Coastguard Worker            if counter and counter.group(1) is not None:
699*da0073e9SAndroid Build Coastguard Worker                cur_idx = int(counter.group(1))
700*da0073e9SAndroid Build Coastguard Worker            else:
701*da0073e9SAndroid Build Coastguard Worker                lineno_map[len(new_lines) + prologue_len] = cur_idx
702*da0073e9SAndroid Build Coastguard Worker                new_lines.append(line)
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker        code = "\n".join(new_lines).lstrip('\n')
705*da0073e9SAndroid Build Coastguard Worker        code = '\n'.join('    ' + line for line in code.split('\n'))
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker        fn_code = f"""
708*da0073e9SAndroid Build Coastguard Worker{wrap_stmts}
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker{prologue}
711*da0073e9SAndroid Build Coastguard Worker{code}"""
712*da0073e9SAndroid Build Coastguard Worker        return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker# Ideally, we'd like to refactor all of the pytree logic into this codegen
716*da0073e9SAndroid Build Coastguard Worker# class. Unfortunately, there are 3 areas we currently need extra logic in FX.
717*da0073e9SAndroid Build Coastguard Worker# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
718*da0073e9SAndroid Build Coastguard Worker# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
719*da0073e9SAndroid Build Coastguard Worker#    Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
720*da0073e9SAndroid Build Coastguard Worker# 3. We currently can't register the pytree imports with `add_global` - not sure why.
721*da0073e9SAndroid Build Coastguard Workerclass _PyTreeCodeGen(CodeGen):
722*da0073e9SAndroid Build Coastguard Worker    def __init__(self, pytree_info: _PyTreeInfo):
723*da0073e9SAndroid Build Coastguard Worker        super().__init__()
724*da0073e9SAndroid Build Coastguard Worker        self.pytree_info: _PyTreeInfo = pytree_info
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    def process_inputs(self, *inputs: Any) -> Any:
727*da0073e9SAndroid Build Coastguard Worker        flat_args = pytree.arg_tree_leaves(*inputs)
728*da0073e9SAndroid Build Coastguard Worker        return flat_args
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def process_outputs(self, out: Any) -> Any:
731*da0073e9SAndroid Build Coastguard Worker        if self.pytree_info is None or self.pytree_info.out_spec is None:
732*da0073e9SAndroid Build Coastguard Worker            return out
733*da0073e9SAndroid Build Coastguard Worker        if not isinstance(out, (list, tuple)):
734*da0073e9SAndroid Build Coastguard Worker            out = [out]
735*da0073e9SAndroid Build Coastguard Worker        assert self.pytree_info.out_spec is not None
736*da0073e9SAndroid Build Coastguard Worker        return pytree.tree_unflatten(out, self.pytree_info.out_spec)
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker    def gen_fn_def(self, free_vars, maybe_return_annotation):
739*da0073e9SAndroid Build Coastguard Worker        # Given a user function/model:
740*da0073e9SAndroid Build Coastguard Worker        #   myargs = [myargs0, myargs1]
741*da0073e9SAndroid Build Coastguard Worker        #   mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
742*da0073e9SAndroid Build Coastguard Worker        #   def forward(self, mypos, *myargs, mykey=None, **mykwargs):
743*da0073e9SAndroid Build Coastguard Worker        #
744*da0073e9SAndroid Build Coastguard Worker        # The generated code flattens all keywords into positional arguments for `forward()`
745*da0073e9SAndroid Build Coastguard Worker        #   e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
746*da0073e9SAndroid Build Coastguard Worker        #
747*da0073e9SAndroid Build Coastguard Worker        # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
748*da0073e9SAndroid Build Coastguard Worker        #   e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
749*da0073e9SAndroid Build Coastguard Worker        #                           {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
750*da0073e9SAndroid Build Coastguard Worker        #                          self._in_spec)
751*da0073e9SAndroid Build Coastguard Worker        #
752*da0073e9SAndroid Build Coastguard Worker        # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
753*da0073e9SAndroid Build Coastguard Worker        #   e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
754*da0073e9SAndroid Build Coastguard Worker        if self.pytree_info is None:
755*da0073e9SAndroid Build Coastguard Worker            return super().gen_fn_def(free_vars, maybe_return_annotation)
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        fn_args = self.pytree_info.orig_args
758*da0073e9SAndroid Build Coastguard Worker        has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
759*da0073e9SAndroid Build Coastguard Worker        if has_orig_self:
760*da0073e9SAndroid Build Coastguard Worker            free_vars.insert(0, 'self')
761*da0073e9SAndroid Build Coastguard Worker        fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker        if len(free_vars) > 0:  # pytree has placeholders in it
764*da0073e9SAndroid Build Coastguard Worker            # when kwargs is present, in_spec is tuple(args, kwargs)
765*da0073e9SAndroid Build Coastguard Worker            has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
766*da0073e9SAndroid Build Coastguard Worker                self.pytree_info.in_spec.num_children == 2 and \
767*da0073e9SAndroid Build Coastguard Worker                self.pytree_info.in_spec.children_specs[0].type == tuple and \
768*da0073e9SAndroid Build Coastguard Worker                self.pytree_info.in_spec.children_specs[1].type == dict
769*da0073e9SAndroid Build Coastguard Worker            fn_kwargs = '{}'
770*da0073e9SAndroid Build Coastguard Worker            fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
771*da0073e9SAndroid Build Coastguard Worker            if has_args_kwargs_tuple:
772*da0073e9SAndroid Build Coastguard Worker                count_args = self.pytree_info.in_spec.children_specs[0].num_children
773*da0073e9SAndroid Build Coastguard Worker                fn_args = self.pytree_info.orig_args[:count_args]
774*da0073e9SAndroid Build Coastguard Worker                fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
775*da0073e9SAndroid Build Coastguard Worker                                  self.pytree_info.in_spec.children_specs[1].context,
776*da0073e9SAndroid Build Coastguard Worker                                  self.pytree_info.orig_args[count_args:])) + '}'
777*da0073e9SAndroid Build Coastguard Worker                fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker            # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
780*da0073e9SAndroid Build Coastguard Worker            # we need to split it to two lines:
781*da0073e9SAndroid Build Coastguard Worker            # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
782*da0073e9SAndroid Build Coastguard Worker            # one for code: `var1, var2, = function_call()`
783*da0073e9SAndroid Build Coastguard Worker            without_annotation = [x.split(":")[0] for x in free_vars]
784*da0073e9SAndroid Build Coastguard Worker            has_annotation = [x + "; " for x in free_vars if ":" in x]
785*da0073e9SAndroid Build Coastguard Worker            if len(has_annotation) > 0:
786*da0073e9SAndroid Build Coastguard Worker                fn_definition += "\n    " + "".join(has_annotation) + "\n"
787*da0073e9SAndroid Build Coastguard Worker            fn_definition += f"""
788*da0073e9SAndroid Build Coastguard Worker    {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
789*da0073e9SAndroid Build Coastguard Worker        return fn_definition
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker    def generate_output(self, output_args):
792*da0073e9SAndroid Build Coastguard Worker        if self.pytree_info and self.pytree_info.out_spec:
793*da0073e9SAndroid Build Coastguard Worker            return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
794*da0073e9SAndroid Build Coastguard Worker        else:
795*da0073e9SAndroid Build Coastguard Worker            return super().generate_output(output_args)
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Workerclass _FindNodesLookupTable:
798*da0073e9SAndroid Build Coastguard Worker    """
799*da0073e9SAndroid Build Coastguard Worker    Side table for the graph for the purpose of doing fast queries
800*da0073e9SAndroid Build Coastguard Worker    """
801*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
802*da0073e9SAndroid Build Coastguard Worker        self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict)
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker    def _key(self, node) -> Tuple[str, Optional[Target]]:
805*da0073e9SAndroid Build Coastguard Worker        return (node.op, node.target if node.op == "call_function" else None)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker    def __contains__(self, node) -> bool:
808*da0073e9SAndroid Build Coastguard Worker        return node in self.table[self._key(node)]
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    def insert(self, node: Node) -> None:
811*da0073e9SAndroid Build Coastguard Worker        self.table[self._key(node)][node] = None
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    def remove(self, node: Node) -> None:
814*da0073e9SAndroid Build Coastguard Worker        self.table[self._key(node)].pop(node)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def find_nodes(self, *, op: str, target: Optional['Target'] = None):
817*da0073e9SAndroid Build Coastguard Worker        if op == "call_function":
818*da0073e9SAndroid Build Coastguard Worker            assert target is not None
819*da0073e9SAndroid Build Coastguard Worker            return dict(self.table[(op, target)]).keys()
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        if target is None:
822*da0073e9SAndroid Build Coastguard Worker            return dict(self.table[(op, None)]).keys()
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker        # op is call_method, get_attr, call_module
825*da0073e9SAndroid Build Coastguard Worker        return [node for node in self.table[(op, None)].keys() if node.target == target]
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
828*da0073e9SAndroid Build Coastguard Workerclass Graph:
829*da0073e9SAndroid Build Coastguard Worker    """
830*da0073e9SAndroid Build Coastguard Worker    ``Graph`` is the main data structure used in the FX Intermediate Representation.
831*da0073e9SAndroid Build Coastguard Worker    It consists of a series of ``Node`` s, each representing callsites (or other
832*da0073e9SAndroid Build Coastguard Worker    syntactic constructs). The list of ``Node`` s, taken together, constitute a
833*da0073e9SAndroid Build Coastguard Worker    valid Python function.
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker    For example, the following code
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    .. code-block:: python
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker        import torch
840*da0073e9SAndroid Build Coastguard Worker        import torch.fx
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
843*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
844*da0073e9SAndroid Build Coastguard Worker                super().__init__()
845*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
846*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
849*da0073e9SAndroid Build Coastguard Worker                return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
852*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker    Will produce the following Graph::
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker        print(gm.graph)
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker    .. code-block:: text
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        graph(x):
861*da0073e9SAndroid Build Coastguard Worker            %linear_weight : [num_users=1] = self.linear.weight
862*da0073e9SAndroid Build Coastguard Worker            %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
863*da0073e9SAndroid Build Coastguard Worker            %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
864*da0073e9SAndroid Build Coastguard Worker            %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
865*da0073e9SAndroid Build Coastguard Worker            %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
866*da0073e9SAndroid Build Coastguard Worker            %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
867*da0073e9SAndroid Build Coastguard Worker            return topk_1
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
870*da0073e9SAndroid Build Coastguard Worker    """
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
873*da0073e9SAndroid Build Coastguard Worker    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
874*da0073e9SAndroid Build Coastguard Worker                 tracer_extras: Optional[Dict[str, Any]] = None):
875*da0073e9SAndroid Build Coastguard Worker        """
876*da0073e9SAndroid Build Coastguard Worker        Construct an empty Graph.
877*da0073e9SAndroid Build Coastguard Worker        """
878*da0073e9SAndroid Build Coastguard Worker        self._root : Node = Node(self, '', 'root', '', (), {})
879*da0073e9SAndroid Build Coastguard Worker        self._used_names : Dict[str, int] = {}  # base name -> number
880*da0073e9SAndroid Build Coastguard Worker        self._insert = self._root.prepend
881*da0073e9SAndroid Build Coastguard Worker        self._len = 0
882*da0073e9SAndroid Build Coastguard Worker        self._graph_namespace = _Namespace()
883*da0073e9SAndroid Build Coastguard Worker        self._owning_module = owning_module
884*da0073e9SAndroid Build Coastguard Worker        self._tracer_cls = tracer_cls
885*da0073e9SAndroid Build Coastguard Worker        self._tracer_extras = tracer_extras
886*da0073e9SAndroid Build Coastguard Worker        self._codegen = CodeGen()
887*da0073e9SAndroid Build Coastguard Worker        self._co_fields : Dict[str, Any] = {}
888*da0073e9SAndroid Build Coastguard Worker        self._find_nodes_lookup_table = _FindNodesLookupTable()
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    @property
891*da0073e9SAndroid Build Coastguard Worker    def owning_module(self):
892*da0073e9SAndroid Build Coastguard Worker        return self._owning_module
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    @owning_module.setter
895*da0073e9SAndroid Build Coastguard Worker    def owning_module(self, mod: Optional["GraphModule"]):
896*da0073e9SAndroid Build Coastguard Worker        self._owning_module = mod
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    @property
899*da0073e9SAndroid Build Coastguard Worker    def nodes(self) -> _node_list:
900*da0073e9SAndroid Build Coastguard Worker        """
901*da0073e9SAndroid Build Coastguard Worker        Get the list of Nodes that constitute this Graph.
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        Note that this ``Node`` list representation is a doubly-linked list. Mutations
904*da0073e9SAndroid Build Coastguard Worker        during iteration (e.g. delete a Node, add a Node) are safe.
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker        Returns:
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker            A doubly-linked list of Nodes. Note that ``reversed`` can be called on
909*da0073e9SAndroid Build Coastguard Worker            this list to switch iteration order.
910*da0073e9SAndroid Build Coastguard Worker        """
911*da0073e9SAndroid Build Coastguard Worker        return _node_list(self)
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
914*da0073e9SAndroid Build Coastguard Worker    def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True):
915*da0073e9SAndroid Build Coastguard Worker        """
916*da0073e9SAndroid Build Coastguard Worker        Allows for fast query of nodes
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker        Args:
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker            op (str): the name of the operation
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker            target (Optional[Target]): the target of the node. For call_function,
923*da0073e9SAndroid Build Coastguard Worker                the target is required. For other ops, the target is optional.
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker            sort (bool): whether to return nodes in the order they appear on
926*da0073e9SAndroid Build Coastguard Worker                         on the graph.
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        Returns:
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker            Iteratable of nodes with the requested op and target.
931*da0073e9SAndroid Build Coastguard Worker        """
932*da0073e9SAndroid Build Coastguard Worker        node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target)
933*da0073e9SAndroid Build Coastguard Worker        if sort:
934*da0073e9SAndroid Build Coastguard Worker            return sorted(node_list)
935*da0073e9SAndroid Build Coastguard Worker        return node_list
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
938*da0073e9SAndroid Build Coastguard Worker    def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
939*da0073e9SAndroid Build Coastguard Worker        """
940*da0073e9SAndroid Build Coastguard Worker        Copy all nodes from a given graph into ``self``.
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker        Args:
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker            g (Graph): The source graph from which to copy Nodes.
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker            val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
947*da0073e9SAndroid Build Coastguard Worker                from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
948*da0073e9SAndroid Build Coastguard Worker                in with values in it already to override copying of certain values.
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker        Returns:
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker            The value in ``self`` that is now equivalent to the output value in ``g``,
953*da0073e9SAndroid Build Coastguard Worker            if ``g`` had an ``output`` node. ``None`` otherwise.
954*da0073e9SAndroid Build Coastguard Worker        """
955*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes:
956*da0073e9SAndroid Build Coastguard Worker            if node in val_map:
957*da0073e9SAndroid Build Coastguard Worker                continue
958*da0073e9SAndroid Build Coastguard Worker            if node.op == 'output':
959*da0073e9SAndroid Build Coastguard Worker                rv = map_arg(node.args[0], lambda n: val_map[n])
960*da0073e9SAndroid Build Coastguard Worker                return rv if not return_output_node else (rv, node)
961*da0073e9SAndroid Build Coastguard Worker            val_map[node] = self.node_copy(node, lambda n : val_map[n])
962*da0073e9SAndroid Build Coastguard Worker        return None
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo=None) -> 'Graph':
965*da0073e9SAndroid Build Coastguard Worker        """
966*da0073e9SAndroid Build Coastguard Worker        Explicitly implement __deepcopy__ to prevent excessive recursion depth
967*da0073e9SAndroid Build Coastguard Worker        from the default implementation. This uses graph_copy to copy the nodes
968*da0073e9SAndroid Build Coastguard Worker        in an iterative way, rather than recursive. It also populates the
969*da0073e9SAndroid Build Coastguard Worker        memoization table to prevent unnecessary copies (e.g. references to
970*da0073e9SAndroid Build Coastguard Worker        nodes or other parts of the Graph from a custom GraphModule implementation.
971*da0073e9SAndroid Build Coastguard Worker        """
972*da0073e9SAndroid Build Coastguard Worker        memo = memo if memo else {}
973*da0073e9SAndroid Build Coastguard Worker        g = Graph(tracer_cls=self._tracer_cls)
974*da0073e9SAndroid Build Coastguard Worker        output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
975*da0073e9SAndroid Build Coastguard Worker        g._codegen = copy.deepcopy(self._codegen)
976*da0073e9SAndroid Build Coastguard Worker        assert isinstance(output_vals, tuple)
977*da0073e9SAndroid Build Coastguard Worker        output_val, old_output_node = output_vals
978*da0073e9SAndroid Build Coastguard Worker        new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None))
979*da0073e9SAndroid Build Coastguard Worker        new_output_node.meta = copy.copy(old_output_node.meta)
980*da0073e9SAndroid Build Coastguard Worker        return g
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
983*da0073e9SAndroid Build Coastguard Worker    def create_node(self, op: str, target: 'Target',
984*da0073e9SAndroid Build Coastguard Worker                    args: Optional[Tuple['Argument', ...]] = None,
985*da0073e9SAndroid Build Coastguard Worker                    kwargs: Optional[Dict[str, 'Argument']] = None,
986*da0073e9SAndroid Build Coastguard Worker                    name: Optional[str] = None,
987*da0073e9SAndroid Build Coastguard Worker                    type_expr: Optional[Any] = None) -> Node:
988*da0073e9SAndroid Build Coastguard Worker        """
989*da0073e9SAndroid Build Coastguard Worker        Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
990*da0073e9SAndroid Build Coastguard Worker        Note that the current insert-point can be set via :meth:`Graph.inserting_before`
991*da0073e9SAndroid Build Coastguard Worker        and :meth:`Graph.inserting_after`.
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker        Args:
994*da0073e9SAndroid Build Coastguard Worker            op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
995*da0073e9SAndroid Build Coastguard Worker                'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
996*da0073e9SAndroid Build Coastguard Worker                described in the ``Graph`` docstring.
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker            args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker            kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker            name (Optional[str]): an optional string name for the ``Node``.
1003*da0073e9SAndroid Build Coastguard Worker                This will influence the name of the value assigned to in the
1004*da0073e9SAndroid Build Coastguard Worker                Python generated code.
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1007*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker        Returns:
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker            The newly-created and inserted node.
1012*da0073e9SAndroid Build Coastguard Worker        """
1013*da0073e9SAndroid Build Coastguard Worker        assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
1014*da0073e9SAndroid Build Coastguard Worker        args = () if args is None else args
1015*da0073e9SAndroid Build Coastguard Worker        kwargs = {} if kwargs is None else kwargs
1016*da0073e9SAndroid Build Coastguard Worker        assert isinstance(args, tuple), "args must be a tuple"
1017*da0073e9SAndroid Build Coastguard Worker        assert isinstance(kwargs, dict), "kwargs must be a dict"
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker        candidate = name if name is not None else self._target_to_str(target)
1020*da0073e9SAndroid Build Coastguard Worker        name = self._graph_namespace.create_name(candidate, None)
1021*da0073e9SAndroid Build Coastguard Worker        n = Node(self, name, op, target, args, kwargs, type_expr)
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker        if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None:
1024*da0073e9SAndroid Build Coastguard Worker            for f in self.owning_module._create_node_hooks:
1025*da0073e9SAndroid Build Coastguard Worker                f(n)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        self._graph_namespace.associate_name_with_obj(name, n)
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        self._insert(n)
1030*da0073e9SAndroid Build Coastguard Worker        self._find_nodes_lookup_table.insert(n)
1031*da0073e9SAndroid Build Coastguard Worker        self._len += 1
1032*da0073e9SAndroid Build Coastguard Worker        return n
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
1035*da0073e9SAndroid Build Coastguard Worker    def process_inputs(self, *args):
1036*da0073e9SAndroid Build Coastguard Worker        """
1037*da0073e9SAndroid Build Coastguard Worker        Processes args so that they can be passed to the FX graph.
1038*da0073e9SAndroid Build Coastguard Worker        """
1039*da0073e9SAndroid Build Coastguard Worker        return self._codegen.process_inputs(*args)
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
1042*da0073e9SAndroid Build Coastguard Worker    def process_outputs(self, out):
1043*da0073e9SAndroid Build Coastguard Worker        return self._codegen.process_outputs(out)
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1047*da0073e9SAndroid Build Coastguard Worker    def erase_node(self, to_erase : Node) -> None:
1048*da0073e9SAndroid Build Coastguard Worker        """
1049*da0073e9SAndroid Build Coastguard Worker        Erases a ``Node`` from the ``Graph``. Throws an exception if
1050*da0073e9SAndroid Build Coastguard Worker        there are still users of that node in the ``Graph``.
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        Args:
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker            to_erase (Node): The ``Node`` to erase from the ``Graph``.
1055*da0073e9SAndroid Build Coastguard Worker        """
1056*da0073e9SAndroid Build Coastguard Worker        if len(to_erase.users) > 0:
1057*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
1058*da0073e9SAndroid Build Coastguard Worker                               f'users in the graph: {to_erase.users}!')
1059*da0073e9SAndroid Build Coastguard Worker        if to_erase.graph != self:
1060*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
1061*da0073e9SAndroid Build Coastguard Worker        if to_erase._erased:
1062*da0073e9SAndroid Build Coastguard Worker            warnings.warn(f"erase_node({to_erase}) on an already erased node")
1063*da0073e9SAndroid Build Coastguard Worker            return
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker        if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None:
1066*da0073e9SAndroid Build Coastguard Worker            for f in self.owning_module._erase_node_hooks:
1067*da0073e9SAndroid Build Coastguard Worker                f(to_erase)
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker        self._find_nodes_lookup_table.remove(to_erase)
1070*da0073e9SAndroid Build Coastguard Worker        to_erase._remove_from_list()
1071*da0073e9SAndroid Build Coastguard Worker        to_erase._erased = True  # iterators may retain handles to erased nodes
1072*da0073e9SAndroid Build Coastguard Worker        self._len -= 1
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        # Null out this Node's argument nodes so that the Nodes referred to
1075*da0073e9SAndroid Build Coastguard Worker        # can update their ``users`` accordingly
1076*da0073e9SAndroid Build Coastguard Worker        new_args = map_arg(to_erase.args, lambda n: None)
1077*da0073e9SAndroid Build Coastguard Worker        assert isinstance(new_args, tuple)
1078*da0073e9SAndroid Build Coastguard Worker        to_erase.args = new_args
1079*da0073e9SAndroid Build Coastguard Worker        new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
1080*da0073e9SAndroid Build Coastguard Worker        assert isinstance(new_kwargs, dict)
1081*da0073e9SAndroid Build Coastguard Worker        to_erase.kwargs = new_kwargs
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1084*da0073e9SAndroid Build Coastguard Worker    def inserting_before(self, n: Optional[Node] = None):
1085*da0073e9SAndroid Build Coastguard Worker        """Set the point at which create_node and companion methods will insert into the graph.
1086*da0073e9SAndroid Build Coastguard Worker        When used within a 'with' statement, this will temporary set the insert point and
1087*da0073e9SAndroid Build Coastguard Worker        then restore it when the with statement exits::
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker            with g.inserting_before(n):
1090*da0073e9SAndroid Build Coastguard Worker                ... # inserting before node n
1091*da0073e9SAndroid Build Coastguard Worker            ... # insert point restored to what it was previously
1092*da0073e9SAndroid Build Coastguard Worker            g.inserting_before(n) #  set the insert point permanently
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        Args:
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker            n (Optional[Node]): The node before which to insert. If None this will insert before
1097*da0073e9SAndroid Build Coastguard Worker                the beginning of the entire graph.
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        Returns:
1100*da0073e9SAndroid Build Coastguard Worker            A resource manager that will restore the insert point on ``__exit__``.
1101*da0073e9SAndroid Build Coastguard Worker        """
1102*da0073e9SAndroid Build Coastguard Worker        if n is None:
1103*da0073e9SAndroid Build Coastguard Worker            return self.inserting_after(self._root)
1104*da0073e9SAndroid Build Coastguard Worker        assert n.graph == self, "Node to insert before is not in graph."
1105*da0073e9SAndroid Build Coastguard Worker        return _InsertPoint(self, n.prepend)
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1108*da0073e9SAndroid Build Coastguard Worker    def inserting_after(self, n: Optional[Node] = None):
1109*da0073e9SAndroid Build Coastguard Worker        """Set the point at which create_node and companion methods will insert into the graph.
1110*da0073e9SAndroid Build Coastguard Worker        When used within a 'with' statement, this will temporary set the insert point and
1111*da0073e9SAndroid Build Coastguard Worker        then restore it when the with statement exits::
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker            with g.inserting_after(n):
1114*da0073e9SAndroid Build Coastguard Worker                ... # inserting after node n
1115*da0073e9SAndroid Build Coastguard Worker            ... # insert point restored to what it was previously
1116*da0073e9SAndroid Build Coastguard Worker            g.inserting_after(n) #  set the insert point permanently
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        Args:
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker            n (Optional[Node]): The node before which to insert. If None this will insert after
1121*da0073e9SAndroid Build Coastguard Worker                the beginning of the entire graph.
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        Returns:
1124*da0073e9SAndroid Build Coastguard Worker            A resource manager that will restore the insert point on ``__exit__``.
1125*da0073e9SAndroid Build Coastguard Worker        """
1126*da0073e9SAndroid Build Coastguard Worker        if n is None:
1127*da0073e9SAndroid Build Coastguard Worker            return self.inserting_before(self._root)
1128*da0073e9SAndroid Build Coastguard Worker        assert n.graph == self, "Node to insert after is not in graph."
1129*da0073e9SAndroid Build Coastguard Worker        return _InsertPoint(self, n.append)
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1132*da0073e9SAndroid Build Coastguard Worker    def placeholder(self, name: str, type_expr: Optional[Any] = None,
1133*da0073e9SAndroid Build Coastguard Worker                    default_value : Any = inspect.Signature.empty) -> Node:
1134*da0073e9SAndroid Build Coastguard Worker        """
1135*da0073e9SAndroid Build Coastguard Worker        Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
1136*da0073e9SAndroid Build Coastguard Worker        a function input.
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker        Args:
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker            name (str): A name for the input value. This corresponds to the name
1141*da0073e9SAndroid Build Coastguard Worker                of the positional argument to the function this ``Graph`` represents.
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1144*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have. This is needed in some
1145*da0073e9SAndroid Build Coastguard Worker                cases for proper code generation (e.g. when the function is used
1146*da0073e9SAndroid Build Coastguard Worker                subsequently in TorchScript compilation).
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker            default_value (Any): The default value this function argument should take
1149*da0073e9SAndroid Build Coastguard Worker                on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
1150*da0073e9SAndroid Build Coastguard Worker                should be passed as this argument to specify that the parameter does _not_
1151*da0073e9SAndroid Build Coastguard Worker                have a default value.
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker        .. note::
1154*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1155*da0073e9SAndroid Build Coastguard Worker            as ``Graph.create_node``.
1156*da0073e9SAndroid Build Coastguard Worker        """
1157*da0073e9SAndroid Build Coastguard Worker        args = () if default_value is inspect.Signature.empty else (default_value,)
1158*da0073e9SAndroid Build Coastguard Worker        return self.create_node('placeholder', name, args=args, type_expr=type_expr)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1161*da0073e9SAndroid Build Coastguard Worker    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
1162*da0073e9SAndroid Build Coastguard Worker        """
1163*da0073e9SAndroid Build Coastguard Worker        Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
1164*da0073e9SAndroid Build Coastguard Worker        fetch of an attribute from the ``Module`` hierarchy.
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        Args:
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker            qualified_name (str): the fully-qualified name of the attribute to be retrieved.
1169*da0073e9SAndroid Build Coastguard Worker                For example, if the traced Module has a submodule named ``foo``, which has a
1170*da0073e9SAndroid Build Coastguard Worker                submodule named ``bar``, which has an attribute named ``baz``, the qualified
1171*da0073e9SAndroid Build Coastguard Worker                name ``foo.bar.baz`` should be passed as ``qualified_name``.
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1174*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker        Returns:
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker            The newly-created and inserted ``get_attr`` node.
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker        .. note::
1182*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1183*da0073e9SAndroid Build Coastguard Worker            as ``Graph.create_node``.
1184*da0073e9SAndroid Build Coastguard Worker        """
1185*da0073e9SAndroid Build Coastguard Worker        def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
1186*da0073e9SAndroid Build Coastguard Worker            module_path, _, name = qualified_name.rpartition(".")
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker            try:
1189*da0073e9SAndroid Build Coastguard Worker                submod: torch.nn.Module = mod.get_submodule(module_path)
1190*da0073e9SAndroid Build Coastguard Worker            except AttributeError:
1191*da0073e9SAndroid Build Coastguard Worker                warnings.warn(f"Failed to fetch module {module_path}!")
1192*da0073e9SAndroid Build Coastguard Worker                return False
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker            if not hasattr(submod, name):
1195*da0073e9SAndroid Build Coastguard Worker                return False
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker            res = getattr(submod, name)
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker            if (not isinstance(res, torch.nn.Module)
1200*da0073e9SAndroid Build Coastguard Worker                    and not isinstance(res, torch.nn.Parameter)
1201*da0073e9SAndroid Build Coastguard Worker                    and name not in submod._buffers):
1202*da0073e9SAndroid Build Coastguard Worker                return False
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker            return True
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker        if (self.owning_module and
1207*da0073e9SAndroid Build Coastguard Worker                not _get_attr_reference_exists(self.owning_module, qualified_name)):
1208*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Attempted to insert a get_attr Node with no "
1209*da0073e9SAndroid Build Coastguard Worker                          "underlying reference in the owning "
1210*da0073e9SAndroid Build Coastguard Worker                          "GraphModule! Call "
1211*da0073e9SAndroid Build Coastguard Worker                          "GraphModule.add_submodule to add the "
1212*da0073e9SAndroid Build Coastguard Worker                          "necessary submodule, "
1213*da0073e9SAndroid Build Coastguard Worker                          "GraphModule.add_parameter to add the "
1214*da0073e9SAndroid Build Coastguard Worker                          "necessary Parameter, or "
1215*da0073e9SAndroid Build Coastguard Worker                          "nn.Module.register_buffer to add the "
1216*da0073e9SAndroid Build Coastguard Worker                          "necessary buffer", stacklevel=2)
1217*da0073e9SAndroid Build Coastguard Worker        return self.create_node('get_attr', qualified_name, type_expr=type_expr)
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1220*da0073e9SAndroid Build Coastguard Worker    def call_module(self,
1221*da0073e9SAndroid Build Coastguard Worker                    module_name: str,
1222*da0073e9SAndroid Build Coastguard Worker                    args: Optional[Tuple['Argument', ...]] = None,
1223*da0073e9SAndroid Build Coastguard Worker                    kwargs: Optional[Dict[str, 'Argument']] = None,
1224*da0073e9SAndroid Build Coastguard Worker                    type_expr: Optional[Any] = None) -> Node:
1225*da0073e9SAndroid Build Coastguard Worker        """
1226*da0073e9SAndroid Build Coastguard Worker        Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
1227*da0073e9SAndroid Build Coastguard Worker        represents a call to the forward() function of a ``Module`` in the ``Module``
1228*da0073e9SAndroid Build Coastguard Worker        hierarchy.
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker        Args:
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker            module_name (str): The qualified name of the ``Module`` in the ``Module``
1233*da0073e9SAndroid Build Coastguard Worker                hierarchy to be called. For example, if the traced ``Module`` has a
1234*da0073e9SAndroid Build Coastguard Worker                submodule named ``foo``, which has a submodule named ``bar``, the
1235*da0073e9SAndroid Build Coastguard Worker                qualified name ``foo.bar`` should be passed as ``module_name`` to
1236*da0073e9SAndroid Build Coastguard Worker                call that module.
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1239*da0073e9SAndroid Build Coastguard Worker                to the called method. Note that this should *not* include a ``self`` argument.
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1242*da0073e9SAndroid Build Coastguard Worker                to the called method
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1245*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        Returns:
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker            The newly-created and inserted ``call_module`` node.
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker        .. note::
1252*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1253*da0073e9SAndroid Build Coastguard Worker            as :meth:`Graph.create_node`.
1254*da0073e9SAndroid Build Coastguard Worker        """
1255*da0073e9SAndroid Build Coastguard Worker        if (self.owning_module and
1256*da0073e9SAndroid Build Coastguard Worker                self.owning_module.get_submodule(module_name) is None):
1257*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Attempted to insert a call_module Node with "
1258*da0073e9SAndroid Build Coastguard Worker                          "no underlying reference in the owning "
1259*da0073e9SAndroid Build Coastguard Worker                          "GraphModule! Call "
1260*da0073e9SAndroid Build Coastguard Worker                          "GraphModule.add_submodule to add the "
1261*da0073e9SAndroid Build Coastguard Worker                          "necessary submodule")
1262*da0073e9SAndroid Build Coastguard Worker        return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1265*da0073e9SAndroid Build Coastguard Worker    def call_method(self,
1266*da0073e9SAndroid Build Coastguard Worker                    method_name: str,
1267*da0073e9SAndroid Build Coastguard Worker                    args: Optional[Tuple['Argument', ...]] = None,
1268*da0073e9SAndroid Build Coastguard Worker                    kwargs: Optional[Dict[str, 'Argument']] = None,
1269*da0073e9SAndroid Build Coastguard Worker                    type_expr: Optional[Any] = None) -> Node:
1270*da0073e9SAndroid Build Coastguard Worker        """
1271*da0073e9SAndroid Build Coastguard Worker        Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
1272*da0073e9SAndroid Build Coastguard Worker        represents a call to a given method on the 0th element of ``args``.
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker        Args:
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker            method_name (str): The name of the method to apply to the self argument.
1277*da0073e9SAndroid Build Coastguard Worker                For example, if args[0] is a ``Node`` representing a ``Tensor``,
1278*da0073e9SAndroid Build Coastguard Worker                then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1281*da0073e9SAndroid Build Coastguard Worker                to the called method. Note that this *should* include a ``self`` argument.
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1284*da0073e9SAndroid Build Coastguard Worker                to the called method
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1287*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        Returns:
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker            The newly created and inserted ``call_method`` node.
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker        .. note::
1294*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1295*da0073e9SAndroid Build Coastguard Worker            as :meth:`Graph.create_node`.
1296*da0073e9SAndroid Build Coastguard Worker        """
1297*da0073e9SAndroid Build Coastguard Worker        return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1300*da0073e9SAndroid Build Coastguard Worker    def call_function(self,
1301*da0073e9SAndroid Build Coastguard Worker                      the_function: Callable[..., Any],
1302*da0073e9SAndroid Build Coastguard Worker                      args: Optional[Tuple['Argument', ...]] = None,
1303*da0073e9SAndroid Build Coastguard Worker                      kwargs: Optional[Dict[str, 'Argument']] = None,
1304*da0073e9SAndroid Build Coastguard Worker                      type_expr: Optional[Any] = None) -> Node:
1305*da0073e9SAndroid Build Coastguard Worker        """
1306*da0073e9SAndroid Build Coastguard Worker        Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
1307*da0073e9SAndroid Build Coastguard Worker        represents a call to a Python callable, specified by ``the_function``.
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        Args:
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker            the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
1312*da0073e9SAndroid Build Coastguard Worker                operator, Python function, or member of the ``builtins`` or ``operator``
1313*da0073e9SAndroid Build Coastguard Worker                namespaces.
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1316*da0073e9SAndroid Build Coastguard Worker                to the called function.
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1319*da0073e9SAndroid Build Coastguard Worker                to the called function
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1322*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker        Returns:
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker            The newly created and inserted ``call_function`` node.
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        .. note::
1329*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1330*da0073e9SAndroid Build Coastguard Worker            as :meth:`Graph.create_node`.
1331*da0073e9SAndroid Build Coastguard Worker        """
1332*da0073e9SAndroid Build Coastguard Worker        return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1335*da0073e9SAndroid Build Coastguard Worker    def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
1336*da0073e9SAndroid Build Coastguard Worker        """
1337*da0073e9SAndroid Build Coastguard Worker        Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
1338*da0073e9SAndroid Build Coastguard Worker        the graph of node to the graph of self. Example::
1339*da0073e9SAndroid Build Coastguard Worker
1340*da0073e9SAndroid Build Coastguard Worker            # Copying all the nodes in `g` into `new_graph`
1341*da0073e9SAndroid Build Coastguard Worker            g : torch.fx.Graph = ...
1342*da0073e9SAndroid Build Coastguard Worker            new_graph = torch.fx.graph()
1343*da0073e9SAndroid Build Coastguard Worker            value_remap = {}
1344*da0073e9SAndroid Build Coastguard Worker            for node in g.nodes:
1345*da0073e9SAndroid Build Coastguard Worker                value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        Args:
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker            node (Node): The node to copy into ``self``.
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker            arg_transform (Callable[[Node], Argument]): A function that transforms
1352*da0073e9SAndroid Build Coastguard Worker                ``Node`` arguments in node's ``args`` and ``kwargs`` into the
1353*da0073e9SAndroid Build Coastguard Worker                equivalent argument in ``self``. In the simplest case, this should
1354*da0073e9SAndroid Build Coastguard Worker                retrieve a value out of a table mapping Nodes in the original
1355*da0073e9SAndroid Build Coastguard Worker                graph to ``self``.
1356*da0073e9SAndroid Build Coastguard Worker        """
1357*da0073e9SAndroid Build Coastguard Worker        args = map_arg(node.args, arg_transform)
1358*da0073e9SAndroid Build Coastguard Worker        kwargs = map_arg(node.kwargs, arg_transform)
1359*da0073e9SAndroid Build Coastguard Worker        assert isinstance(args, tuple)
1360*da0073e9SAndroid Build Coastguard Worker        assert isinstance(kwargs, dict)
1361*da0073e9SAndroid Build Coastguard Worker        result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
1362*da0073e9SAndroid Build Coastguard Worker        result_node.meta = copy.copy(node.meta)
1363*da0073e9SAndroid Build Coastguard Worker        return result_node
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1366*da0073e9SAndroid Build Coastguard Worker    def output(self, result: 'Argument', type_expr: Optional[Any] = None):
1367*da0073e9SAndroid Build Coastguard Worker        """
1368*da0073e9SAndroid Build Coastguard Worker        Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
1369*da0073e9SAndroid Build Coastguard Worker        a ``return`` statement in Python code. ``result`` is the value that should
1370*da0073e9SAndroid Build Coastguard Worker        be returned.
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        Args:
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker            result (Argument): The value to be returned.
1375*da0073e9SAndroid Build Coastguard Worker
1376*da0073e9SAndroid Build Coastguard Worker            type_expr (Optional[Any]): an optional type annotation representing the
1377*da0073e9SAndroid Build Coastguard Worker                Python type the output of this node will have.
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker        .. note::
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker            The same insertion point and type expression rules apply for this method
1382*da0073e9SAndroid Build Coastguard Worker            as ``Graph.create_node``.
1383*da0073e9SAndroid Build Coastguard Worker        """
1384*da0073e9SAndroid Build Coastguard Worker        return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker    def _target_to_str(self, target : Target) -> str:
1387*da0073e9SAndroid Build Coastguard Worker        if callable(target):
1388*da0073e9SAndroid Build Coastguard Worker            op = target.__name__
1389*da0073e9SAndroid Build Coastguard Worker        else:
1390*da0073e9SAndroid Build Coastguard Worker            assert isinstance(target, str)
1391*da0073e9SAndroid Build Coastguard Worker            op = target
1392*da0073e9SAndroid Build Coastguard Worker            if _is_magic(op):
1393*da0073e9SAndroid Build Coastguard Worker                op = op[2:-2]
1394*da0073e9SAndroid Build Coastguard Worker        op = _snake_case(op)
1395*da0073e9SAndroid Build Coastguard Worker        return op
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1398*da0073e9SAndroid Build Coastguard Worker    def python_code(
1399*da0073e9SAndroid Build Coastguard Worker        self, root_module: str, *,
1400*da0073e9SAndroid Build Coastguard Worker        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False
1401*da0073e9SAndroid Build Coastguard Worker    ) -> PythonCode:
1402*da0073e9SAndroid Build Coastguard Worker        """
1403*da0073e9SAndroid Build Coastguard Worker        Turn this ``Graph`` into valid Python code.
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker        Args:
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker            root_module (str): The name of the root module on which to look-up
1408*da0073e9SAndroid Build Coastguard Worker                qualified name targets. This is usually 'self'.
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker        Returns:
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker            A PythonCode object, consisting of two fields:
1413*da0073e9SAndroid Build Coastguard Worker                src: the Python source code representing the object
1414*da0073e9SAndroid Build Coastguard Worker                globals: a dictionary of global names in `src` -> the objects that they reference.
1415*da0073e9SAndroid Build Coastguard Worker        """
1416*da0073e9SAndroid Build Coastguard Worker        # NOTE: [Graph Namespaces]
1417*da0073e9SAndroid Build Coastguard Worker        #
1418*da0073e9SAndroid Build Coastguard Worker        # There are two types of symbols in generated Python source code:
1419*da0073e9SAndroid Build Coastguard Worker        # locals and globals.
1420*da0073e9SAndroid Build Coastguard Worker        #   Locals are locally defined by the output of a node in the Graph.
1421*da0073e9SAndroid Build Coastguard Worker        #   Globals are references to external objects, like functions or types.
1422*da0073e9SAndroid Build Coastguard Worker        #
1423*da0073e9SAndroid Build Coastguard Worker        # When generating Python code, we need to make sure to name things
1424*da0073e9SAndroid Build Coastguard Worker        # appropriately. In particular:
1425*da0073e9SAndroid Build Coastguard Worker        # - All names should be unique, to avoid weird shadowing bugs.
1426*da0073e9SAndroid Build Coastguard Worker        # - These names need to be consistent, e.g. a object should always be
1427*da0073e9SAndroid Build Coastguard Worker        #   referenced by the same name.
1428*da0073e9SAndroid Build Coastguard Worker        #
1429*da0073e9SAndroid Build Coastguard Worker        # To do this, we create a new namespace just for this source. All names
1430*da0073e9SAndroid Build Coastguard Worker        # that get printed must come from this namespace.
1431*da0073e9SAndroid Build Coastguard Worker        #
1432*da0073e9SAndroid Build Coastguard Worker        # Why can't we re-use node.name? Because it was generated within the
1433*da0073e9SAndroid Build Coastguard Worker        # namespace `self._graph_namespace`. In order to provide uniqueness
1434*da0073e9SAndroid Build Coastguard Worker        # over both locals (node.name) *and* globals, we create a completely
1435*da0073e9SAndroid Build Coastguard Worker        # new namespace to put all identifiers in.
1436*da0073e9SAndroid Build Coastguard Worker        namespace = _Namespace()
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker        # Override Node's repr to generate a valid name within our namespace.
1439*da0073e9SAndroid Build Coastguard Worker        # Since repr() is designed to produce a valid Python expression, it
1440*da0073e9SAndroid Build Coastguard Worker        # makes sense to re-use it. This way, it's easy to print something like
1441*da0073e9SAndroid Build Coastguard Worker        # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
1442*da0073e9SAndroid Build Coastguard Worker        # implemented cooperatively to allow this.
1443*da0073e9SAndroid Build Coastguard Worker        def node_repr(n: Node):
1444*da0073e9SAndroid Build Coastguard Worker            return namespace.create_name(n.name, n)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker        @contextmanager
1447*da0073e9SAndroid Build Coastguard Worker        def override_node_repr(graph: Graph):
1448*da0073e9SAndroid Build Coastguard Worker            orig_repr_fns = {}
1449*da0073e9SAndroid Build Coastguard Worker            for node in graph.nodes:
1450*da0073e9SAndroid Build Coastguard Worker                orig_repr_fns[node] = node._repr_fn
1451*da0073e9SAndroid Build Coastguard Worker                node._repr_fn = node_repr
1452*da0073e9SAndroid Build Coastguard Worker            try:
1453*da0073e9SAndroid Build Coastguard Worker                yield None
1454*da0073e9SAndroid Build Coastguard Worker            finally:
1455*da0073e9SAndroid Build Coastguard Worker                # restore the original repr functions
1456*da0073e9SAndroid Build Coastguard Worker                for node in graph.nodes:
1457*da0073e9SAndroid Build Coastguard Worker                    node._repr_fn = orig_repr_fns[node]
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        with override_node_repr(self):
1460*da0073e9SAndroid Build Coastguard Worker            return self._python_code(
1461*da0073e9SAndroid Build Coastguard Worker                root_module, namespace,
1462*da0073e9SAndroid Build Coastguard Worker                verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored
1463*da0073e9SAndroid Build Coastguard Worker            )
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    def _python_code(
1466*da0073e9SAndroid Build Coastguard Worker        self, root_module: str, namespace: _Namespace, *,
1467*da0073e9SAndroid Build Coastguard Worker        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False,
1468*da0073e9SAndroid Build Coastguard Worker    ) -> PythonCode:
1469*da0073e9SAndroid Build Coastguard Worker        return self._codegen._gen_python_code(
1470*da0073e9SAndroid Build Coastguard Worker            self.nodes, root_module, namespace,
1471*da0073e9SAndroid Build Coastguard Worker            verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored
1472*da0073e9SAndroid Build Coastguard Worker        )
1473*da0073e9SAndroid Build Coastguard Worker
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1476*da0073e9SAndroid Build Coastguard Worker        """
1477*da0073e9SAndroid Build Coastguard Worker        Return a human-readable (not machine-readable) string representation
1478*da0073e9SAndroid Build Coastguard Worker        of this Graph
1479*da0073e9SAndroid Build Coastguard Worker        """
1480*da0073e9SAndroid Build Coastguard Worker        placeholder_names : List[str] = []
1481*da0073e9SAndroid Build Coastguard Worker        # This is a one-element array just so ``format_node`` can modify the closed
1482*da0073e9SAndroid Build Coastguard Worker        # over value
1483*da0073e9SAndroid Build Coastguard Worker        maybe_return_typename : List[str] = ['']
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker        node_strs = [node.format_node(placeholder_names) for node in self.nodes]
1486*da0073e9SAndroid Build Coastguard Worker        param_str = ', '.join(placeholder_names)
1487*da0073e9SAndroid Build Coastguard Worker        s = f'graph({param_str}){maybe_return_typename[0]}:'
1488*da0073e9SAndroid Build Coastguard Worker        for node_str in node_strs:
1489*da0073e9SAndroid Build Coastguard Worker            if node_str:
1490*da0073e9SAndroid Build Coastguard Worker                s += '\n    ' + node_str
1491*da0073e9SAndroid Build Coastguard Worker        return s
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1494*da0073e9SAndroid Build Coastguard Worker    def print_tabular(self):
1495*da0073e9SAndroid Build Coastguard Worker        """
1496*da0073e9SAndroid Build Coastguard Worker        Prints the intermediate representation of the graph in tabular
1497*da0073e9SAndroid Build Coastguard Worker        format. Note that this API requires the ``tabulate`` module to be
1498*da0073e9SAndroid Build Coastguard Worker        installed.
1499*da0073e9SAndroid Build Coastguard Worker        """
1500*da0073e9SAndroid Build Coastguard Worker        try:
1501*da0073e9SAndroid Build Coastguard Worker            from tabulate import tabulate
1502*da0073e9SAndroid Build Coastguard Worker        except ImportError:
1503*da0073e9SAndroid Build Coastguard Worker            print("`print_tabular` relies on the library `tabulate`, "
1504*da0073e9SAndroid Build Coastguard Worker                  "which could not be found on this machine. Run `pip "
1505*da0073e9SAndroid Build Coastguard Worker                  "install tabulate` to install the library.")
1506*da0073e9SAndroid Build Coastguard Worker            raise
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
1509*da0073e9SAndroid Build Coastguard Worker                      for n in self.nodes]
1510*da0073e9SAndroid Build Coastguard Worker        print(tabulate(node_specs,
1511*da0073e9SAndroid Build Coastguard Worker              headers=['opcode', 'name', 'target', 'args', 'kwargs']))
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1514*da0073e9SAndroid Build Coastguard Worker    def lint(self):
1515*da0073e9SAndroid Build Coastguard Worker        """
1516*da0073e9SAndroid Build Coastguard Worker        Runs various checks on this Graph to make sure it is well-formed. In
1517*da0073e9SAndroid Build Coastguard Worker        particular:
1518*da0073e9SAndroid Build Coastguard Worker        - Checks Nodes have correct ownership (owned by this graph)
1519*da0073e9SAndroid Build Coastguard Worker        - Checks Nodes appear in topological order
1520*da0073e9SAndroid Build Coastguard Worker        - If this Graph has an owning GraphModule, checks that targets
1521*da0073e9SAndroid Build Coastguard Worker        exist in that GraphModule
1522*da0073e9SAndroid Build Coastguard Worker        """
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        # Check topo order
1525*da0073e9SAndroid Build Coastguard Worker        def check_arg(arg : Node, n : Optional[Node] = None) -> None:
1526*da0073e9SAndroid Build Coastguard Worker            context_str = f' of Node \'{n}\' ' if n else ' '
1527*da0073e9SAndroid Build Coastguard Worker            if arg.graph is not self:
1528*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
1529*da0073e9SAndroid Build Coastguard Worker                                   f'but was used as an argument! If you are copying nodes from another graph, make '
1530*da0073e9SAndroid Build Coastguard Worker                                   f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
1531*da0073e9SAndroid Build Coastguard Worker            if arg not in seen_values:
1532*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
1533*da0073e9SAndroid Build Coastguard Worker                                   f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker        seen_names : Set[str] = set()
1536*da0073e9SAndroid Build Coastguard Worker        seen_values : Set[Node] = set()
1537*da0073e9SAndroid Build Coastguard Worker        for node in self.nodes:
1538*da0073e9SAndroid Build Coastguard Worker            if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
1539*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
1540*da0073e9SAndroid Build Coastguard Worker            if node.graph is not self:
1541*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
1542*da0073e9SAndroid Build Coastguard Worker            if node not in self._find_nodes_lookup_table:
1543*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"Node '{node}' is not added to the side table")
1544*da0073e9SAndroid Build Coastguard Worker            map_arg(node.args, lambda arg: check_arg(arg, node))
1545*da0073e9SAndroid Build Coastguard Worker            map_arg(node.kwargs, lambda arg: check_arg(arg, node))
1546*da0073e9SAndroid Build Coastguard Worker            seen_values.add(node)
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker            if node.name in seen_names:
1549*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Node redefined name {node.name}!')
1550*da0073e9SAndroid Build Coastguard Worker            seen_names.add(node.name)
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker        # Check targets are legit
1553*da0073e9SAndroid Build Coastguard Worker        if self.owning_module:
1554*da0073e9SAndroid Build Coastguard Worker            num_warnings = 0
1555*da0073e9SAndroid Build Coastguard Worker            MAX_WARNINGS = 5
1556*da0073e9SAndroid Build Coastguard Worker            for node in self.nodes:
1557*da0073e9SAndroid Build Coastguard Worker                if node.op == 'call_function':
1558*da0073e9SAndroid Build Coastguard Worker                    if not callable(node.target):
1559*da0073e9SAndroid Build Coastguard Worker                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
1560*da0073e9SAndroid Build Coastguard Worker                                         'a Callable is expected')
1561*da0073e9SAndroid Build Coastguard Worker                else:
1562*da0073e9SAndroid Build Coastguard Worker                    if not isinstance(node.target, str):
1563*da0073e9SAndroid Build Coastguard Worker                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
1564*da0073e9SAndroid Build Coastguard Worker                                         'a str is expected')
1565*da0073e9SAndroid Build Coastguard Worker                if node.op in ['get_attr', 'call_module']:
1566*da0073e9SAndroid Build Coastguard Worker                    target_atoms = node.target.split('.')
1567*da0073e9SAndroid Build Coastguard Worker                    m_itr = self.owning_module
1568*da0073e9SAndroid Build Coastguard Worker                    for i, atom in enumerate(target_atoms):
1569*da0073e9SAndroid Build Coastguard Worker                        new_m_itr = getattr(m_itr, atom, None)
1570*da0073e9SAndroid Build Coastguard Worker                        seen_qualname = '.'.join(target_atoms[:i])
1571*da0073e9SAndroid Build Coastguard Worker                        if new_m_itr is None:
1572*da0073e9SAndroid Build Coastguard Worker                            raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
1573*da0073e9SAndroid Build Coastguard Worker                                               f'{atom} of {seen_qualname}')
1574*da0073e9SAndroid Build Coastguard Worker                        if (node.op == "call_module"
1575*da0073e9SAndroid Build Coastguard Worker                                and not isinstance(new_m_itr, torch.nn.Module)):
1576*da0073e9SAndroid Build Coastguard Worker                            raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
1577*da0073e9SAndroid Build Coastguard Worker                                               'not reference an nn.Module')
1578*da0073e9SAndroid Build Coastguard Worker                        elif (node.op == "get_attr"
1579*da0073e9SAndroid Build Coastguard Worker                              and not isinstance(new_m_itr, torch.nn.Module)
1580*da0073e9SAndroid Build Coastguard Worker                              and not isinstance(new_m_itr, torch.nn.Parameter)
1581*da0073e9SAndroid Build Coastguard Worker                              and atom not in m_itr._buffers):
1582*da0073e9SAndroid Build Coastguard Worker                            if num_warnings < MAX_WARNINGS:
1583*da0073e9SAndroid Build Coastguard Worker                                # Don't emit this warning too frequently,
1584*da0073e9SAndroid Build Coastguard Worker                                # for very large graphs this can become very expensive
1585*da0073e9SAndroid Build Coastguard Worker                                # from a performance perspective.
1586*da0073e9SAndroid Build Coastguard Worker                                warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
1587*da0073e9SAndroid Build Coastguard Worker                                              'not reference an nn.Module, nn.Parameter, or buffer, which is '
1588*da0073e9SAndroid Build Coastguard Worker                                              'what \'get_attr\' Nodes typically target')
1589*da0073e9SAndroid Build Coastguard Worker                            num_warnings += 1
1590*da0073e9SAndroid Build Coastguard Worker                        else:
1591*da0073e9SAndroid Build Coastguard Worker                            m_itr = new_m_itr
1592*da0073e9SAndroid Build Coastguard Worker            if num_warnings > MAX_WARNINGS:
1593*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
1594*da0073e9SAndroid Build Coastguard Worker                    f'Additional {num_warnings - MAX_WARNINGS} warnings '
1595*da0073e9SAndroid Build Coastguard Worker                    'suppressed about get_attr references'
1596*da0073e9SAndroid Build Coastguard Worker                )
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
1599*da0073e9SAndroid Build Coastguard Worker    def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None):
1600*da0073e9SAndroid Build Coastguard Worker        """
1601*da0073e9SAndroid Build Coastguard Worker        Remove all dead code from the graph, based on each node's number of
1602*da0073e9SAndroid Build Coastguard Worker        users, and whether the nodes have any side effects. The graph must be
1603*da0073e9SAndroid Build Coastguard Worker        topologically sorted before calling.
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker        Args:
1606*da0073e9SAndroid Build Coastguard Worker            is_impure_node (Optional[Callable[[Node], bool]]): A function that returns
1607*da0073e9SAndroid Build Coastguard Worker            whether a node is impure. If this is None, then the default behavior is to
1608*da0073e9SAndroid Build Coastguard Worker            use Node.is_impure.
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker        Returns:
1611*da0073e9SAndroid Build Coastguard Worker          bool: Whether the graph was changed as a result of the pass.
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker        Example:
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        Before dead code is eliminated, `a` from `a = x + 1` below has no users
1616*da0073e9SAndroid Build Coastguard Worker        and thus can be eliminated from the graph without having an effect.
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1621*da0073e9SAndroid Build Coastguard Worker                a = x + 1
1622*da0073e9SAndroid Build Coastguard Worker                return x + self.attr_1
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        After dead code is eliminated, `a = x + 1` has been removed, and the rest
1625*da0073e9SAndroid Build Coastguard Worker        of `forward` remains.
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1630*da0073e9SAndroid Build Coastguard Worker                return x + self.attr_1
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        .. warning::
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker            Dead code elimination has some heuristics to avoid removing
1635*da0073e9SAndroid Build Coastguard Worker            side-effectful nodes (see Node.is_impure) but in general coverage
1636*da0073e9SAndroid Build Coastguard Worker            is very bad, so you should assume that this method is not sound
1637*da0073e9SAndroid Build Coastguard Worker            to call unless you know that your FX graph consists entirely
1638*da0073e9SAndroid Build Coastguard Worker            of functional operations or you supply your own custom
1639*da0073e9SAndroid Build Coastguard Worker            function for detecting side-effectful nodes.
1640*da0073e9SAndroid Build Coastguard Worker        """
1641*da0073e9SAndroid Build Coastguard Worker        # Lint the graph first to make sure its topologically sorted, otherwise
1642*da0073e9SAndroid Build Coastguard Worker        # DCE below will not behave as expected.
1643*da0073e9SAndroid Build Coastguard Worker        self.lint()
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker        def has_side_effect(node):
1646*da0073e9SAndroid Build Coastguard Worker            if is_impure_node is not None:
1647*da0073e9SAndroid Build Coastguard Worker                return is_impure_node(node)
1648*da0073e9SAndroid Build Coastguard Worker            return node.is_impure()
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        # Reverse iterate so that when we remove a node, any nodes used as an
1651*da0073e9SAndroid Build Coastguard Worker        # input to that node have an updated user count that no longer reflects
1652*da0073e9SAndroid Build Coastguard Worker        # the removed node.
1653*da0073e9SAndroid Build Coastguard Worker        changed = False
1654*da0073e9SAndroid Build Coastguard Worker        for node in reversed(self.nodes):
1655*da0073e9SAndroid Build Coastguard Worker            if not has_side_effect(node) and len(node.users) == 0:
1656*da0073e9SAndroid Build Coastguard Worker                self.erase_node(node)
1657*da0073e9SAndroid Build Coastguard Worker                changed = True
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        return changed
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
1662*da0073e9SAndroid Build Coastguard Worker    def set_codegen(self, codegen: CodeGen):
1663*da0073e9SAndroid Build Coastguard Worker        self._codegen = codegen
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
1666*da0073e9SAndroid Build Coastguard Worker    def on_generate_code(
1667*da0073e9SAndroid Build Coastguard Worker        self,
1668*da0073e9SAndroid Build Coastguard Worker        make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]
1669*da0073e9SAndroid Build Coastguard Worker    ):
1670*da0073e9SAndroid Build Coastguard Worker        """Register a transformer function when python code is generated
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker        Args:
1673*da0073e9SAndroid Build Coastguard Worker            make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
1674*da0073e9SAndroid Build Coastguard Worker                a function that returns a code transformer to be registered.
1675*da0073e9SAndroid Build Coastguard Worker                This function is called by `on_generate_code` to obtain the
1676*da0073e9SAndroid Build Coastguard Worker                code transformer.
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker                This function is also given as its input the currently
1679*da0073e9SAndroid Build Coastguard Worker                registered code transformer (or None if nothing is registered),
1680*da0073e9SAndroid Build Coastguard Worker                in case it is not desirable to overwrite it. This is useful to
1681*da0073e9SAndroid Build Coastguard Worker                chain code transformers together.
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker        Returns:
1684*da0073e9SAndroid Build Coastguard Worker            a context manager that when used in a `with` statement, to automatically
1685*da0073e9SAndroid Build Coastguard Worker            restore the previously registered code transformer.
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker        Example:
1688*da0073e9SAndroid Build Coastguard Worker
1689*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker
1692*da0073e9SAndroid Build Coastguard Worker            gm: fx.GraphModule = ...
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker            # This is a code transformer we want to register. This code
1695*da0073e9SAndroid Build Coastguard Worker            # transformer prepends a pdb import and trace statement at the very
1696*da0073e9SAndroid Build Coastguard Worker            # beginning of the generated torch.fx code to allow for manual
1697*da0073e9SAndroid Build Coastguard Worker            # debugging with the PDB library.
1698*da0073e9SAndroid Build Coastguard Worker            def insert_pdb(body):
1699*da0073e9SAndroid Build Coastguard Worker                return ["import pdb; pdb.set_trace()\\n", *body]
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker            # Registers `insert_pdb`, and overwrites the current registered
1702*da0073e9SAndroid Build Coastguard Worker            # code transformer (given by `_` to the lambda):
1703*da0073e9SAndroid Build Coastguard Worker            gm.graph.on_generate_code(
1704*da0073e9SAndroid Build Coastguard Worker                lambda _: insert_pdb
1705*da0073e9SAndroid Build Coastguard Worker            )
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker            # Or alternatively, registers a code transformer which first
1708*da0073e9SAndroid Build Coastguard Worker            # runs `body` through existing registered transformer, then
1709*da0073e9SAndroid Build Coastguard Worker            # through `insert_pdb`:
1710*da0073e9SAndroid Build Coastguard Worker            gm.graph.on_generate_code(
1711*da0073e9SAndroid Build Coastguard Worker                lambda current_trans: (
1712*da0073e9SAndroid Build Coastguard Worker                    lambda body: insert_pdb(
1713*da0073e9SAndroid Build Coastguard Worker                        current_trans(body) if current_trans
1714*da0073e9SAndroid Build Coastguard Worker                        else body
1715*da0073e9SAndroid Build Coastguard Worker                    )
1716*da0073e9SAndroid Build Coastguard Worker                )
1717*da0073e9SAndroid Build Coastguard Worker            )
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker            gm.recompile()
1720*da0073e9SAndroid Build Coastguard Worker            gm(*inputs)  # drops into pdb
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        This function can also be used as a context manager, with the benefit to
1724*da0073e9SAndroid Build Coastguard Worker        automatically restores the previously registered code transformer:
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker        .. code-block:: python
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker            # ... continue from previous example
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker            with gm.graph.on_generate_code(lambda _: insert_pdb):
1731*da0073e9SAndroid Build Coastguard Worker                # do more stuff with `gm`...
1732*da0073e9SAndroid Build Coastguard Worker                gm.recompile()
1733*da0073e9SAndroid Build Coastguard Worker                gm(*inputs)  # drops into pdb
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker            # now previous code transformer is restored (but `gm`'s code with pdb
1736*da0073e9SAndroid Build Coastguard Worker            # remains - that means you can run `gm` with pdb here too, until you
1737*da0073e9SAndroid Build Coastguard Worker            # run next `recompile()`).
1738*da0073e9SAndroid Build Coastguard Worker        """
1739*da0073e9SAndroid Build Coastguard Worker        on_gen_code_old = self._codegen._body_transformer
1740*da0073e9SAndroid Build Coastguard Worker        self._codegen._body_transformer = make_transformer(on_gen_code_old)
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
1743*da0073e9SAndroid Build Coastguard Worker        def on_generate_code_context_manager():
1744*da0073e9SAndroid Build Coastguard Worker            try:
1745*da0073e9SAndroid Build Coastguard Worker                yield
1746*da0073e9SAndroid Build Coastguard Worker            finally:
1747*da0073e9SAndroid Build Coastguard Worker                self._codegen._body_transformer = on_gen_code_old
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker        return on_generate_code_context_manager()
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Workerreflectable_magic_methods = {
1753*da0073e9SAndroid Build Coastguard Worker    'add': '{} + {}',
1754*da0073e9SAndroid Build Coastguard Worker    'sub': '{} - {}',
1755*da0073e9SAndroid Build Coastguard Worker    'mul': '{} * {}',
1756*da0073e9SAndroid Build Coastguard Worker    'floordiv': '{} // {}',
1757*da0073e9SAndroid Build Coastguard Worker    'truediv': '{} / {}',
1758*da0073e9SAndroid Build Coastguard Worker    'div': '{} / {}',
1759*da0073e9SAndroid Build Coastguard Worker    'mod': '{} % {}',
1760*da0073e9SAndroid Build Coastguard Worker    'pow': '{} ** {}',
1761*da0073e9SAndroid Build Coastguard Worker    'lshift': '{} << {}',
1762*da0073e9SAndroid Build Coastguard Worker    'rshift': '{} >> {}',
1763*da0073e9SAndroid Build Coastguard Worker    'and_': '{} & {}',
1764*da0073e9SAndroid Build Coastguard Worker    'or_': '{} | {}',
1765*da0073e9SAndroid Build Coastguard Worker    'xor': '{} ^ {}',
1766*da0073e9SAndroid Build Coastguard Worker    'getitem': '{}[{}]',
1767*da0073e9SAndroid Build Coastguard Worker    'matmul': '{} @ {}',
1768*da0073e9SAndroid Build Coastguard Worker}
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Workermagic_methods = dict({
1771*da0073e9SAndroid Build Coastguard Worker    'eq': '{} == {}',
1772*da0073e9SAndroid Build Coastguard Worker    'ne': '{} != {}',
1773*da0073e9SAndroid Build Coastguard Worker    'lt': '{} < {}',
1774*da0073e9SAndroid Build Coastguard Worker    'gt': '{} > {}',
1775*da0073e9SAndroid Build Coastguard Worker    'le': '{} <= {}',
1776*da0073e9SAndroid Build Coastguard Worker    'ge': '{} >= {}',
1777*da0073e9SAndroid Build Coastguard Worker    'pos': '+{}',
1778*da0073e9SAndroid Build Coastguard Worker    'neg': '-{}',
1779*da0073e9SAndroid Build Coastguard Worker    'invert': '~{}'}, **reflectable_magic_methods)
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Workerinplace_methods = {
1782*da0073e9SAndroid Build Coastguard Worker    'iadd': '{} += {}',
1783*da0073e9SAndroid Build Coastguard Worker    'iand': '{} &= {}',
1784*da0073e9SAndroid Build Coastguard Worker    'ifloordiv': '{} //= {}',
1785*da0073e9SAndroid Build Coastguard Worker    'ilshift': '{} <<= {}',
1786*da0073e9SAndroid Build Coastguard Worker    'imod': '{} %= {}',
1787*da0073e9SAndroid Build Coastguard Worker    'imul': '{} *= {}',
1788*da0073e9SAndroid Build Coastguard Worker    'imatmul': '{} @= {}',
1789*da0073e9SAndroid Build Coastguard Worker    'ior': '{} |= {}',
1790*da0073e9SAndroid Build Coastguard Worker    'ipow': '{} **= {}',
1791*da0073e9SAndroid Build Coastguard Worker    'irshift': '{} >>= {}',
1792*da0073e9SAndroid Build Coastguard Worker    'isub': '{} -= {}',
1793*da0073e9SAndroid Build Coastguard Worker    'itruediv': '{} /= {}',
1794*da0073e9SAndroid Build Coastguard Worker    'ixor': '{} ^= {}',
1795*da0073e9SAndroid Build Coastguard Worker    'setitem': '{}[{}] = {}',
1796*da0073e9SAndroid Build Coastguard Worker}
1797