1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 3*da0073e9SAndroid Build Coastguard Workerfrom .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name 4*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 5*da0073e9SAndroid Build Coastguard Workerfrom . import _pytree as fx_pytree 6*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility 7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _NodeIter 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport os 10*da0073e9SAndroid Build Coastguard Workerimport contextlib 11*da0073e9SAndroid Build Coastguard Workerfrom typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable 12*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 13*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 14*da0073e9SAndroid Build Coastguard Workerimport copy 15*da0073e9SAndroid Build Coastguard Workerimport enum 16*da0073e9SAndroid Build Coastguard Workerimport torch 17*da0073e9SAndroid Build Coastguard Workerimport keyword 18*da0073e9SAndroid Build Coastguard Workerimport re 19*da0073e9SAndroid Build Coastguard Workerimport builtins 20*da0073e9SAndroid Build Coastguard Workerimport math 21*da0073e9SAndroid Build Coastguard Workerimport warnings 22*da0073e9SAndroid Build Coastguard Workerimport inspect 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker__all__ = ["PythonCode", "CodeGen", "Graph"] 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 27*da0073e9SAndroid Build Coastguard Worker from .graph_module import GraphModule # noqa: F401 28*da0073e9SAndroid Build Coastguard Worker from ._symbolic_trace import Tracer # noqa: F401 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker# Mapping of builtins to their `typing` equivalent. 32*da0073e9SAndroid Build Coastguard Worker_origin_type_map = { 33*da0073e9SAndroid Build Coastguard Worker list: List, 34*da0073e9SAndroid Build Coastguard Worker dict: Dict, 35*da0073e9SAndroid Build Coastguard Worker set: Set, 36*da0073e9SAndroid Build Coastguard Worker frozenset: FrozenSet, 37*da0073e9SAndroid Build Coastguard Worker tuple: Tuple, 38*da0073e9SAndroid Build Coastguard Worker} 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker# Signature for functions thattransforms the body (`list[str]`) of the 42*da0073e9SAndroid Build Coastguard Worker# generated code 43*da0073e9SAndroid Build Coastguard WorkerTransformCodeFunc = Callable[[List[str]], List[str]] 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass _CustomBuiltin(NamedTuple): 47*da0073e9SAndroid Build Coastguard Worker """Additional objs that we add to every graph's globals. 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker The repr() for some standard library objects is not valid Python code without 50*da0073e9SAndroid Build Coastguard Worker an import. For common objects of this sort, we bundle them in the globals of 51*da0073e9SAndroid Build Coastguard Worker every FX graph. 52*da0073e9SAndroid Build Coastguard Worker """ 53*da0073e9SAndroid Build Coastguard Worker # How to import this object from the standard library. 54*da0073e9SAndroid Build Coastguard Worker import_str: str 55*da0073e9SAndroid Build Coastguard Worker # The actual object, produced from that import string. 56*da0073e9SAndroid Build Coastguard Worker obj: Any 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker_custom_builtins: Dict[str, _CustomBuiltin] = {} 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerdef _register_custom_builtin(name: str, import_str: str, obj: Any): 62*da0073e9SAndroid Build Coastguard Worker _custom_builtins[name] = _CustomBuiltin(import_str, obj) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('inf', 'from math import inf', math.inf) 66*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('nan', 'from math import nan', math.nan) 67*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) 68*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('torch', 'import torch', torch) 69*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('device', 'from torch import device', torch.device) 70*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) 71*da0073e9SAndroid Build Coastguard Worker_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Workerdef _is_magic(x: str) -> bool: 75*da0073e9SAndroid Build Coastguard Worker return x.startswith('__') and x.endswith('__') 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerdef _snake_case(s: str) -> str: 79*da0073e9SAndroid Build Coastguard Worker """ 80*da0073e9SAndroid Build Coastguard Worker Transforms the given string ``s`` to a Python-style variable name 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker Examples: 83*da0073e9SAndroid Build Coastguard Worker ``mod.snake_case`` -> ``mod.snake_case`` 84*da0073e9SAndroid Build Coastguard Worker ``mod.pascalCase``-> ``mod.pascal_case`` 85*da0073e9SAndroid Build Coastguard Worker ``mod.ALL_CAPS`` -> ``mod.all_caps`` 86*da0073e9SAndroid Build Coastguard Worker """ 87*da0073e9SAndroid Build Coastguard Worker chars = [] 88*da0073e9SAndroid Build Coastguard Worker prev_lower = False 89*da0073e9SAndroid Build Coastguard Worker for c in s: 90*da0073e9SAndroid Build Coastguard Worker if prev_lower and c.isupper(): 91*da0073e9SAndroid Build Coastguard Worker chars.append('_') 92*da0073e9SAndroid Build Coastguard Worker chars.append(c.lower()) 93*da0073e9SAndroid Build Coastguard Worker prev_lower = c.islower() 94*da0073e9SAndroid Build Coastguard Worker return ''.join(chars) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerdef _is_from_torch(obj: Any) -> bool: 98*da0073e9SAndroid Build Coastguard Worker module_name = getattr(obj, '__module__', None) 99*da0073e9SAndroid Build Coastguard Worker if module_name is not None: 100*da0073e9SAndroid Build Coastguard Worker base_module = module_name.partition('.')[0] 101*da0073e9SAndroid Build Coastguard Worker return ( 102*da0073e9SAndroid Build Coastguard Worker base_module == 'torch' and 103*da0073e9SAndroid Build Coastguard Worker not module_name.startswith("torch._dynamo.") and 104*da0073e9SAndroid Build Coastguard Worker not module_name.startswith("torch._inductor.") 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker name = getattr(obj, '__name__', None) 108*da0073e9SAndroid Build Coastguard Worker # exclude torch because torch.torch.torch.torch works. idk mang 109*da0073e9SAndroid Build Coastguard Worker if name is not None and name != 'torch': 110*da0073e9SAndroid Build Coastguard Worker for guess in [torch, torch.nn.functional]: 111*da0073e9SAndroid Build Coastguard Worker if getattr(guess, name, None) is obj: 112*da0073e9SAndroid Build Coastguard Worker return True 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker return False 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerclass _Namespace: 118*da0073e9SAndroid Build Coastguard Worker """A context for associating names uniquely with objects. 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker The following invariants are enforced: 121*da0073e9SAndroid Build Coastguard Worker - Each object gets a single name. 122*da0073e9SAndroid Build Coastguard Worker - Each name is unique within a given namespace. 123*da0073e9SAndroid Build Coastguard Worker - Names generated do not shadow builtins, unless the object is indeed that builtin. 124*da0073e9SAndroid Build Coastguard Worker """ 125*da0073e9SAndroid Build Coastguard Worker def __init__(self): 126*da0073e9SAndroid Build Coastguard Worker self._obj_to_name: Dict[Any, str] = {} 127*da0073e9SAndroid Build Coastguard Worker self._unassociated_names = set() 128*da0073e9SAndroid Build Coastguard Worker self._used_names: Set[str] = set() 129*da0073e9SAndroid Build Coastguard Worker self._base_count: Dict[str, int] = defaultdict(int) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') 132*da0073e9SAndroid Build Coastguard Worker self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker def create_name(self, candidate: str, obj: Optional[Any]) -> str: 135*da0073e9SAndroid Build Coastguard Worker """Create a unique name. 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker Arguments: 138*da0073e9SAndroid Build Coastguard Worker candidate: used as the basis for the unique name, relevant to the user. 139*da0073e9SAndroid Build Coastguard Worker obj: If not None, an object that will be associated with the unique name. 140*da0073e9SAndroid Build Coastguard Worker """ 141*da0073e9SAndroid Build Coastguard Worker if obj is not None and obj in self._obj_to_name: 142*da0073e9SAndroid Build Coastguard Worker return self._obj_to_name[obj] 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # delete all characters that are illegal in a Python identifier 145*da0073e9SAndroid Build Coastguard Worker candidate = self._illegal_char_regex.sub('_', candidate) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker if not candidate: 148*da0073e9SAndroid Build Coastguard Worker candidate = '_unnamed' 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker if candidate[0].isdigit(): 151*da0073e9SAndroid Build Coastguard Worker candidate = f'_{candidate}' 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker match = self._name_suffix_regex.match(candidate) 154*da0073e9SAndroid Build Coastguard Worker if match is None: 155*da0073e9SAndroid Build Coastguard Worker base = candidate 156*da0073e9SAndroid Build Coastguard Worker num = None 157*da0073e9SAndroid Build Coastguard Worker else: 158*da0073e9SAndroid Build Coastguard Worker base, num_str = match.group(1, 2) 159*da0073e9SAndroid Build Coastguard Worker num = int(num_str) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker candidate = base if num is None else f'{base}_{num}' 162*da0073e9SAndroid Build Coastguard Worker if not num: 163*da0073e9SAndroid Build Coastguard Worker num = self._base_count[base] 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker while candidate in self._used_names or self._is_illegal_name(candidate, obj): 166*da0073e9SAndroid Build Coastguard Worker num += 1 167*da0073e9SAndroid Build Coastguard Worker candidate = f'{base}_{num}' 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker self._used_names.add(candidate) 170*da0073e9SAndroid Build Coastguard Worker self._base_count[base] = num 171*da0073e9SAndroid Build Coastguard Worker if obj is None: 172*da0073e9SAndroid Build Coastguard Worker self._unassociated_names.add(candidate) 173*da0073e9SAndroid Build Coastguard Worker else: 174*da0073e9SAndroid Build Coastguard Worker self._obj_to_name[obj] = candidate 175*da0073e9SAndroid Build Coastguard Worker return candidate 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def associate_name_with_obj(self, name: str, obj: Any): 178*da0073e9SAndroid Build Coastguard Worker """Associate a unique name with an object. 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker Neither `name` nor `obj` should be associated already. 181*da0073e9SAndroid Build Coastguard Worker """ 182*da0073e9SAndroid Build Coastguard Worker assert obj not in self._obj_to_name 183*da0073e9SAndroid Build Coastguard Worker assert name in self._unassociated_names 184*da0073e9SAndroid Build Coastguard Worker self._obj_to_name[obj] = name 185*da0073e9SAndroid Build Coastguard Worker self._unassociated_names.remove(name) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def _is_illegal_name(self, name: str, obj: Any) -> bool: 188*da0073e9SAndroid Build Coastguard Worker # 1. keywords are never allowed as names. 189*da0073e9SAndroid Build Coastguard Worker if name in keyword.kwlist: 190*da0073e9SAndroid Build Coastguard Worker return True 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker # 2. Can't shadow a builtin name, unless you *are* that builtin. 193*da0073e9SAndroid Build Coastguard Worker if name in builtins.__dict__: 194*da0073e9SAndroid Build Coastguard Worker return obj is not builtins.__dict__[name] 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker # 3. Can't shadow our custom builtins either 197*da0073e9SAndroid Build Coastguard Worker if name in _custom_builtins: 198*da0073e9SAndroid Build Coastguard Worker return obj is not _custom_builtins[name].obj 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker return False 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def _rename_object(self, obj: Any, name: str): 203*da0073e9SAndroid Build Coastguard Worker assert obj in self._obj_to_name 204*da0073e9SAndroid Build Coastguard Worker self._obj_to_name[obj] = name 205*da0073e9SAndroid Build Coastguard Worker self._used_names.add(name) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Workerdtype_abbrs = { 208*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 'bf16', 209*da0073e9SAndroid Build Coastguard Worker torch.float64: 'f64', 210*da0073e9SAndroid Build Coastguard Worker torch.float32: 'f32', 211*da0073e9SAndroid Build Coastguard Worker torch.float16: 'f16', 212*da0073e9SAndroid Build Coastguard Worker torch.float8_e4m3fn: 'f8e4m3fn', 213*da0073e9SAndroid Build Coastguard Worker torch.float8_e5m2: 'f8e5m2', 214*da0073e9SAndroid Build Coastguard Worker torch.float8_e4m3fnuz: 'f8e4m3fnuz', 215*da0073e9SAndroid Build Coastguard Worker torch.float8_e5m2fnuz: 'f8e5m2fnuz', 216*da0073e9SAndroid Build Coastguard Worker torch.complex32: 'c32', 217*da0073e9SAndroid Build Coastguard Worker torch.complex64: 'c64', 218*da0073e9SAndroid Build Coastguard Worker torch.complex128: 'c128', 219*da0073e9SAndroid Build Coastguard Worker torch.int8: 'i8', 220*da0073e9SAndroid Build Coastguard Worker torch.int16: 'i16', 221*da0073e9SAndroid Build Coastguard Worker torch.int32: 'i32', 222*da0073e9SAndroid Build Coastguard Worker torch.int64: 'i64', 223*da0073e9SAndroid Build Coastguard Worker torch.bool: 'b8', 224*da0073e9SAndroid Build Coastguard Worker torch.uint8: 'u8', 225*da0073e9SAndroid Build Coastguard Worker torch.uint16: 'u16', 226*da0073e9SAndroid Build Coastguard Worker torch.uint32: 'u32', 227*da0073e9SAndroid Build Coastguard Worker torch.uint64: 'u64', 228*da0073e9SAndroid Build Coastguard Worker torch.bits16: 'b16', 229*da0073e9SAndroid Build Coastguard Worker} 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 232*da0073e9SAndroid Build Coastguard Worker@dataclass 233*da0073e9SAndroid Build Coastguard Workerclass PythonCode: 234*da0073e9SAndroid Build Coastguard Worker """ 235*da0073e9SAndroid Build Coastguard Worker Represents all the information necessary to exec or save a graph as Python code. 236*da0073e9SAndroid Build Coastguard Worker """ 237*da0073e9SAndroid Build Coastguard Worker # Python source code for the forward function definition. 238*da0073e9SAndroid Build Coastguard Worker src: str 239*da0073e9SAndroid Build Coastguard Worker # Values in global scope during execution of `src_def`. 240*da0073e9SAndroid Build Coastguard Worker globals: Dict[str, Any] 241*da0073e9SAndroid Build Coastguard Worker # Optional mapping from the forward function's line number to 242*da0073e9SAndroid Build Coastguard Worker # node index. 243*da0073e9SAndroid Build Coastguard Worker _lineno_map: Optional[Dict[int, Optional[int]]] 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Workerdef _format_target(base: str, target: str) -> str: 247*da0073e9SAndroid Build Coastguard Worker elems = target.split('.') 248*da0073e9SAndroid Build Coastguard Worker r = base 249*da0073e9SAndroid Build Coastguard Worker for e in elems: 250*da0073e9SAndroid Build Coastguard Worker if not e.isidentifier(): 251*da0073e9SAndroid Build Coastguard Worker r = f'getattr({r}, "{e}")' 252*da0073e9SAndroid Build Coastguard Worker else: 253*da0073e9SAndroid Build Coastguard Worker r = f'{r}.{e}' 254*da0073e9SAndroid Build Coastguard Worker return r 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Workerclass _InsertPoint: 257*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph, new_insert): 258*da0073e9SAndroid Build Coastguard Worker self.graph = graph 259*da0073e9SAndroid Build Coastguard Worker self.orig_insert, graph._insert = graph._insert, new_insert 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 262*da0073e9SAndroid Build Coastguard Worker pass 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def __exit__(self, type, value, tb): 265*da0073e9SAndroid Build Coastguard Worker self.graph._insert = self.orig_insert 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Workerclass _node_list: 268*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph: 'Graph', direction: str = '_next'): 269*da0073e9SAndroid Build Coastguard Worker assert direction in ['_next', '_prev'] 270*da0073e9SAndroid Build Coastguard Worker self.graph = graph 271*da0073e9SAndroid Build Coastguard Worker self.direction = direction 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def __len__(self): 274*da0073e9SAndroid Build Coastguard Worker return self.graph._len 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 277*da0073e9SAndroid Build Coastguard Worker assert self.direction == "_prev" or self.direction == "_next" 278*da0073e9SAndroid Build Coastguard Worker yield from _NodeIter(self.graph._root, self.direction == "_prev") 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker def __reversed__(self): 281*da0073e9SAndroid Build Coastguard Worker return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Workerclass _PyTreeInfo(NamedTuple): 284*da0073e9SAndroid Build Coastguard Worker """ 285*da0073e9SAndroid Build Coastguard Worker Contains extra info stored when we're using Pytrees 286*da0073e9SAndroid Build Coastguard Worker """ 287*da0073e9SAndroid Build Coastguard Worker orig_args: List[str] 288*da0073e9SAndroid Build Coastguard Worker in_spec: pytree.TreeSpec 289*da0073e9SAndroid Build Coastguard Worker out_spec: Optional[pytree.TreeSpec] 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 292*da0073e9SAndroid Build Coastguard Workerclass _ParsedStackTrace: 293*da0073e9SAndroid Build Coastguard Worker """ 294*da0073e9SAndroid Build Coastguard Worker Represents the top-most frame of a parsed stack trace 295*da0073e9SAndroid Build Coastguard Worker """ 296*da0073e9SAndroid Build Coastguard Worker file: str 297*da0073e9SAndroid Build Coastguard Worker lineno: str 298*da0073e9SAndroid Build Coastguard Worker name: str 299*da0073e9SAndroid Build Coastguard Worker code: str 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker def get_summary_str(self): 302*da0073e9SAndroid Build Coastguard Worker return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker# get File:lineno code from stack_trace 305*da0073e9SAndroid Build Coastguard Workerdef _parse_stack_trace(stack_trace: str): 306*da0073e9SAndroid Build Coastguard Worker if stack_trace is None: 307*da0073e9SAndroid Build Coastguard Worker return None 308*da0073e9SAndroid Build Coastguard Worker pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") 309*da0073e9SAndroid Build Coastguard Worker lines = stack_trace.strip().split('\n') 310*da0073e9SAndroid Build Coastguard Worker # stacktrace should have innermost frame last, so we 311*da0073e9SAndroid Build Coastguard Worker # iterate backwards to find the first line that starts 312*da0073e9SAndroid Build Coastguard Worker # with 'File ' 313*da0073e9SAndroid Build Coastguard Worker summary_str = "" 314*da0073e9SAndroid Build Coastguard Worker for idx in range(len(lines) - 2, -1, -1): 315*da0073e9SAndroid Build Coastguard Worker line = lines[idx].strip() 316*da0073e9SAndroid Build Coastguard Worker matches = pattern.match(line) 317*da0073e9SAndroid Build Coastguard Worker if matches: 318*da0073e9SAndroid Build Coastguard Worker file = matches.group(1) 319*da0073e9SAndroid Build Coastguard Worker lineno = matches.group(2) 320*da0073e9SAndroid Build Coastguard Worker name = matches.group(3) 321*da0073e9SAndroid Build Coastguard Worker # next line should be the code 322*da0073e9SAndroid Build Coastguard Worker code = lines[idx + 1].strip() 323*da0073e9SAndroid Build Coastguard Worker return _ParsedStackTrace(file, lineno, name, code) 324*da0073e9SAndroid Build Coastguard Worker return None 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 327*da0073e9SAndroid Build Coastguard Workerclass CodeGen: 328*da0073e9SAndroid Build Coastguard Worker def __init__(self): 329*da0073e9SAndroid Build Coastguard Worker self._body_transformer: Optional[TransformCodeFunc] = None 330*da0073e9SAndroid Build Coastguard Worker self._func_name: str = "forward" 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: 333*da0073e9SAndroid Build Coastguard Worker """ 334*da0073e9SAndroid Build Coastguard Worker Given the free variables and a return annotation, generates the beginning of the FX function. 335*da0073e9SAndroid Build Coastguard Worker By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` 336*da0073e9SAndroid Build Coastguard Worker """ 337*da0073e9SAndroid Build Coastguard Worker # If the original function didn't have self as its first argument, we 338*da0073e9SAndroid Build Coastguard Worker # would have added it. 339*da0073e9SAndroid Build Coastguard Worker if len(free_vars) == 0 or free_vars[0] != 'self': 340*da0073e9SAndroid Build Coastguard Worker free_vars.insert(0, 'self') 341*da0073e9SAndroid Build Coastguard Worker return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def generate_output(self, output_args: Argument) -> str: 344*da0073e9SAndroid Build Coastguard Worker """ 345*da0073e9SAndroid Build Coastguard Worker Given the output arguments, generates the return statement of the FX function. 346*da0073e9SAndroid Build Coastguard Worker Note: The returned statement should not be indented. 347*da0073e9SAndroid Build Coastguard Worker """ 348*da0073e9SAndroid Build Coastguard Worker return f'return {repr(output_args)}' 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *args: Any) -> Any: 351*da0073e9SAndroid Build Coastguard Worker """ 352*da0073e9SAndroid Build Coastguard Worker Transforms the inputs so that the graph can take them as arguments, as 353*da0073e9SAndroid Build Coastguard Worker non-default codegen may result in the inputs to the function being 354*da0073e9SAndroid Build Coastguard Worker different from the inputs to the graph. 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker If the graph was directly runnable, this invariant should hold true 357*da0073e9SAndroid Build Coastguard Worker `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` 358*da0073e9SAndroid Build Coastguard Worker """ 359*da0073e9SAndroid Build Coastguard Worker return args 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def process_outputs(self, outputs: Any) -> Any: 362*da0073e9SAndroid Build Coastguard Worker """ 363*da0073e9SAndroid Build Coastguard Worker Transforms the outputs of the graph to be identical to the codegen. 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker See ``process_inputs`` for more details. 366*da0073e9SAndroid Build Coastguard Worker """ 367*da0073e9SAndroid Build Coastguard Worker return outputs 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker def additional_globals(self) -> List[Tuple[str, Any]]: 370*da0073e9SAndroid Build Coastguard Worker """ 371*da0073e9SAndroid Build Coastguard Worker If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. 372*da0073e9SAndroid Build Coastguard Worker For example, return ['List', typing.List] if you need ``List`` in the global context. 373*da0073e9SAndroid Build Coastguard Worker """ 374*da0073e9SAndroid Build Coastguard Worker return [] 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def _gen_python_code( 377*da0073e9SAndroid Build Coastguard Worker self, nodes, root_module: str, namespace: _Namespace, *, 378*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False 379*da0073e9SAndroid Build Coastguard Worker ) -> PythonCode: 380*da0073e9SAndroid Build Coastguard Worker free_vars: List[str] = [] 381*da0073e9SAndroid Build Coastguard Worker body: List[str] = [] 382*da0073e9SAndroid Build Coastguard Worker globals_: Dict[str, Any] = {} 383*da0073e9SAndroid Build Coastguard Worker wrapped_fns: Dict[str, None] = {} 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker # Wrap string in list to pass by reference 386*da0073e9SAndroid Build Coastguard Worker maybe_return_annotation : List[str] = [''] 387*da0073e9SAndroid Build Coastguard Worker include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") 388*da0073e9SAndroid Build Coastguard Worker include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker def add_global(name_hint: str, obj: Any): 391*da0073e9SAndroid Build Coastguard Worker """Add an obj to be tracked as a global. 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker We call this for names that reference objects external to the 394*da0073e9SAndroid Build Coastguard Worker Graph, like functions or types. 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker Returns: the global name that should be used to reference 'obj' in generated source. 397*da0073e9SAndroid Build Coastguard Worker """ 398*da0073e9SAndroid Build Coastguard Worker if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device 399*da0073e9SAndroid Build Coastguard Worker # HACK: workaround for how torch custom ops are registered. We 400*da0073e9SAndroid Build Coastguard Worker # can't import them like normal modules so they must retain their 401*da0073e9SAndroid Build Coastguard Worker # fully qualified name. 402*da0073e9SAndroid Build Coastguard Worker return _get_qualified_name(obj) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker # normalize the name hint to get a proper identifier 405*da0073e9SAndroid Build Coastguard Worker global_name = namespace.create_name(name_hint, obj) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker if global_name in globals_: 408*da0073e9SAndroid Build Coastguard Worker assert globals_[global_name] is obj 409*da0073e9SAndroid Build Coastguard Worker return global_name 410*da0073e9SAndroid Build Coastguard Worker globals_[global_name] = obj 411*da0073e9SAndroid Build Coastguard Worker return global_name 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker # Pre-fill the globals table with registered builtins. 414*da0073e9SAndroid Build Coastguard Worker for name, (_, obj) in _custom_builtins.items(): 415*da0073e9SAndroid Build Coastguard Worker add_global(name, obj) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker def type_repr(o : Any): 418*da0073e9SAndroid Build Coastguard Worker if o == (): 419*da0073e9SAndroid Build Coastguard Worker # Empty tuple is used for empty tuple type annotation Tuple[()] 420*da0073e9SAndroid Build Coastguard Worker return '()' 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker typename = _type_repr(o) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if hasattr(o, '__origin__'): 425*da0073e9SAndroid Build Coastguard Worker # This is a generic type, e.g. typing.List[torch.Tensor] 426*da0073e9SAndroid Build Coastguard Worker origin_type = _origin_type_map.get(o.__origin__, o.__origin__) 427*da0073e9SAndroid Build Coastguard Worker origin_typename = add_global(_type_repr(origin_type), origin_type) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker if hasattr(o, '__args__'): 430*da0073e9SAndroid Build Coastguard Worker # Assign global names for each of the inner type variables. 431*da0073e9SAndroid Build Coastguard Worker args = [type_repr(arg) for arg in o.__args__] 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker if len(args) == 0: 434*da0073e9SAndroid Build Coastguard Worker # Bare type, such as `typing.Tuple` with no subscript 435*da0073e9SAndroid Build Coastguard Worker # This code-path used in Python < 3.9 436*da0073e9SAndroid Build Coastguard Worker return origin_typename 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker return f'{origin_typename}[{",".join(args)}]' 439*da0073e9SAndroid Build Coastguard Worker else: 440*da0073e9SAndroid Build Coastguard Worker # Bare type, such as `typing.Tuple` with no subscript 441*da0073e9SAndroid Build Coastguard Worker # This code-path used in Python 3.9+ 442*da0073e9SAndroid Build Coastguard Worker return origin_typename 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker # Common case: this is a regular module name like 'foo.bar.baz' 445*da0073e9SAndroid Build Coastguard Worker return add_global(typename, o) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker codes = { 448*da0073e9SAndroid Build Coastguard Worker "yellow": "\033[33m", 449*da0073e9SAndroid Build Coastguard Worker "cyan": "\033[36m", 450*da0073e9SAndroid Build Coastguard Worker "green": "\033[32m", 451*da0073e9SAndroid Build Coastguard Worker "blue": "\033[34m", 452*da0073e9SAndroid Build Coastguard Worker "red": "\033[31m", 453*da0073e9SAndroid Build Coastguard Worker "dim": "\033[2m", 454*da0073e9SAndroid Build Coastguard Worker "dim_blue": "\033[2m\033[34m", 455*da0073e9SAndroid Build Coastguard Worker "dim_green": "\033[2m\033[32m", 456*da0073e9SAndroid Build Coastguard Worker "reset": "\033[0m", 457*da0073e9SAndroid Build Coastguard Worker } 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def make_wrapper_func(name): 460*da0073e9SAndroid Build Coastguard Worker def f(s): 461*da0073e9SAndroid Build Coastguard Worker if colored: 462*da0073e9SAndroid Build Coastguard Worker return f"{codes[name]}{s}{codes['reset']}" 463*da0073e9SAndroid Build Coastguard Worker return s 464*da0073e9SAndroid Build Coastguard Worker return f 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker yellow = make_wrapper_func("yellow") 467*da0073e9SAndroid Build Coastguard Worker cyan = make_wrapper_func("cyan") 468*da0073e9SAndroid Build Coastguard Worker red = make_wrapper_func("red") 469*da0073e9SAndroid Build Coastguard Worker green = make_wrapper_func("green") 470*da0073e9SAndroid Build Coastguard Worker dim_green = make_wrapper_func("dim_green") 471*da0073e9SAndroid Build Coastguard Worker dim = make_wrapper_func("dim") 472*da0073e9SAndroid Build Coastguard Worker dim_blue = make_wrapper_func("dim_blue") 473*da0073e9SAndroid Build Coastguard Worker blue = make_wrapper_func("blue") 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker def _get_repr(arg: Any) -> str: 476*da0073e9SAndroid Build Coastguard Worker # Handle NamedTuples (if it has `_fields`) via add_global. 477*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, tuple) and hasattr(arg, '_fields'): 478*da0073e9SAndroid Build Coastguard Worker qualified_name = _get_qualified_name(type(arg)) 479*da0073e9SAndroid Build Coastguard Worker global_name = add_global(qualified_name, type(arg)) 480*da0073e9SAndroid Build Coastguard Worker return f"{global_name}{repr(tuple(arg))}" 481*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): 482*da0073e9SAndroid Build Coastguard Worker qualified_name = _get_qualified_name(arg) 483*da0073e9SAndroid Build Coastguard Worker global_name = add_global(qualified_name, arg) 484*da0073e9SAndroid Build Coastguard Worker return f"{global_name}" 485*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, enum.Enum): 486*da0073e9SAndroid Build Coastguard Worker cls = arg.__class__ 487*da0073e9SAndroid Build Coastguard Worker clsname = add_global(cls.__name__, cls) 488*da0073e9SAndroid Build Coastguard Worker return f"{clsname}.{arg.name}" 489*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, Node): 490*da0073e9SAndroid Build Coastguard Worker return repr(arg) 491*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, torch.Tensor): 492*da0073e9SAndroid Build Coastguard Worker size = list(arg.size()) 493*da0073e9SAndroid Build Coastguard Worker dtype = str(arg.dtype).split(".")[-1] 494*da0073e9SAndroid Build Coastguard Worker return f"torch.Tensor(size={size}, dtype={dtype})" 495*da0073e9SAndroid Build Coastguard Worker else: 496*da0073e9SAndroid Build Coastguard Worker return blue(repr(arg)) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: 500*da0073e9SAndroid Build Coastguard Worker args_s = ', '.join(_get_repr(a) for a in args) 501*da0073e9SAndroid Build Coastguard Worker kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) 502*da0073e9SAndroid Build Coastguard Worker if args_s and kwargs_s: 503*da0073e9SAndroid Build Coastguard Worker return f'{args_s}, {kwargs_s}' 504*da0073e9SAndroid Build Coastguard Worker return args_s or kwargs_s 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker # Run through reverse nodes and record the first instance of a use 507*da0073e9SAndroid Build Coastguard Worker # of a given node. This represents the *last* use of the node in the 508*da0073e9SAndroid Build Coastguard Worker # execution order of the program, which we will use to free unused 509*da0073e9SAndroid Build Coastguard Worker # values 510*da0073e9SAndroid Build Coastguard Worker node_to_last_use : Dict[Node, Node] = {} 511*da0073e9SAndroid Build Coastguard Worker user_to_last_uses : Dict[Node, List[Node]] = {} 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker def register_last_uses(n : Node, user : Node): 514*da0073e9SAndroid Build Coastguard Worker if n not in node_to_last_use: 515*da0073e9SAndroid Build Coastguard Worker node_to_last_use[n] = user 516*da0073e9SAndroid Build Coastguard Worker user_to_last_uses.setdefault(user, []).append(n) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker for node in reversed(nodes): 519*da0073e9SAndroid Build Coastguard Worker map_arg(node.args, lambda n: register_last_uses(n, node)) 520*da0073e9SAndroid Build Coastguard Worker map_arg(node.kwargs, lambda n: register_last_uses(n, node)) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker def delete_unused_values(user : Node): 523*da0073e9SAndroid Build Coastguard Worker """ 524*da0073e9SAndroid Build Coastguard Worker Delete values after their last use. This ensures that values that are 525*da0073e9SAndroid Build Coastguard Worker not used in the remainder of the code are freed and the memory usage 526*da0073e9SAndroid Build Coastguard Worker of the code is optimal. 527*da0073e9SAndroid Build Coastguard Worker """ 528*da0073e9SAndroid Build Coastguard Worker if user.op == 'placeholder': 529*da0073e9SAndroid Build Coastguard Worker return 530*da0073e9SAndroid Build Coastguard Worker if user.op == 'output': 531*da0073e9SAndroid Build Coastguard Worker body.append('\n') 532*da0073e9SAndroid Build Coastguard Worker return 533*da0073e9SAndroid Build Coastguard Worker nodes_to_delete = user_to_last_uses.get(user, []) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker if len(user.users.keys()) == 0: 536*da0073e9SAndroid Build Coastguard Worker # This node is not used by any others. however it's also not 537*da0073e9SAndroid Build Coastguard Worker # removed by DCE since side-effect. We want to free it's outputs 538*da0073e9SAndroid Build Coastguard Worker # right after its execution done to save memory. 539*da0073e9SAndroid Build Coastguard Worker nodes_to_delete.append(user) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker if len(nodes_to_delete): 542*da0073e9SAndroid Build Coastguard Worker to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) 543*da0073e9SAndroid Build Coastguard Worker body.append(f'; {dim(to_delete_str)}\n') 544*da0073e9SAndroid Build Coastguard Worker else: 545*da0073e9SAndroid Build Coastguard Worker body.append('\n') 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker prev_stacktrace = None 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker def append_stacktrace_summary(node : Node): 550*da0073e9SAndroid Build Coastguard Worker """ 551*da0073e9SAndroid Build Coastguard Worker Append a summary of the stacktrace to the generated code. This is 552*da0073e9SAndroid Build Coastguard Worker useful for debugging. 553*da0073e9SAndroid Build Coastguard Worker """ 554*da0073e9SAndroid Build Coastguard Worker nonlocal prev_stacktrace 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker if node.op not in {'placeholder', 'output'}: 557*da0073e9SAndroid Build Coastguard Worker if node.stack_trace: 558*da0073e9SAndroid Build Coastguard Worker if node.stack_trace != prev_stacktrace: 559*da0073e9SAndroid Build Coastguard Worker prev_stacktrace = node.stack_trace 560*da0073e9SAndroid Build Coastguard Worker summary_str = "" 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker if parsed_stack_trace := _parse_stack_trace(node.stack_trace): 563*da0073e9SAndroid Build Coastguard Worker summary_str = parsed_stack_trace.get_summary_str() 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker body.append(f'\n {dim("# " + summary_str)}\n') 566*da0073e9SAndroid Build Coastguard Worker elif prev_stacktrace != "": 567*da0073e9SAndroid Build Coastguard Worker prev_stacktrace = "" 568*da0073e9SAndroid Build Coastguard Worker no_stacktrace_msg = "# No stacktrace found for following nodes" 569*da0073e9SAndroid Build Coastguard Worker body.append(f'\n{dim(no_stacktrace_msg)}\n') 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def stringify_shape(shape : Iterable) -> str: 572*da0073e9SAndroid Build Coastguard Worker return f"[{', '.join(str(x) for x in shape)}]" 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker def emit_node(node : Node): 575*da0073e9SAndroid Build Coastguard Worker maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker if verbose: 578*da0073e9SAndroid Build Coastguard Worker # override annotation with more detailed information 579*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import py_sym_types 580*da0073e9SAndroid Build Coastguard Worker from torch.fx.passes.shape_prop import TensorMetadata 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) 583*da0073e9SAndroid Build Coastguard Worker # use string as annotation, to make it valid python code 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker if isinstance(meta_val, torch.Tensor): 586*da0073e9SAndroid Build Coastguard Worker stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" 587*da0073e9SAndroid Build Coastguard Worker device_annotation = f"{meta_val.device}" if include_device else "" 588*da0073e9SAndroid Build Coastguard Worker maybe_type_annotation = \ 589*da0073e9SAndroid Build Coastguard Worker f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ 590*da0073e9SAndroid Build Coastguard Worker f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' 591*da0073e9SAndroid Build Coastguard Worker elif isinstance(meta_val, py_sym_types): 592*da0073e9SAndroid Build Coastguard Worker maybe_type_annotation = f': "Sym({meta_val})"' 593*da0073e9SAndroid Build Coastguard Worker elif isinstance(meta_val, TensorMetadata): 594*da0073e9SAndroid Build Coastguard Worker maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker if node.op == 'placeholder': 597*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.target, str) 598*da0073e9SAndroid Build Coastguard Worker maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' 599*da0073e9SAndroid Build Coastguard Worker free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') 600*da0073e9SAndroid Build Coastguard Worker raw_name = node.target.replace('*', '') 601*da0073e9SAndroid Build Coastguard Worker if raw_name != repr(node): 602*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)} = {raw_name}\n') 603*da0073e9SAndroid Build Coastguard Worker return 604*da0073e9SAndroid Build Coastguard Worker elif node.op == 'call_method': 605*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.target, str) 606*da0073e9SAndroid Build Coastguard Worker body.append( 607*da0073e9SAndroid Build Coastguard Worker f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' 608*da0073e9SAndroid Build Coastguard Worker f'({_format_args(node.args[1:], node.kwargs)})') 609*da0073e9SAndroid Build Coastguard Worker return 610*da0073e9SAndroid Build Coastguard Worker elif node.op == 'call_function': 611*da0073e9SAndroid Build Coastguard Worker assert callable(node.target) 612*da0073e9SAndroid Build Coastguard Worker # pretty print operators 613*da0073e9SAndroid Build Coastguard Worker if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: 614*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.args, tuple) 615*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)}{maybe_type_annotation} = ' 616*da0073e9SAndroid Build Coastguard Worker f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') 617*da0073e9SAndroid Build Coastguard Worker return 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker # pretty print inplace operators; required for jit.script to work properly 620*da0073e9SAndroid Build Coastguard Worker # not currently supported in normal FX graphs, but generated by torchdynamo 621*da0073e9SAndroid Build Coastguard Worker if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: 622*da0073e9SAndroid Build Coastguard Worker body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' 623*da0073e9SAndroid Build Coastguard Worker f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') 624*da0073e9SAndroid Build Coastguard Worker return 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker qualified_name = _get_qualified_name(node.target) 627*da0073e9SAndroid Build Coastguard Worker global_name = add_global(qualified_name, node.target) 628*da0073e9SAndroid Build Coastguard Worker # special case for getattr: node.args could be 2-argument or 3-argument 629*da0073e9SAndroid Build Coastguard Worker # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value 630*da0073e9SAndroid Build Coastguard Worker if global_name == 'getattr' and \ 631*da0073e9SAndroid Build Coastguard Worker isinstance(node.args, tuple) and \ 632*da0073e9SAndroid Build Coastguard Worker isinstance(node.args[1], str) and \ 633*da0073e9SAndroid Build Coastguard Worker node.args[1].isidentifier() and \ 634*da0073e9SAndroid Build Coastguard Worker len(node.args) == 2: 635*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') 636*da0073e9SAndroid Build Coastguard Worker return 637*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') 638*da0073e9SAndroid Build Coastguard Worker if node.meta.get('is_wrapped', False): 639*da0073e9SAndroid Build Coastguard Worker wrapped_fns.setdefault(global_name) 640*da0073e9SAndroid Build Coastguard Worker return 641*da0073e9SAndroid Build Coastguard Worker elif node.op == 'call_module': 642*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.target, str) 643*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)}{maybe_type_annotation} = ' 644*da0073e9SAndroid Build Coastguard Worker f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') 645*da0073e9SAndroid Build Coastguard Worker return 646*da0073e9SAndroid Build Coastguard Worker elif node.op == 'get_attr': 647*da0073e9SAndroid Build Coastguard Worker assert isinstance(node.target, str) 648*da0073e9SAndroid Build Coastguard Worker body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') 649*da0073e9SAndroid Build Coastguard Worker return 650*da0073e9SAndroid Build Coastguard Worker elif node.op == 'output': 651*da0073e9SAndroid Build Coastguard Worker if node.type is not None: 652*da0073e9SAndroid Build Coastguard Worker maybe_return_annotation[0] = f" -> {type_repr(node.type)}" 653*da0073e9SAndroid Build Coastguard Worker body.append(self.generate_output(node.args[0])) 654*da0073e9SAndroid Build Coastguard Worker return 655*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f'node: {node.op} {node.target}') 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker for i, node in enumerate(nodes): 658*da0073e9SAndroid Build Coastguard Worker # NOTE: emit_node does not emit a string with newline. It depends 659*da0073e9SAndroid Build Coastguard Worker # on delete_unused_values to append one 660*da0073e9SAndroid Build Coastguard Worker if verbose: 661*da0073e9SAndroid Build Coastguard Worker append_stacktrace_summary(node) 662*da0073e9SAndroid Build Coastguard Worker # emit a counter comment to keep track of 663*da0073e9SAndroid Build Coastguard Worker # node index, which will be deleted later 664*da0073e9SAndroid Build Coastguard Worker # after going through _body_transformer 665*da0073e9SAndroid Build Coastguard Worker body.append(f"# COUNTER: {i}\n") 666*da0073e9SAndroid Build Coastguard Worker emit_node(node) 667*da0073e9SAndroid Build Coastguard Worker delete_unused_values(node) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker if len(body) == 0: 670*da0073e9SAndroid Build Coastguard Worker # If the Graph has no non-placeholder nodes, no lines for the body 671*da0073e9SAndroid Build Coastguard Worker # have been emitted. To continue to have valid Python code, emit a 672*da0073e9SAndroid Build Coastguard Worker # single pass statement 673*da0073e9SAndroid Build Coastguard Worker body.append('pass\n') 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker if len(wrapped_fns) > 0: 678*da0073e9SAndroid Build Coastguard Worker wrap_name = add_global('wrap', torch.fx.wrap) 679*da0073e9SAndroid Build Coastguard Worker wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) 680*da0073e9SAndroid Build Coastguard Worker else: 681*da0073e9SAndroid Build Coastguard Worker wrap_stmts = '' 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker if self._body_transformer: 684*da0073e9SAndroid Build Coastguard Worker body = self._body_transformer(body) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker for name, value in self.additional_globals(): 687*da0073e9SAndroid Build Coastguard Worker add_global(name, value) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker # remove counter and generate lineno to node index mapping 692*da0073e9SAndroid Build Coastguard Worker lineno_map: Dict[int, Optional[int]] = {} 693*da0073e9SAndroid Build Coastguard Worker prologue_len = prologue.count('\n') + 1 694*da0073e9SAndroid Build Coastguard Worker new_lines: List[str] = [] 695*da0073e9SAndroid Build Coastguard Worker cur_idx = None 696*da0073e9SAndroid Build Coastguard Worker for line in ''.join(body).split('\n'): 697*da0073e9SAndroid Build Coastguard Worker counter = re.search(r"# COUNTER: (\d+)", line) 698*da0073e9SAndroid Build Coastguard Worker if counter and counter.group(1) is not None: 699*da0073e9SAndroid Build Coastguard Worker cur_idx = int(counter.group(1)) 700*da0073e9SAndroid Build Coastguard Worker else: 701*da0073e9SAndroid Build Coastguard Worker lineno_map[len(new_lines) + prologue_len] = cur_idx 702*da0073e9SAndroid Build Coastguard Worker new_lines.append(line) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker code = "\n".join(new_lines).lstrip('\n') 705*da0073e9SAndroid Build Coastguard Worker code = '\n'.join(' ' + line for line in code.split('\n')) 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker fn_code = f""" 708*da0073e9SAndroid Build Coastguard Worker{wrap_stmts} 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker{prologue} 711*da0073e9SAndroid Build Coastguard Worker{code}""" 712*da0073e9SAndroid Build Coastguard Worker return PythonCode(fn_code, globals_, _lineno_map=lineno_map) 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker# Ideally, we'd like to refactor all of the pytree logic into this codegen 716*da0073e9SAndroid Build Coastguard Worker# class. Unfortunately, there are 3 areas we currently need extra logic in FX. 717*da0073e9SAndroid Build Coastguard Worker# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. 718*da0073e9SAndroid Build Coastguard Worker# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. 719*da0073e9SAndroid Build Coastguard Worker# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. 720*da0073e9SAndroid Build Coastguard Worker# 3. We currently can't register the pytree imports with `add_global` - not sure why. 721*da0073e9SAndroid Build Coastguard Workerclass _PyTreeCodeGen(CodeGen): 722*da0073e9SAndroid Build Coastguard Worker def __init__(self, pytree_info: _PyTreeInfo): 723*da0073e9SAndroid Build Coastguard Worker super().__init__() 724*da0073e9SAndroid Build Coastguard Worker self.pytree_info: _PyTreeInfo = pytree_info 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *inputs: Any) -> Any: 727*da0073e9SAndroid Build Coastguard Worker flat_args = pytree.arg_tree_leaves(*inputs) 728*da0073e9SAndroid Build Coastguard Worker return flat_args 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def process_outputs(self, out: Any) -> Any: 731*da0073e9SAndroid Build Coastguard Worker if self.pytree_info is None or self.pytree_info.out_spec is None: 732*da0073e9SAndroid Build Coastguard Worker return out 733*da0073e9SAndroid Build Coastguard Worker if not isinstance(out, (list, tuple)): 734*da0073e9SAndroid Build Coastguard Worker out = [out] 735*da0073e9SAndroid Build Coastguard Worker assert self.pytree_info.out_spec is not None 736*da0073e9SAndroid Build Coastguard Worker return pytree.tree_unflatten(out, self.pytree_info.out_spec) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker def gen_fn_def(self, free_vars, maybe_return_annotation): 739*da0073e9SAndroid Build Coastguard Worker # Given a user function/model: 740*da0073e9SAndroid Build Coastguard Worker # myargs = [myargs0, myargs1] 741*da0073e9SAndroid Build Coastguard Worker # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} 742*da0073e9SAndroid Build Coastguard Worker # def forward(self, mypos, *myargs, mykey=None, **mykwargs): 743*da0073e9SAndroid Build Coastguard Worker # 744*da0073e9SAndroid Build Coastguard Worker # The generated code flattens all keywords into positional arguments for `forward()` 745*da0073e9SAndroid Build Coastguard Worker # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): 746*da0073e9SAndroid Build Coastguard Worker # 747*da0073e9SAndroid Build Coastguard Worker # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately 748*da0073e9SAndroid Build Coastguard Worker # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], 749*da0073e9SAndroid Build Coastguard Worker # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), 750*da0073e9SAndroid Build Coastguard Worker # self._in_spec) 751*da0073e9SAndroid Build Coastguard Worker # 752*da0073e9SAndroid Build Coastguard Worker # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec 753*da0073e9SAndroid Build Coastguard Worker # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) 754*da0073e9SAndroid Build Coastguard Worker if self.pytree_info is None: 755*da0073e9SAndroid Build Coastguard Worker return super().gen_fn_def(free_vars, maybe_return_annotation) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker fn_args = self.pytree_info.orig_args 758*da0073e9SAndroid Build Coastguard Worker has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False 759*da0073e9SAndroid Build Coastguard Worker if has_orig_self: 760*da0073e9SAndroid Build Coastguard Worker free_vars.insert(0, 'self') 761*da0073e9SAndroid Build Coastguard Worker fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker if len(free_vars) > 0: # pytree has placeholders in it 764*da0073e9SAndroid Build Coastguard Worker # when kwargs is present, in_spec is tuple(args, kwargs) 765*da0073e9SAndroid Build Coastguard Worker has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ 766*da0073e9SAndroid Build Coastguard Worker self.pytree_info.in_spec.num_children == 2 and \ 767*da0073e9SAndroid Build Coastguard Worker self.pytree_info.in_spec.children_specs[0].type == tuple and \ 768*da0073e9SAndroid Build Coastguard Worker self.pytree_info.in_spec.children_specs[1].type == dict 769*da0073e9SAndroid Build Coastguard Worker fn_kwargs = '{}' 770*da0073e9SAndroid Build Coastguard Worker fn_signature = f"[{', '.join(fn_args)}], self._in_spec" 771*da0073e9SAndroid Build Coastguard Worker if has_args_kwargs_tuple: 772*da0073e9SAndroid Build Coastguard Worker count_args = self.pytree_info.in_spec.children_specs[0].num_children 773*da0073e9SAndroid Build Coastguard Worker fn_args = self.pytree_info.orig_args[:count_args] 774*da0073e9SAndroid Build Coastguard Worker fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( 775*da0073e9SAndroid Build Coastguard Worker self.pytree_info.in_spec.children_specs[1].context, 776*da0073e9SAndroid Build Coastguard Worker self.pytree_info.orig_args[count_args:])) + '}' 777*da0073e9SAndroid Build Coastguard Worker fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. 780*da0073e9SAndroid Build Coastguard Worker # we need to split it to two lines: 781*da0073e9SAndroid Build Coastguard Worker # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) 782*da0073e9SAndroid Build Coastguard Worker # one for code: `var1, var2, = function_call()` 783*da0073e9SAndroid Build Coastguard Worker without_annotation = [x.split(":")[0] for x in free_vars] 784*da0073e9SAndroid Build Coastguard Worker has_annotation = [x + "; " for x in free_vars if ":" in x] 785*da0073e9SAndroid Build Coastguard Worker if len(has_annotation) > 0: 786*da0073e9SAndroid Build Coastguard Worker fn_definition += "\n " + "".join(has_annotation) + "\n" 787*da0073e9SAndroid Build Coastguard Worker fn_definition += f""" 788*da0073e9SAndroid Build Coastguard Worker {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" 789*da0073e9SAndroid Build Coastguard Worker return fn_definition 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker def generate_output(self, output_args): 792*da0073e9SAndroid Build Coastguard Worker if self.pytree_info and self.pytree_info.out_spec: 793*da0073e9SAndroid Build Coastguard Worker return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' 794*da0073e9SAndroid Build Coastguard Worker else: 795*da0073e9SAndroid Build Coastguard Worker return super().generate_output(output_args) 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Workerclass _FindNodesLookupTable: 798*da0073e9SAndroid Build Coastguard Worker """ 799*da0073e9SAndroid Build Coastguard Worker Side table for the graph for the purpose of doing fast queries 800*da0073e9SAndroid Build Coastguard Worker """ 801*da0073e9SAndroid Build Coastguard Worker def __init__(self): 802*da0073e9SAndroid Build Coastguard Worker self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker def _key(self, node) -> Tuple[str, Optional[Target]]: 805*da0073e9SAndroid Build Coastguard Worker return (node.op, node.target if node.op == "call_function" else None) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker def __contains__(self, node) -> bool: 808*da0073e9SAndroid Build Coastguard Worker return node in self.table[self._key(node)] 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker def insert(self, node: Node) -> None: 811*da0073e9SAndroid Build Coastguard Worker self.table[self._key(node)][node] = None 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def remove(self, node: Node) -> None: 814*da0073e9SAndroid Build Coastguard Worker self.table[self._key(node)].pop(node) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def find_nodes(self, *, op: str, target: Optional['Target'] = None): 817*da0073e9SAndroid Build Coastguard Worker if op == "call_function": 818*da0073e9SAndroid Build Coastguard Worker assert target is not None 819*da0073e9SAndroid Build Coastguard Worker return dict(self.table[(op, target)]).keys() 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker if target is None: 822*da0073e9SAndroid Build Coastguard Worker return dict(self.table[(op, None)]).keys() 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker # op is call_method, get_attr, call_module 825*da0073e9SAndroid Build Coastguard Worker return [node for node in self.table[(op, None)].keys() if node.target == target] 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 828*da0073e9SAndroid Build Coastguard Workerclass Graph: 829*da0073e9SAndroid Build Coastguard Worker """ 830*da0073e9SAndroid Build Coastguard Worker ``Graph`` is the main data structure used in the FX Intermediate Representation. 831*da0073e9SAndroid Build Coastguard Worker It consists of a series of ``Node`` s, each representing callsites (or other 832*da0073e9SAndroid Build Coastguard Worker syntactic constructs). The list of ``Node`` s, taken together, constitute a 833*da0073e9SAndroid Build Coastguard Worker valid Python function. 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker For example, the following code 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker import torch 840*da0073e9SAndroid Build Coastguard Worker import torch.fx 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 843*da0073e9SAndroid Build Coastguard Worker def __init__(self): 844*da0073e9SAndroid Build Coastguard Worker super().__init__() 845*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 846*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 849*da0073e9SAndroid Build Coastguard Worker return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker m = MyModule() 852*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(m) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker Will produce the following Graph:: 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker print(gm.graph) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker .. code-block:: text 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker graph(x): 861*da0073e9SAndroid Build Coastguard Worker %linear_weight : [num_users=1] = self.linear.weight 862*da0073e9SAndroid Build Coastguard Worker %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) 863*da0073e9SAndroid Build Coastguard Worker %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) 864*da0073e9SAndroid Build Coastguard Worker %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) 865*da0073e9SAndroid Build Coastguard Worker %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) 866*da0073e9SAndroid Build Coastguard Worker %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) 867*da0073e9SAndroid Build Coastguard Worker return topk_1 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker For the semantics of operations represented in the ``Graph``, please see :class:`Node`. 870*da0073e9SAndroid Build Coastguard Worker """ 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 873*da0073e9SAndroid Build Coastguard Worker def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, 874*da0073e9SAndroid Build Coastguard Worker tracer_extras: Optional[Dict[str, Any]] = None): 875*da0073e9SAndroid Build Coastguard Worker """ 876*da0073e9SAndroid Build Coastguard Worker Construct an empty Graph. 877*da0073e9SAndroid Build Coastguard Worker """ 878*da0073e9SAndroid Build Coastguard Worker self._root : Node = Node(self, '', 'root', '', (), {}) 879*da0073e9SAndroid Build Coastguard Worker self._used_names : Dict[str, int] = {} # base name -> number 880*da0073e9SAndroid Build Coastguard Worker self._insert = self._root.prepend 881*da0073e9SAndroid Build Coastguard Worker self._len = 0 882*da0073e9SAndroid Build Coastguard Worker self._graph_namespace = _Namespace() 883*da0073e9SAndroid Build Coastguard Worker self._owning_module = owning_module 884*da0073e9SAndroid Build Coastguard Worker self._tracer_cls = tracer_cls 885*da0073e9SAndroid Build Coastguard Worker self._tracer_extras = tracer_extras 886*da0073e9SAndroid Build Coastguard Worker self._codegen = CodeGen() 887*da0073e9SAndroid Build Coastguard Worker self._co_fields : Dict[str, Any] = {} 888*da0073e9SAndroid Build Coastguard Worker self._find_nodes_lookup_table = _FindNodesLookupTable() 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker @property 891*da0073e9SAndroid Build Coastguard Worker def owning_module(self): 892*da0073e9SAndroid Build Coastguard Worker return self._owning_module 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker @owning_module.setter 895*da0073e9SAndroid Build Coastguard Worker def owning_module(self, mod: Optional["GraphModule"]): 896*da0073e9SAndroid Build Coastguard Worker self._owning_module = mod 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker @property 899*da0073e9SAndroid Build Coastguard Worker def nodes(self) -> _node_list: 900*da0073e9SAndroid Build Coastguard Worker """ 901*da0073e9SAndroid Build Coastguard Worker Get the list of Nodes that constitute this Graph. 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker Note that this ``Node`` list representation is a doubly-linked list. Mutations 904*da0073e9SAndroid Build Coastguard Worker during iteration (e.g. delete a Node, add a Node) are safe. 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker Returns: 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker A doubly-linked list of Nodes. Note that ``reversed`` can be called on 909*da0073e9SAndroid Build Coastguard Worker this list to switch iteration order. 910*da0073e9SAndroid Build Coastguard Worker """ 911*da0073e9SAndroid Build Coastguard Worker return _node_list(self) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 914*da0073e9SAndroid Build Coastguard Worker def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): 915*da0073e9SAndroid Build Coastguard Worker """ 916*da0073e9SAndroid Build Coastguard Worker Allows for fast query of nodes 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker Args: 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker op (str): the name of the operation 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker target (Optional[Target]): the target of the node. For call_function, 923*da0073e9SAndroid Build Coastguard Worker the target is required. For other ops, the target is optional. 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker sort (bool): whether to return nodes in the order they appear on 926*da0073e9SAndroid Build Coastguard Worker on the graph. 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker Returns: 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker Iteratable of nodes with the requested op and target. 931*da0073e9SAndroid Build Coastguard Worker """ 932*da0073e9SAndroid Build Coastguard Worker node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) 933*da0073e9SAndroid Build Coastguard Worker if sort: 934*da0073e9SAndroid Build Coastguard Worker return sorted(node_list) 935*da0073e9SAndroid Build Coastguard Worker return node_list 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 938*da0073e9SAndroid Build Coastguard Worker def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': 939*da0073e9SAndroid Build Coastguard Worker """ 940*da0073e9SAndroid Build Coastguard Worker Copy all nodes from a given graph into ``self``. 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker Args: 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker g (Graph): The source graph from which to copy Nodes. 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping 947*da0073e9SAndroid Build Coastguard Worker from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed 948*da0073e9SAndroid Build Coastguard Worker in with values in it already to override copying of certain values. 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker Returns: 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker The value in ``self`` that is now equivalent to the output value in ``g``, 953*da0073e9SAndroid Build Coastguard Worker if ``g`` had an ``output`` node. ``None`` otherwise. 954*da0073e9SAndroid Build Coastguard Worker """ 955*da0073e9SAndroid Build Coastguard Worker for node in g.nodes: 956*da0073e9SAndroid Build Coastguard Worker if node in val_map: 957*da0073e9SAndroid Build Coastguard Worker continue 958*da0073e9SAndroid Build Coastguard Worker if node.op == 'output': 959*da0073e9SAndroid Build Coastguard Worker rv = map_arg(node.args[0], lambda n: val_map[n]) 960*da0073e9SAndroid Build Coastguard Worker return rv if not return_output_node else (rv, node) 961*da0073e9SAndroid Build Coastguard Worker val_map[node] = self.node_copy(node, lambda n : val_map[n]) 962*da0073e9SAndroid Build Coastguard Worker return None 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo=None) -> 'Graph': 965*da0073e9SAndroid Build Coastguard Worker """ 966*da0073e9SAndroid Build Coastguard Worker Explicitly implement __deepcopy__ to prevent excessive recursion depth 967*da0073e9SAndroid Build Coastguard Worker from the default implementation. This uses graph_copy to copy the nodes 968*da0073e9SAndroid Build Coastguard Worker in an iterative way, rather than recursive. It also populates the 969*da0073e9SAndroid Build Coastguard Worker memoization table to prevent unnecessary copies (e.g. references to 970*da0073e9SAndroid Build Coastguard Worker nodes or other parts of the Graph from a custom GraphModule implementation. 971*da0073e9SAndroid Build Coastguard Worker """ 972*da0073e9SAndroid Build Coastguard Worker memo = memo if memo else {} 973*da0073e9SAndroid Build Coastguard Worker g = Graph(tracer_cls=self._tracer_cls) 974*da0073e9SAndroid Build Coastguard Worker output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) 975*da0073e9SAndroid Build Coastguard Worker g._codegen = copy.deepcopy(self._codegen) 976*da0073e9SAndroid Build Coastguard Worker assert isinstance(output_vals, tuple) 977*da0073e9SAndroid Build Coastguard Worker output_val, old_output_node = output_vals 978*da0073e9SAndroid Build Coastguard Worker new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) 979*da0073e9SAndroid Build Coastguard Worker new_output_node.meta = copy.copy(old_output_node.meta) 980*da0073e9SAndroid Build Coastguard Worker return g 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 983*da0073e9SAndroid Build Coastguard Worker def create_node(self, op: str, target: 'Target', 984*da0073e9SAndroid Build Coastguard Worker args: Optional[Tuple['Argument', ...]] = None, 985*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, 'Argument']] = None, 986*da0073e9SAndroid Build Coastguard Worker name: Optional[str] = None, 987*da0073e9SAndroid Build Coastguard Worker type_expr: Optional[Any] = None) -> Node: 988*da0073e9SAndroid Build Coastguard Worker """ 989*da0073e9SAndroid Build Coastguard Worker Create a ``Node`` and add it to the ``Graph`` at the current insert-point. 990*da0073e9SAndroid Build Coastguard Worker Note that the current insert-point can be set via :meth:`Graph.inserting_before` 991*da0073e9SAndroid Build Coastguard Worker and :meth:`Graph.inserting_after`. 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker Args: 994*da0073e9SAndroid Build Coastguard Worker op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', 995*da0073e9SAndroid Build Coastguard Worker 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are 996*da0073e9SAndroid Build Coastguard Worker described in the ``Graph`` docstring. 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker name (Optional[str]): an optional string name for the ``Node``. 1003*da0073e9SAndroid Build Coastguard Worker This will influence the name of the value assigned to in the 1004*da0073e9SAndroid Build Coastguard Worker Python generated code. 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1007*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker Returns: 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker The newly-created and inserted node. 1012*da0073e9SAndroid Build Coastguard Worker """ 1013*da0073e9SAndroid Build Coastguard Worker assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') 1014*da0073e9SAndroid Build Coastguard Worker args = () if args is None else args 1015*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1016*da0073e9SAndroid Build Coastguard Worker assert isinstance(args, tuple), "args must be a tuple" 1017*da0073e9SAndroid Build Coastguard Worker assert isinstance(kwargs, dict), "kwargs must be a dict" 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker candidate = name if name is not None else self._target_to_str(target) 1020*da0073e9SAndroid Build Coastguard Worker name = self._graph_namespace.create_name(candidate, None) 1021*da0073e9SAndroid Build Coastguard Worker n = Node(self, name, op, target, args, kwargs, type_expr) 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: 1024*da0073e9SAndroid Build Coastguard Worker for f in self.owning_module._create_node_hooks: 1025*da0073e9SAndroid Build Coastguard Worker f(n) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker self._graph_namespace.associate_name_with_obj(name, n) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker self._insert(n) 1030*da0073e9SAndroid Build Coastguard Worker self._find_nodes_lookup_table.insert(n) 1031*da0073e9SAndroid Build Coastguard Worker self._len += 1 1032*da0073e9SAndroid Build Coastguard Worker return n 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 1035*da0073e9SAndroid Build Coastguard Worker def process_inputs(self, *args): 1036*da0073e9SAndroid Build Coastguard Worker """ 1037*da0073e9SAndroid Build Coastguard Worker Processes args so that they can be passed to the FX graph. 1038*da0073e9SAndroid Build Coastguard Worker """ 1039*da0073e9SAndroid Build Coastguard Worker return self._codegen.process_inputs(*args) 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 1042*da0073e9SAndroid Build Coastguard Worker def process_outputs(self, out): 1043*da0073e9SAndroid Build Coastguard Worker return self._codegen.process_outputs(out) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1047*da0073e9SAndroid Build Coastguard Worker def erase_node(self, to_erase : Node) -> None: 1048*da0073e9SAndroid Build Coastguard Worker """ 1049*da0073e9SAndroid Build Coastguard Worker Erases a ``Node`` from the ``Graph``. Throws an exception if 1050*da0073e9SAndroid Build Coastguard Worker there are still users of that node in the ``Graph``. 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker Args: 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker to_erase (Node): The ``Node`` to erase from the ``Graph``. 1055*da0073e9SAndroid Build Coastguard Worker """ 1056*da0073e9SAndroid Build Coastguard Worker if len(to_erase.users) > 0: 1057*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' 1058*da0073e9SAndroid Build Coastguard Worker f'users in the graph: {to_erase.users}!') 1059*da0073e9SAndroid Build Coastguard Worker if to_erase.graph != self: 1060*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") 1061*da0073e9SAndroid Build Coastguard Worker if to_erase._erased: 1062*da0073e9SAndroid Build Coastguard Worker warnings.warn(f"erase_node({to_erase}) on an already erased node") 1063*da0073e9SAndroid Build Coastguard Worker return 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: 1066*da0073e9SAndroid Build Coastguard Worker for f in self.owning_module._erase_node_hooks: 1067*da0073e9SAndroid Build Coastguard Worker f(to_erase) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker self._find_nodes_lookup_table.remove(to_erase) 1070*da0073e9SAndroid Build Coastguard Worker to_erase._remove_from_list() 1071*da0073e9SAndroid Build Coastguard Worker to_erase._erased = True # iterators may retain handles to erased nodes 1072*da0073e9SAndroid Build Coastguard Worker self._len -= 1 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker # Null out this Node's argument nodes so that the Nodes referred to 1075*da0073e9SAndroid Build Coastguard Worker # can update their ``users`` accordingly 1076*da0073e9SAndroid Build Coastguard Worker new_args = map_arg(to_erase.args, lambda n: None) 1077*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_args, tuple) 1078*da0073e9SAndroid Build Coastguard Worker to_erase.args = new_args 1079*da0073e9SAndroid Build Coastguard Worker new_kwargs = map_arg(to_erase.kwargs, lambda n: None) 1080*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_kwargs, dict) 1081*da0073e9SAndroid Build Coastguard Worker to_erase.kwargs = new_kwargs 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1084*da0073e9SAndroid Build Coastguard Worker def inserting_before(self, n: Optional[Node] = None): 1085*da0073e9SAndroid Build Coastguard Worker """Set the point at which create_node and companion methods will insert into the graph. 1086*da0073e9SAndroid Build Coastguard Worker When used within a 'with' statement, this will temporary set the insert point and 1087*da0073e9SAndroid Build Coastguard Worker then restore it when the with statement exits:: 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker with g.inserting_before(n): 1090*da0073e9SAndroid Build Coastguard Worker ... # inserting before node n 1091*da0073e9SAndroid Build Coastguard Worker ... # insert point restored to what it was previously 1092*da0073e9SAndroid Build Coastguard Worker g.inserting_before(n) # set the insert point permanently 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker Args: 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker n (Optional[Node]): The node before which to insert. If None this will insert before 1097*da0073e9SAndroid Build Coastguard Worker the beginning of the entire graph. 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker Returns: 1100*da0073e9SAndroid Build Coastguard Worker A resource manager that will restore the insert point on ``__exit__``. 1101*da0073e9SAndroid Build Coastguard Worker """ 1102*da0073e9SAndroid Build Coastguard Worker if n is None: 1103*da0073e9SAndroid Build Coastguard Worker return self.inserting_after(self._root) 1104*da0073e9SAndroid Build Coastguard Worker assert n.graph == self, "Node to insert before is not in graph." 1105*da0073e9SAndroid Build Coastguard Worker return _InsertPoint(self, n.prepend) 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1108*da0073e9SAndroid Build Coastguard Worker def inserting_after(self, n: Optional[Node] = None): 1109*da0073e9SAndroid Build Coastguard Worker """Set the point at which create_node and companion methods will insert into the graph. 1110*da0073e9SAndroid Build Coastguard Worker When used within a 'with' statement, this will temporary set the insert point and 1111*da0073e9SAndroid Build Coastguard Worker then restore it when the with statement exits:: 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker with g.inserting_after(n): 1114*da0073e9SAndroid Build Coastguard Worker ... # inserting after node n 1115*da0073e9SAndroid Build Coastguard Worker ... # insert point restored to what it was previously 1116*da0073e9SAndroid Build Coastguard Worker g.inserting_after(n) # set the insert point permanently 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker Args: 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker n (Optional[Node]): The node before which to insert. If None this will insert after 1121*da0073e9SAndroid Build Coastguard Worker the beginning of the entire graph. 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker Returns: 1124*da0073e9SAndroid Build Coastguard Worker A resource manager that will restore the insert point on ``__exit__``. 1125*da0073e9SAndroid Build Coastguard Worker """ 1126*da0073e9SAndroid Build Coastguard Worker if n is None: 1127*da0073e9SAndroid Build Coastguard Worker return self.inserting_before(self._root) 1128*da0073e9SAndroid Build Coastguard Worker assert n.graph == self, "Node to insert after is not in graph." 1129*da0073e9SAndroid Build Coastguard Worker return _InsertPoint(self, n.append) 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1132*da0073e9SAndroid Build Coastguard Worker def placeholder(self, name: str, type_expr: Optional[Any] = None, 1133*da0073e9SAndroid Build Coastguard Worker default_value : Any = inspect.Signature.empty) -> Node: 1134*da0073e9SAndroid Build Coastguard Worker """ 1135*da0073e9SAndroid Build Coastguard Worker Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents 1136*da0073e9SAndroid Build Coastguard Worker a function input. 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker Args: 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker name (str): A name for the input value. This corresponds to the name 1141*da0073e9SAndroid Build Coastguard Worker of the positional argument to the function this ``Graph`` represents. 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1144*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. This is needed in some 1145*da0073e9SAndroid Build Coastguard Worker cases for proper code generation (e.g. when the function is used 1146*da0073e9SAndroid Build Coastguard Worker subsequently in TorchScript compilation). 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker default_value (Any): The default value this function argument should take 1149*da0073e9SAndroid Build Coastguard Worker on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` 1150*da0073e9SAndroid Build Coastguard Worker should be passed as this argument to specify that the parameter does _not_ 1151*da0073e9SAndroid Build Coastguard Worker have a default value. 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker .. note:: 1154*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1155*da0073e9SAndroid Build Coastguard Worker as ``Graph.create_node``. 1156*da0073e9SAndroid Build Coastguard Worker """ 1157*da0073e9SAndroid Build Coastguard Worker args = () if default_value is inspect.Signature.empty else (default_value,) 1158*da0073e9SAndroid Build Coastguard Worker return self.create_node('placeholder', name, args=args, type_expr=type_expr) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1161*da0073e9SAndroid Build Coastguard Worker def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: 1162*da0073e9SAndroid Build Coastguard Worker """ 1163*da0073e9SAndroid Build Coastguard Worker Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the 1164*da0073e9SAndroid Build Coastguard Worker fetch of an attribute from the ``Module`` hierarchy. 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker Args: 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker qualified_name (str): the fully-qualified name of the attribute to be retrieved. 1169*da0073e9SAndroid Build Coastguard Worker For example, if the traced Module has a submodule named ``foo``, which has a 1170*da0073e9SAndroid Build Coastguard Worker submodule named ``bar``, which has an attribute named ``baz``, the qualified 1171*da0073e9SAndroid Build Coastguard Worker name ``foo.bar.baz`` should be passed as ``qualified_name``. 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1174*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker Returns: 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker The newly-created and inserted ``get_attr`` node. 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker .. note:: 1182*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1183*da0073e9SAndroid Build Coastguard Worker as ``Graph.create_node``. 1184*da0073e9SAndroid Build Coastguard Worker """ 1185*da0073e9SAndroid Build Coastguard Worker def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: 1186*da0073e9SAndroid Build Coastguard Worker module_path, _, name = qualified_name.rpartition(".") 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker try: 1189*da0073e9SAndroid Build Coastguard Worker submod: torch.nn.Module = mod.get_submodule(module_path) 1190*da0073e9SAndroid Build Coastguard Worker except AttributeError: 1191*da0073e9SAndroid Build Coastguard Worker warnings.warn(f"Failed to fetch module {module_path}!") 1192*da0073e9SAndroid Build Coastguard Worker return False 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker if not hasattr(submod, name): 1195*da0073e9SAndroid Build Coastguard Worker return False 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker res = getattr(submod, name) 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker if (not isinstance(res, torch.nn.Module) 1200*da0073e9SAndroid Build Coastguard Worker and not isinstance(res, torch.nn.Parameter) 1201*da0073e9SAndroid Build Coastguard Worker and name not in submod._buffers): 1202*da0073e9SAndroid Build Coastguard Worker return False 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker return True 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker if (self.owning_module and 1207*da0073e9SAndroid Build Coastguard Worker not _get_attr_reference_exists(self.owning_module, qualified_name)): 1208*da0073e9SAndroid Build Coastguard Worker warnings.warn("Attempted to insert a get_attr Node with no " 1209*da0073e9SAndroid Build Coastguard Worker "underlying reference in the owning " 1210*da0073e9SAndroid Build Coastguard Worker "GraphModule! Call " 1211*da0073e9SAndroid Build Coastguard Worker "GraphModule.add_submodule to add the " 1212*da0073e9SAndroid Build Coastguard Worker "necessary submodule, " 1213*da0073e9SAndroid Build Coastguard Worker "GraphModule.add_parameter to add the " 1214*da0073e9SAndroid Build Coastguard Worker "necessary Parameter, or " 1215*da0073e9SAndroid Build Coastguard Worker "nn.Module.register_buffer to add the " 1216*da0073e9SAndroid Build Coastguard Worker "necessary buffer", stacklevel=2) 1217*da0073e9SAndroid Build Coastguard Worker return self.create_node('get_attr', qualified_name, type_expr=type_expr) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1220*da0073e9SAndroid Build Coastguard Worker def call_module(self, 1221*da0073e9SAndroid Build Coastguard Worker module_name: str, 1222*da0073e9SAndroid Build Coastguard Worker args: Optional[Tuple['Argument', ...]] = None, 1223*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, 'Argument']] = None, 1224*da0073e9SAndroid Build Coastguard Worker type_expr: Optional[Any] = None) -> Node: 1225*da0073e9SAndroid Build Coastguard Worker """ 1226*da0073e9SAndroid Build Coastguard Worker Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node 1227*da0073e9SAndroid Build Coastguard Worker represents a call to the forward() function of a ``Module`` in the ``Module`` 1228*da0073e9SAndroid Build Coastguard Worker hierarchy. 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker Args: 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker module_name (str): The qualified name of the ``Module`` in the ``Module`` 1233*da0073e9SAndroid Build Coastguard Worker hierarchy to be called. For example, if the traced ``Module`` has a 1234*da0073e9SAndroid Build Coastguard Worker submodule named ``foo``, which has a submodule named ``bar``, the 1235*da0073e9SAndroid Build Coastguard Worker qualified name ``foo.bar`` should be passed as ``module_name`` to 1236*da0073e9SAndroid Build Coastguard Worker call that module. 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1239*da0073e9SAndroid Build Coastguard Worker to the called method. Note that this should *not* include a ``self`` argument. 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1242*da0073e9SAndroid Build Coastguard Worker to the called method 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1245*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker Returns: 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker The newly-created and inserted ``call_module`` node. 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker .. note:: 1252*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1253*da0073e9SAndroid Build Coastguard Worker as :meth:`Graph.create_node`. 1254*da0073e9SAndroid Build Coastguard Worker """ 1255*da0073e9SAndroid Build Coastguard Worker if (self.owning_module and 1256*da0073e9SAndroid Build Coastguard Worker self.owning_module.get_submodule(module_name) is None): 1257*da0073e9SAndroid Build Coastguard Worker warnings.warn("Attempted to insert a call_module Node with " 1258*da0073e9SAndroid Build Coastguard Worker "no underlying reference in the owning " 1259*da0073e9SAndroid Build Coastguard Worker "GraphModule! Call " 1260*da0073e9SAndroid Build Coastguard Worker "GraphModule.add_submodule to add the " 1261*da0073e9SAndroid Build Coastguard Worker "necessary submodule") 1262*da0073e9SAndroid Build Coastguard Worker return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1265*da0073e9SAndroid Build Coastguard Worker def call_method(self, 1266*da0073e9SAndroid Build Coastguard Worker method_name: str, 1267*da0073e9SAndroid Build Coastguard Worker args: Optional[Tuple['Argument', ...]] = None, 1268*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, 'Argument']] = None, 1269*da0073e9SAndroid Build Coastguard Worker type_expr: Optional[Any] = None) -> Node: 1270*da0073e9SAndroid Build Coastguard Worker """ 1271*da0073e9SAndroid Build Coastguard Worker Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node 1272*da0073e9SAndroid Build Coastguard Worker represents a call to a given method on the 0th element of ``args``. 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker Args: 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker method_name (str): The name of the method to apply to the self argument. 1277*da0073e9SAndroid Build Coastguard Worker For example, if args[0] is a ``Node`` representing a ``Tensor``, 1278*da0073e9SAndroid Build Coastguard Worker then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1281*da0073e9SAndroid Build Coastguard Worker to the called method. Note that this *should* include a ``self`` argument. 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1284*da0073e9SAndroid Build Coastguard Worker to the called method 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1287*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker Returns: 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker The newly created and inserted ``call_method`` node. 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker .. note:: 1294*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1295*da0073e9SAndroid Build Coastguard Worker as :meth:`Graph.create_node`. 1296*da0073e9SAndroid Build Coastguard Worker """ 1297*da0073e9SAndroid Build Coastguard Worker return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1300*da0073e9SAndroid Build Coastguard Worker def call_function(self, 1301*da0073e9SAndroid Build Coastguard Worker the_function: Callable[..., Any], 1302*da0073e9SAndroid Build Coastguard Worker args: Optional[Tuple['Argument', ...]] = None, 1303*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, 'Argument']] = None, 1304*da0073e9SAndroid Build Coastguard Worker type_expr: Optional[Any] = None) -> Node: 1305*da0073e9SAndroid Build Coastguard Worker """ 1306*da0073e9SAndroid Build Coastguard Worker Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node 1307*da0073e9SAndroid Build Coastguard Worker represents a call to a Python callable, specified by ``the_function``. 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker Args: 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker the_function (Callable[..., Any]): The function to be called. Can be any PyTorch 1312*da0073e9SAndroid Build Coastguard Worker operator, Python function, or member of the ``builtins`` or ``operator`` 1313*da0073e9SAndroid Build Coastguard Worker namespaces. 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1316*da0073e9SAndroid Build Coastguard Worker to the called function. 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1319*da0073e9SAndroid Build Coastguard Worker to the called function 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1322*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker Returns: 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker The newly created and inserted ``call_function`` node. 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker .. note:: 1329*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1330*da0073e9SAndroid Build Coastguard Worker as :meth:`Graph.create_node`. 1331*da0073e9SAndroid Build Coastguard Worker """ 1332*da0073e9SAndroid Build Coastguard Worker return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1335*da0073e9SAndroid Build Coastguard Worker def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: 1336*da0073e9SAndroid Build Coastguard Worker """ 1337*da0073e9SAndroid Build Coastguard Worker Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from 1338*da0073e9SAndroid Build Coastguard Worker the graph of node to the graph of self. Example:: 1339*da0073e9SAndroid Build Coastguard Worker 1340*da0073e9SAndroid Build Coastguard Worker # Copying all the nodes in `g` into `new_graph` 1341*da0073e9SAndroid Build Coastguard Worker g : torch.fx.Graph = ... 1342*da0073e9SAndroid Build Coastguard Worker new_graph = torch.fx.graph() 1343*da0073e9SAndroid Build Coastguard Worker value_remap = {} 1344*da0073e9SAndroid Build Coastguard Worker for node in g.nodes: 1345*da0073e9SAndroid Build Coastguard Worker value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker Args: 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker node (Node): The node to copy into ``self``. 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker arg_transform (Callable[[Node], Argument]): A function that transforms 1352*da0073e9SAndroid Build Coastguard Worker ``Node`` arguments in node's ``args`` and ``kwargs`` into the 1353*da0073e9SAndroid Build Coastguard Worker equivalent argument in ``self``. In the simplest case, this should 1354*da0073e9SAndroid Build Coastguard Worker retrieve a value out of a table mapping Nodes in the original 1355*da0073e9SAndroid Build Coastguard Worker graph to ``self``. 1356*da0073e9SAndroid Build Coastguard Worker """ 1357*da0073e9SAndroid Build Coastguard Worker args = map_arg(node.args, arg_transform) 1358*da0073e9SAndroid Build Coastguard Worker kwargs = map_arg(node.kwargs, arg_transform) 1359*da0073e9SAndroid Build Coastguard Worker assert isinstance(args, tuple) 1360*da0073e9SAndroid Build Coastguard Worker assert isinstance(kwargs, dict) 1361*da0073e9SAndroid Build Coastguard Worker result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) 1362*da0073e9SAndroid Build Coastguard Worker result_node.meta = copy.copy(node.meta) 1363*da0073e9SAndroid Build Coastguard Worker return result_node 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1366*da0073e9SAndroid Build Coastguard Worker def output(self, result: 'Argument', type_expr: Optional[Any] = None): 1367*da0073e9SAndroid Build Coastguard Worker """ 1368*da0073e9SAndroid Build Coastguard Worker Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents 1369*da0073e9SAndroid Build Coastguard Worker a ``return`` statement in Python code. ``result`` is the value that should 1370*da0073e9SAndroid Build Coastguard Worker be returned. 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker Args: 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker result (Argument): The value to be returned. 1375*da0073e9SAndroid Build Coastguard Worker 1376*da0073e9SAndroid Build Coastguard Worker type_expr (Optional[Any]): an optional type annotation representing the 1377*da0073e9SAndroid Build Coastguard Worker Python type the output of this node will have. 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker .. note:: 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker The same insertion point and type expression rules apply for this method 1382*da0073e9SAndroid Build Coastguard Worker as ``Graph.create_node``. 1383*da0073e9SAndroid Build Coastguard Worker """ 1384*da0073e9SAndroid Build Coastguard Worker return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker def _target_to_str(self, target : Target) -> str: 1387*da0073e9SAndroid Build Coastguard Worker if callable(target): 1388*da0073e9SAndroid Build Coastguard Worker op = target.__name__ 1389*da0073e9SAndroid Build Coastguard Worker else: 1390*da0073e9SAndroid Build Coastguard Worker assert isinstance(target, str) 1391*da0073e9SAndroid Build Coastguard Worker op = target 1392*da0073e9SAndroid Build Coastguard Worker if _is_magic(op): 1393*da0073e9SAndroid Build Coastguard Worker op = op[2:-2] 1394*da0073e9SAndroid Build Coastguard Worker op = _snake_case(op) 1395*da0073e9SAndroid Build Coastguard Worker return op 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1398*da0073e9SAndroid Build Coastguard Worker def python_code( 1399*da0073e9SAndroid Build Coastguard Worker self, root_module: str, *, 1400*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False 1401*da0073e9SAndroid Build Coastguard Worker ) -> PythonCode: 1402*da0073e9SAndroid Build Coastguard Worker """ 1403*da0073e9SAndroid Build Coastguard Worker Turn this ``Graph`` into valid Python code. 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker Args: 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker root_module (str): The name of the root module on which to look-up 1408*da0073e9SAndroid Build Coastguard Worker qualified name targets. This is usually 'self'. 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker Returns: 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker A PythonCode object, consisting of two fields: 1413*da0073e9SAndroid Build Coastguard Worker src: the Python source code representing the object 1414*da0073e9SAndroid Build Coastguard Worker globals: a dictionary of global names in `src` -> the objects that they reference. 1415*da0073e9SAndroid Build Coastguard Worker """ 1416*da0073e9SAndroid Build Coastguard Worker # NOTE: [Graph Namespaces] 1417*da0073e9SAndroid Build Coastguard Worker # 1418*da0073e9SAndroid Build Coastguard Worker # There are two types of symbols in generated Python source code: 1419*da0073e9SAndroid Build Coastguard Worker # locals and globals. 1420*da0073e9SAndroid Build Coastguard Worker # Locals are locally defined by the output of a node in the Graph. 1421*da0073e9SAndroid Build Coastguard Worker # Globals are references to external objects, like functions or types. 1422*da0073e9SAndroid Build Coastguard Worker # 1423*da0073e9SAndroid Build Coastguard Worker # When generating Python code, we need to make sure to name things 1424*da0073e9SAndroid Build Coastguard Worker # appropriately. In particular: 1425*da0073e9SAndroid Build Coastguard Worker # - All names should be unique, to avoid weird shadowing bugs. 1426*da0073e9SAndroid Build Coastguard Worker # - These names need to be consistent, e.g. a object should always be 1427*da0073e9SAndroid Build Coastguard Worker # referenced by the same name. 1428*da0073e9SAndroid Build Coastguard Worker # 1429*da0073e9SAndroid Build Coastguard Worker # To do this, we create a new namespace just for this source. All names 1430*da0073e9SAndroid Build Coastguard Worker # that get printed must come from this namespace. 1431*da0073e9SAndroid Build Coastguard Worker # 1432*da0073e9SAndroid Build Coastguard Worker # Why can't we re-use node.name? Because it was generated within the 1433*da0073e9SAndroid Build Coastguard Worker # namespace `self._graph_namespace`. In order to provide uniqueness 1434*da0073e9SAndroid Build Coastguard Worker # over both locals (node.name) *and* globals, we create a completely 1435*da0073e9SAndroid Build Coastguard Worker # new namespace to put all identifiers in. 1436*da0073e9SAndroid Build Coastguard Worker namespace = _Namespace() 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker # Override Node's repr to generate a valid name within our namespace. 1439*da0073e9SAndroid Build Coastguard Worker # Since repr() is designed to produce a valid Python expression, it 1440*da0073e9SAndroid Build Coastguard Worker # makes sense to re-use it. This way, it's easy to print something like 1441*da0073e9SAndroid Build Coastguard Worker # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is 1442*da0073e9SAndroid Build Coastguard Worker # implemented cooperatively to allow this. 1443*da0073e9SAndroid Build Coastguard Worker def node_repr(n: Node): 1444*da0073e9SAndroid Build Coastguard Worker return namespace.create_name(n.name, n) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker @contextmanager 1447*da0073e9SAndroid Build Coastguard Worker def override_node_repr(graph: Graph): 1448*da0073e9SAndroid Build Coastguard Worker orig_repr_fns = {} 1449*da0073e9SAndroid Build Coastguard Worker for node in graph.nodes: 1450*da0073e9SAndroid Build Coastguard Worker orig_repr_fns[node] = node._repr_fn 1451*da0073e9SAndroid Build Coastguard Worker node._repr_fn = node_repr 1452*da0073e9SAndroid Build Coastguard Worker try: 1453*da0073e9SAndroid Build Coastguard Worker yield None 1454*da0073e9SAndroid Build Coastguard Worker finally: 1455*da0073e9SAndroid Build Coastguard Worker # restore the original repr functions 1456*da0073e9SAndroid Build Coastguard Worker for node in graph.nodes: 1457*da0073e9SAndroid Build Coastguard Worker node._repr_fn = orig_repr_fns[node] 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker with override_node_repr(self): 1460*da0073e9SAndroid Build Coastguard Worker return self._python_code( 1461*da0073e9SAndroid Build Coastguard Worker root_module, namespace, 1462*da0073e9SAndroid Build Coastguard Worker verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored 1463*da0073e9SAndroid Build Coastguard Worker ) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def _python_code( 1466*da0073e9SAndroid Build Coastguard Worker self, root_module: str, namespace: _Namespace, *, 1467*da0073e9SAndroid Build Coastguard Worker verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, 1468*da0073e9SAndroid Build Coastguard Worker ) -> PythonCode: 1469*da0073e9SAndroid Build Coastguard Worker return self._codegen._gen_python_code( 1470*da0073e9SAndroid Build Coastguard Worker self.nodes, root_module, namespace, 1471*da0073e9SAndroid Build Coastguard Worker verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored 1472*da0073e9SAndroid Build Coastguard Worker ) 1473*da0073e9SAndroid Build Coastguard Worker 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1476*da0073e9SAndroid Build Coastguard Worker """ 1477*da0073e9SAndroid Build Coastguard Worker Return a human-readable (not machine-readable) string representation 1478*da0073e9SAndroid Build Coastguard Worker of this Graph 1479*da0073e9SAndroid Build Coastguard Worker """ 1480*da0073e9SAndroid Build Coastguard Worker placeholder_names : List[str] = [] 1481*da0073e9SAndroid Build Coastguard Worker # This is a one-element array just so ``format_node`` can modify the closed 1482*da0073e9SAndroid Build Coastguard Worker # over value 1483*da0073e9SAndroid Build Coastguard Worker maybe_return_typename : List[str] = [''] 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker node_strs = [node.format_node(placeholder_names) for node in self.nodes] 1486*da0073e9SAndroid Build Coastguard Worker param_str = ', '.join(placeholder_names) 1487*da0073e9SAndroid Build Coastguard Worker s = f'graph({param_str}){maybe_return_typename[0]}:' 1488*da0073e9SAndroid Build Coastguard Worker for node_str in node_strs: 1489*da0073e9SAndroid Build Coastguard Worker if node_str: 1490*da0073e9SAndroid Build Coastguard Worker s += '\n ' + node_str 1491*da0073e9SAndroid Build Coastguard Worker return s 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1494*da0073e9SAndroid Build Coastguard Worker def print_tabular(self): 1495*da0073e9SAndroid Build Coastguard Worker """ 1496*da0073e9SAndroid Build Coastguard Worker Prints the intermediate representation of the graph in tabular 1497*da0073e9SAndroid Build Coastguard Worker format. Note that this API requires the ``tabulate`` module to be 1498*da0073e9SAndroid Build Coastguard Worker installed. 1499*da0073e9SAndroid Build Coastguard Worker """ 1500*da0073e9SAndroid Build Coastguard Worker try: 1501*da0073e9SAndroid Build Coastguard Worker from tabulate import tabulate 1502*da0073e9SAndroid Build Coastguard Worker except ImportError: 1503*da0073e9SAndroid Build Coastguard Worker print("`print_tabular` relies on the library `tabulate`, " 1504*da0073e9SAndroid Build Coastguard Worker "which could not be found on this machine. Run `pip " 1505*da0073e9SAndroid Build Coastguard Worker "install tabulate` to install the library.") 1506*da0073e9SAndroid Build Coastguard Worker raise 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] 1509*da0073e9SAndroid Build Coastguard Worker for n in self.nodes] 1510*da0073e9SAndroid Build Coastguard Worker print(tabulate(node_specs, 1511*da0073e9SAndroid Build Coastguard Worker headers=['opcode', 'name', 'target', 'args', 'kwargs'])) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1514*da0073e9SAndroid Build Coastguard Worker def lint(self): 1515*da0073e9SAndroid Build Coastguard Worker """ 1516*da0073e9SAndroid Build Coastguard Worker Runs various checks on this Graph to make sure it is well-formed. In 1517*da0073e9SAndroid Build Coastguard Worker particular: 1518*da0073e9SAndroid Build Coastguard Worker - Checks Nodes have correct ownership (owned by this graph) 1519*da0073e9SAndroid Build Coastguard Worker - Checks Nodes appear in topological order 1520*da0073e9SAndroid Build Coastguard Worker - If this Graph has an owning GraphModule, checks that targets 1521*da0073e9SAndroid Build Coastguard Worker exist in that GraphModule 1522*da0073e9SAndroid Build Coastguard Worker """ 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker # Check topo order 1525*da0073e9SAndroid Build Coastguard Worker def check_arg(arg : Node, n : Optional[Node] = None) -> None: 1526*da0073e9SAndroid Build Coastguard Worker context_str = f' of Node \'{n}\' ' if n else ' ' 1527*da0073e9SAndroid Build Coastguard Worker if arg.graph is not self: 1528*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' 1529*da0073e9SAndroid Build Coastguard Worker f'but was used as an argument! If you are copying nodes from another graph, make ' 1530*da0073e9SAndroid Build Coastguard Worker f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') 1531*da0073e9SAndroid Build Coastguard Worker if arg not in seen_values: 1532*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' 1533*da0073e9SAndroid Build Coastguard Worker f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker seen_names : Set[str] = set() 1536*da0073e9SAndroid Build Coastguard Worker seen_values : Set[Node] = set() 1537*da0073e9SAndroid Build Coastguard Worker for node in self.nodes: 1538*da0073e9SAndroid Build Coastguard Worker if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: 1539*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') 1540*da0073e9SAndroid Build Coastguard Worker if node.graph is not self: 1541*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') 1542*da0073e9SAndroid Build Coastguard Worker if node not in self._find_nodes_lookup_table: 1543*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Node '{node}' is not added to the side table") 1544*da0073e9SAndroid Build Coastguard Worker map_arg(node.args, lambda arg: check_arg(arg, node)) 1545*da0073e9SAndroid Build Coastguard Worker map_arg(node.kwargs, lambda arg: check_arg(arg, node)) 1546*da0073e9SAndroid Build Coastguard Worker seen_values.add(node) 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker if node.name in seen_names: 1549*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node redefined name {node.name}!') 1550*da0073e9SAndroid Build Coastguard Worker seen_names.add(node.name) 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker # Check targets are legit 1553*da0073e9SAndroid Build Coastguard Worker if self.owning_module: 1554*da0073e9SAndroid Build Coastguard Worker num_warnings = 0 1555*da0073e9SAndroid Build Coastguard Worker MAX_WARNINGS = 5 1556*da0073e9SAndroid Build Coastguard Worker for node in self.nodes: 1557*da0073e9SAndroid Build Coastguard Worker if node.op == 'call_function': 1558*da0073e9SAndroid Build Coastguard Worker if not callable(node.target): 1559*da0073e9SAndroid Build Coastguard Worker raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' 1560*da0073e9SAndroid Build Coastguard Worker 'a Callable is expected') 1561*da0073e9SAndroid Build Coastguard Worker else: 1562*da0073e9SAndroid Build Coastguard Worker if not isinstance(node.target, str): 1563*da0073e9SAndroid Build Coastguard Worker raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' 1564*da0073e9SAndroid Build Coastguard Worker 'a str is expected') 1565*da0073e9SAndroid Build Coastguard Worker if node.op in ['get_attr', 'call_module']: 1566*da0073e9SAndroid Build Coastguard Worker target_atoms = node.target.split('.') 1567*da0073e9SAndroid Build Coastguard Worker m_itr = self.owning_module 1568*da0073e9SAndroid Build Coastguard Worker for i, atom in enumerate(target_atoms): 1569*da0073e9SAndroid Build Coastguard Worker new_m_itr = getattr(m_itr, atom, None) 1570*da0073e9SAndroid Build Coastguard Worker seen_qualname = '.'.join(target_atoms[:i]) 1571*da0073e9SAndroid Build Coastguard Worker if new_m_itr is None: 1572*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' 1573*da0073e9SAndroid Build Coastguard Worker f'{atom} of {seen_qualname}') 1574*da0073e9SAndroid Build Coastguard Worker if (node.op == "call_module" 1575*da0073e9SAndroid Build Coastguard Worker and not isinstance(new_m_itr, torch.nn.Module)): 1576*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' 1577*da0073e9SAndroid Build Coastguard Worker 'not reference an nn.Module') 1578*da0073e9SAndroid Build Coastguard Worker elif (node.op == "get_attr" 1579*da0073e9SAndroid Build Coastguard Worker and not isinstance(new_m_itr, torch.nn.Module) 1580*da0073e9SAndroid Build Coastguard Worker and not isinstance(new_m_itr, torch.nn.Parameter) 1581*da0073e9SAndroid Build Coastguard Worker and atom not in m_itr._buffers): 1582*da0073e9SAndroid Build Coastguard Worker if num_warnings < MAX_WARNINGS: 1583*da0073e9SAndroid Build Coastguard Worker # Don't emit this warning too frequently, 1584*da0073e9SAndroid Build Coastguard Worker # for very large graphs this can become very expensive 1585*da0073e9SAndroid Build Coastguard Worker # from a performance perspective. 1586*da0073e9SAndroid Build Coastguard Worker warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' 1587*da0073e9SAndroid Build Coastguard Worker 'not reference an nn.Module, nn.Parameter, or buffer, which is ' 1588*da0073e9SAndroid Build Coastguard Worker 'what \'get_attr\' Nodes typically target') 1589*da0073e9SAndroid Build Coastguard Worker num_warnings += 1 1590*da0073e9SAndroid Build Coastguard Worker else: 1591*da0073e9SAndroid Build Coastguard Worker m_itr = new_m_itr 1592*da0073e9SAndroid Build Coastguard Worker if num_warnings > MAX_WARNINGS: 1593*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1594*da0073e9SAndroid Build Coastguard Worker f'Additional {num_warnings - MAX_WARNINGS} warnings ' 1595*da0073e9SAndroid Build Coastguard Worker 'suppressed about get_attr references' 1596*da0073e9SAndroid Build Coastguard Worker ) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 1599*da0073e9SAndroid Build Coastguard Worker def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): 1600*da0073e9SAndroid Build Coastguard Worker """ 1601*da0073e9SAndroid Build Coastguard Worker Remove all dead code from the graph, based on each node's number of 1602*da0073e9SAndroid Build Coastguard Worker users, and whether the nodes have any side effects. The graph must be 1603*da0073e9SAndroid Build Coastguard Worker topologically sorted before calling. 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker Args: 1606*da0073e9SAndroid Build Coastguard Worker is_impure_node (Optional[Callable[[Node], bool]]): A function that returns 1607*da0073e9SAndroid Build Coastguard Worker whether a node is impure. If this is None, then the default behavior is to 1608*da0073e9SAndroid Build Coastguard Worker use Node.is_impure. 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker Returns: 1611*da0073e9SAndroid Build Coastguard Worker bool: Whether the graph was changed as a result of the pass. 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker Example: 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker Before dead code is eliminated, `a` from `a = x + 1` below has no users 1616*da0073e9SAndroid Build Coastguard Worker and thus can be eliminated from the graph without having an effect. 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1621*da0073e9SAndroid Build Coastguard Worker a = x + 1 1622*da0073e9SAndroid Build Coastguard Worker return x + self.attr_1 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker After dead code is eliminated, `a = x + 1` has been removed, and the rest 1625*da0073e9SAndroid Build Coastguard Worker of `forward` remains. 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1630*da0073e9SAndroid Build Coastguard Worker return x + self.attr_1 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker .. warning:: 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker Dead code elimination has some heuristics to avoid removing 1635*da0073e9SAndroid Build Coastguard Worker side-effectful nodes (see Node.is_impure) but in general coverage 1636*da0073e9SAndroid Build Coastguard Worker is very bad, so you should assume that this method is not sound 1637*da0073e9SAndroid Build Coastguard Worker to call unless you know that your FX graph consists entirely 1638*da0073e9SAndroid Build Coastguard Worker of functional operations or you supply your own custom 1639*da0073e9SAndroid Build Coastguard Worker function for detecting side-effectful nodes. 1640*da0073e9SAndroid Build Coastguard Worker """ 1641*da0073e9SAndroid Build Coastguard Worker # Lint the graph first to make sure its topologically sorted, otherwise 1642*da0073e9SAndroid Build Coastguard Worker # DCE below will not behave as expected. 1643*da0073e9SAndroid Build Coastguard Worker self.lint() 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def has_side_effect(node): 1646*da0073e9SAndroid Build Coastguard Worker if is_impure_node is not None: 1647*da0073e9SAndroid Build Coastguard Worker return is_impure_node(node) 1648*da0073e9SAndroid Build Coastguard Worker return node.is_impure() 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker # Reverse iterate so that when we remove a node, any nodes used as an 1651*da0073e9SAndroid Build Coastguard Worker # input to that node have an updated user count that no longer reflects 1652*da0073e9SAndroid Build Coastguard Worker # the removed node. 1653*da0073e9SAndroid Build Coastguard Worker changed = False 1654*da0073e9SAndroid Build Coastguard Worker for node in reversed(self.nodes): 1655*da0073e9SAndroid Build Coastguard Worker if not has_side_effect(node) and len(node.users) == 0: 1656*da0073e9SAndroid Build Coastguard Worker self.erase_node(node) 1657*da0073e9SAndroid Build Coastguard Worker changed = True 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker return changed 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 1662*da0073e9SAndroid Build Coastguard Worker def set_codegen(self, codegen: CodeGen): 1663*da0073e9SAndroid Build Coastguard Worker self._codegen = codegen 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 1666*da0073e9SAndroid Build Coastguard Worker def on_generate_code( 1667*da0073e9SAndroid Build Coastguard Worker self, 1668*da0073e9SAndroid Build Coastguard Worker make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] 1669*da0073e9SAndroid Build Coastguard Worker ): 1670*da0073e9SAndroid Build Coastguard Worker """Register a transformer function when python code is generated 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker Args: 1673*da0073e9SAndroid Build Coastguard Worker make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): 1674*da0073e9SAndroid Build Coastguard Worker a function that returns a code transformer to be registered. 1675*da0073e9SAndroid Build Coastguard Worker This function is called by `on_generate_code` to obtain the 1676*da0073e9SAndroid Build Coastguard Worker code transformer. 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker This function is also given as its input the currently 1679*da0073e9SAndroid Build Coastguard Worker registered code transformer (or None if nothing is registered), 1680*da0073e9SAndroid Build Coastguard Worker in case it is not desirable to overwrite it. This is useful to 1681*da0073e9SAndroid Build Coastguard Worker chain code transformers together. 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker Returns: 1684*da0073e9SAndroid Build Coastguard Worker a context manager that when used in a `with` statement, to automatically 1685*da0073e9SAndroid Build Coastguard Worker restore the previously registered code transformer. 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker Example: 1688*da0073e9SAndroid Build Coastguard Worker 1689*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker 1692*da0073e9SAndroid Build Coastguard Worker gm: fx.GraphModule = ... 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker # This is a code transformer we want to register. This code 1695*da0073e9SAndroid Build Coastguard Worker # transformer prepends a pdb import and trace statement at the very 1696*da0073e9SAndroid Build Coastguard Worker # beginning of the generated torch.fx code to allow for manual 1697*da0073e9SAndroid Build Coastguard Worker # debugging with the PDB library. 1698*da0073e9SAndroid Build Coastguard Worker def insert_pdb(body): 1699*da0073e9SAndroid Build Coastguard Worker return ["import pdb; pdb.set_trace()\\n", *body] 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker # Registers `insert_pdb`, and overwrites the current registered 1702*da0073e9SAndroid Build Coastguard Worker # code transformer (given by `_` to the lambda): 1703*da0073e9SAndroid Build Coastguard Worker gm.graph.on_generate_code( 1704*da0073e9SAndroid Build Coastguard Worker lambda _: insert_pdb 1705*da0073e9SAndroid Build Coastguard Worker ) 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker # Or alternatively, registers a code transformer which first 1708*da0073e9SAndroid Build Coastguard Worker # runs `body` through existing registered transformer, then 1709*da0073e9SAndroid Build Coastguard Worker # through `insert_pdb`: 1710*da0073e9SAndroid Build Coastguard Worker gm.graph.on_generate_code( 1711*da0073e9SAndroid Build Coastguard Worker lambda current_trans: ( 1712*da0073e9SAndroid Build Coastguard Worker lambda body: insert_pdb( 1713*da0073e9SAndroid Build Coastguard Worker current_trans(body) if current_trans 1714*da0073e9SAndroid Build Coastguard Worker else body 1715*da0073e9SAndroid Build Coastguard Worker ) 1716*da0073e9SAndroid Build Coastguard Worker ) 1717*da0073e9SAndroid Build Coastguard Worker ) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker gm.recompile() 1720*da0073e9SAndroid Build Coastguard Worker gm(*inputs) # drops into pdb 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker This function can also be used as a context manager, with the benefit to 1724*da0073e9SAndroid Build Coastguard Worker automatically restores the previously registered code transformer: 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker .. code-block:: python 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker # ... continue from previous example 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker with gm.graph.on_generate_code(lambda _: insert_pdb): 1731*da0073e9SAndroid Build Coastguard Worker # do more stuff with `gm`... 1732*da0073e9SAndroid Build Coastguard Worker gm.recompile() 1733*da0073e9SAndroid Build Coastguard Worker gm(*inputs) # drops into pdb 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker # now previous code transformer is restored (but `gm`'s code with pdb 1736*da0073e9SAndroid Build Coastguard Worker # remains - that means you can run `gm` with pdb here too, until you 1737*da0073e9SAndroid Build Coastguard Worker # run next `recompile()`). 1738*da0073e9SAndroid Build Coastguard Worker """ 1739*da0073e9SAndroid Build Coastguard Worker on_gen_code_old = self._codegen._body_transformer 1740*da0073e9SAndroid Build Coastguard Worker self._codegen._body_transformer = make_transformer(on_gen_code_old) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 1743*da0073e9SAndroid Build Coastguard Worker def on_generate_code_context_manager(): 1744*da0073e9SAndroid Build Coastguard Worker try: 1745*da0073e9SAndroid Build Coastguard Worker yield 1746*da0073e9SAndroid Build Coastguard Worker finally: 1747*da0073e9SAndroid Build Coastguard Worker self._codegen._body_transformer = on_gen_code_old 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker return on_generate_code_context_manager() 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Workerreflectable_magic_methods = { 1753*da0073e9SAndroid Build Coastguard Worker 'add': '{} + {}', 1754*da0073e9SAndroid Build Coastguard Worker 'sub': '{} - {}', 1755*da0073e9SAndroid Build Coastguard Worker 'mul': '{} * {}', 1756*da0073e9SAndroid Build Coastguard Worker 'floordiv': '{} // {}', 1757*da0073e9SAndroid Build Coastguard Worker 'truediv': '{} / {}', 1758*da0073e9SAndroid Build Coastguard Worker 'div': '{} / {}', 1759*da0073e9SAndroid Build Coastguard Worker 'mod': '{} % {}', 1760*da0073e9SAndroid Build Coastguard Worker 'pow': '{} ** {}', 1761*da0073e9SAndroid Build Coastguard Worker 'lshift': '{} << {}', 1762*da0073e9SAndroid Build Coastguard Worker 'rshift': '{} >> {}', 1763*da0073e9SAndroid Build Coastguard Worker 'and_': '{} & {}', 1764*da0073e9SAndroid Build Coastguard Worker 'or_': '{} | {}', 1765*da0073e9SAndroid Build Coastguard Worker 'xor': '{} ^ {}', 1766*da0073e9SAndroid Build Coastguard Worker 'getitem': '{}[{}]', 1767*da0073e9SAndroid Build Coastguard Worker 'matmul': '{} @ {}', 1768*da0073e9SAndroid Build Coastguard Worker} 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Workermagic_methods = dict({ 1771*da0073e9SAndroid Build Coastguard Worker 'eq': '{} == {}', 1772*da0073e9SAndroid Build Coastguard Worker 'ne': '{} != {}', 1773*da0073e9SAndroid Build Coastguard Worker 'lt': '{} < {}', 1774*da0073e9SAndroid Build Coastguard Worker 'gt': '{} > {}', 1775*da0073e9SAndroid Build Coastguard Worker 'le': '{} <= {}', 1776*da0073e9SAndroid Build Coastguard Worker 'ge': '{} >= {}', 1777*da0073e9SAndroid Build Coastguard Worker 'pos': '+{}', 1778*da0073e9SAndroid Build Coastguard Worker 'neg': '-{}', 1779*da0073e9SAndroid Build Coastguard Worker 'invert': '~{}'}, **reflectable_magic_methods) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Workerinplace_methods = { 1782*da0073e9SAndroid Build Coastguard Worker 'iadd': '{} += {}', 1783*da0073e9SAndroid Build Coastguard Worker 'iand': '{} &= {}', 1784*da0073e9SAndroid Build Coastguard Worker 'ifloordiv': '{} //= {}', 1785*da0073e9SAndroid Build Coastguard Worker 'ilshift': '{} <<= {}', 1786*da0073e9SAndroid Build Coastguard Worker 'imod': '{} %= {}', 1787*da0073e9SAndroid Build Coastguard Worker 'imul': '{} *= {}', 1788*da0073e9SAndroid Build Coastguard Worker 'imatmul': '{} @= {}', 1789*da0073e9SAndroid Build Coastguard Worker 'ior': '{} |= {}', 1790*da0073e9SAndroid Build Coastguard Worker 'ipow': '{} **= {}', 1791*da0073e9SAndroid Build Coastguard Worker 'irshift': '{} >>= {}', 1792*da0073e9SAndroid Build Coastguard Worker 'isub': '{} -= {}', 1793*da0073e9SAndroid Build Coastguard Worker 'itruediv': '{} /= {}', 1794*da0073e9SAndroid Build Coastguard Worker 'ixor': '{} ^= {}', 1795*da0073e9SAndroid Build Coastguard Worker 'setitem': '{}[{}] = {}', 1796*da0073e9SAndroid Build Coastguard Worker} 1797