1# Nodes represent a definition of a value in our graph of operators. 2from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set 3from ._compatibility import compatibility 4from .immutable_collections import immutable_dict, immutable_list 5import torch 6import builtins 7import types 8import inspect 9import warnings 10from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair 11from .._ops import ops as _ops 12from torch._C import _NodeBase 13 14if TYPE_CHECKING: 15 from .graph import Graph 16 17__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] 18 19BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, 20 torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, 21 torch.SymInt, torch.SymBool, torch.SymFloat] 22base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] 23 24Target = Union[Callable[..., Any], str] 25 26Argument = Optional[Union[ 27 Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types 28 List[Any], # actually Argument 29 Dict[str, Any], # actually Argument 30 slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing 31 range, 32 'Node', 33 BaseArgumentTypes 34]] 35 36_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) 37 38_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { 39 torch._C._set_grad_enabled, 40 torch.amp._enter_autocast, 41 torch.amp._exit_autocast, 42} 43 44# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, 45# or add logic to correctly mark all inplace ops as side effectful. 46_side_effectful_functions: Set[Callable] = { 47 torch._assert, 48 torch._assert_async, 49 _ops.aten._assert_async.msg, 50 _ops.aten._assert_scalar.default, 51 _ops.aten.sym_constrain_range.default, 52 _ops.aten.sym_constrain_range_for_size.default, 53 _ops.profiler._record_function_enter, 54 _ops.profiler._record_function_enter_new, 55 _ops.profiler._record_function_exit, 56 _ops.inductor.accumulate_grad_.default, 57} | _side_effectful_need_to_be_preserved_pre_dispatch 58if hasattr(_ops.inductor, "resize_storage_bytes_"): 59 _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) 60 61 62@compatibility(is_backward_compatible=False) 63def has_side_effect(fn: Callable) -> Callable: 64 _side_effectful_functions.add(fn) 65 return fn 66 67 68# this is fixed on master, WAR for 1.5 69def _find_module_of_method(orig_method: Callable[..., Any]) -> str: 70 name = orig_method.__name__ 71 module = orig_method.__module__ 72 if module is not None: 73 return module 74 for guess in [torch, torch.nn.functional]: 75 if getattr(guess, name, None) is orig_method: 76 return guess.__name__ 77 raise RuntimeError(f'cannot find module for {orig_method}') 78 79# Borrowed from CPython typing module 80# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 81def _type_repr(obj: object) -> str: 82 """Return the repr() of an object, special-casing types (internal helper). 83 If obj is a type, we return a shorter version than the default 84 type.__repr__, based on the module and qualified name, which is 85 typically enough to uniquely identify a type. For everything 86 else, we fall back on repr(obj). 87 """ 88 if isinstance(obj, type): 89 if obj.__module__ == 'builtins': 90 return obj.__qualname__ 91 return f'{obj.__module__}.{obj.__qualname__}' 92 if obj is ...: 93 return '...' 94 if isinstance(obj, types.FunctionType): 95 return obj.__name__ 96 return repr(obj) 97 98def _get_qualified_name(func: Callable[..., Any]) -> str: 99 # things like getattr just appear in builtins 100 if getattr(builtins, func.__name__, None) is func: 101 return func.__name__ 102 # torch.Tensor.{fn} 103 if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) 104 and func is getattr(torch.Tensor, func.__name__, None)): 105 return f"torch.Tensor.{func.__name__}" 106 name = func.__name__ 107 if name == "<lambda>": 108 # For lambdas, try to get their defining name in the module 109 try: 110 name = inspect.getsource(func).split("=")[0].strip() 111 except Exception as e: 112 raise RuntimeError("Unable to represent lambda") from e 113 module = _find_module_of_method(func) 114 module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module 115 # Fixup segment_reduce mismatch 116 if module == "torch" and name == "segment_reduce": 117 name = "_" + name 118 return f'{module}.{name}' 119 120def _format_arg(arg: object, max_list_len: float = float('inf')) -> str: 121 if hasattr(arg, '_custom_fx_repr_fn'): 122 return arg._custom_fx_repr_fn() 123 elif isinstance(arg, list): 124 items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) 125 maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' 126 return f'[{items}{maybe_len}]' 127 elif isinstance(arg, tuple): 128 items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) 129 maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' 130 maybe_comma = ',' if len(arg) == 1 else '' 131 return f'({items}{maybe_comma}{maybe_len})' 132 elif isinstance(arg, dict): 133 items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) 134 return f'{{{items_str}}}' 135 136 if isinstance(arg, Node): 137 return '%' + str(arg) 138 else: 139 return str(arg) 140 141@compatibility(is_backward_compatible=True) 142class Node(_NodeBase): 143 """ 144 ``Node`` is the data structure that represents individual operations within 145 a ``Graph``. For the most part, Nodes represent callsites to various entities, 146 such as operators, methods, and Modules (some exceptions include nodes that 147 specify function inputs and outputs). Each ``Node`` has a function specified 148 by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: 149 150 - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. 151 ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument 152 denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to 153 the function parameters (e.g. ``x``) in the graph printout. 154 - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the 155 fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. 156 ``args`` and ``kwargs`` are don't-care 157 - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign 158 to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, 159 following the Python calling convention 160 - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is 161 as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. 162 ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. 163 - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method 164 to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, 165 *including the self argument* 166 - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement 167 in the Graph printout. 168 """ 169 _args: Tuple['Argument', ...] 170 _kwargs: Dict[str, 'Argument'] 171 172 @compatibility(is_backward_compatible=True) 173 def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', 174 args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], 175 return_type : Optional[Any] = None) -> None: 176 """ 177 Instantiate an instance of ``Node``. Note: most often, you want to use the 178 Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather 179 than instantiating a ``Node`` directly. 180 181 Args: 182 graph (Graph): The ``Graph`` to which this ``Node`` should belong. 183 184 name (str): The name to which the output of this ``Node`` should be assigned 185 186 op (str): The opcode for this ``Node``. Can be one of 'placeholder', 187 'call_method', 'call_module', 'call_function', 'get_attr', 188 'output' 189 190 target ('Target'): The target this op should call. See the broader 191 ``Node`` docstring for more details. 192 193 args (Tuple['Argument']): The args to be passed to ``target`` 194 195 kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` 196 197 return_type (Optional[Any]): The python type expression representing the 198 type of the output of this node. This field can be used for 199 annotation of values in the generated code or for other types 200 of analyses. 201 """ 202 super().__init__() 203 self.graph = graph 204 self.name = name # unique name of value being created 205 assert op in _legal_ops 206 self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr 207 if op == 'call_function': 208 if not callable(target): 209 raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 210 'but a Callable is expected') 211 else: 212 if not isinstance(target, str): 213 raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 214 'but a str is expected') 215 self.target = target # for method/module/function, the name of the method/module/function/attr 216 # being invoked, e.g add, layer1, or torch.add 217 218 # All `Node`-valued inputs. Key is the Node, value is don't-care. 219 # The public API for this is `all_input_nodes`, this private attribute 220 # should not be accessed directly. 221 self._input_nodes : Dict[Node, None] = {} 222 self.__update_args_kwargs(args, kwargs) 223 224 # All of the nodes that use the value produced by this Node 225 # Note one user may correspond to several uses, e.g. the node fo ``x + x`` 226 # would appear once here, but represents two uses. 227 # 228 # Is a dict to act as an "ordered set". Keys are significant, value dont-care 229 self.users : Dict[Node, None] = {} 230 # Type expression representing the output value of this node. 231 # This should contain the same class of Type objects that would appear 232 # as type annotations for function inputs/outputs. 233 # 234 # For placeholder nodes, this value will be used to type-annotate the 235 # generated function parameters. 236 # For the return node, this value will be used to type-annotate the 237 # generated function return type. (Note this is a special case. ``return`` 238 # does not produce a value, it's more of a notation. Thus, this value 239 # describes the type of args[0] in the ``return`` node. 240 self.type : Optional[Any] = return_type 241 self._sort_key: Any = () 242 243 # If set, use this fn to print this node 244 self._repr_fn : Optional[Callable[[Node], str]] = None 245 246 # Dictionary to store metadata passes need to do their 247 # transformations. This metadata is preserved across node copies 248 self.meta : Dict[str, Any] = {} 249 250 def __getstate__(self) -> Dict[str, Any]: 251 state = self.__dict__.copy() 252 state["_erased"] = self._erased 253 state["_prev"] = self._prev 254 state["_next"] = self._next 255 return state 256 257 def __setstate__(self, state: Dict[str, Any]) -> None: 258 _erased = state.pop("_erased") 259 _prev = state.pop("_prev") 260 _next = state.pop("_next") 261 self.__dict__.update(state) 262 self._erased = _erased 263 self._prev = _prev 264 self._next = _next 265 266 @property 267 def next(self) -> 'Node': 268 """ 269 Returns the next ``Node`` in the linked list of Nodes. 270 271 Returns: 272 273 The next ``Node`` in the linked list of Nodes. 274 """ 275 return self._next 276 277 @property 278 def prev(self) -> 'Node': 279 """ 280 Returns the previous ``Node`` in the linked list of Nodes. 281 282 Returns: 283 284 The previous ``Node`` in the linked list of Nodes. 285 """ 286 return self._prev 287 288 @compatibility(is_backward_compatible=True) 289 def prepend(self, x: 'Node') -> None: 290 """ 291 Insert x before this node in the list of nodes in the graph. Example:: 292 293 Before: p -> self 294 bx -> x -> ax 295 After: p -> x -> self 296 bx -> ax 297 298 Args: 299 x (Node): The node to put before this node. Must be a member of the same graph. 300 """ 301 assert self.graph == x.graph, "Attempting to move a Node into a different Graph" 302 if self == x: 303 warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") 304 return 305 x._remove_from_list() 306 p = self._prev 307 p._next, x._prev = x, p 308 x._next, self._prev = self, x 309 310 # compute x._sort_key 311 psk = x._prev._sort_key 312 nsk = x._next._sort_key 313 if len(psk) > len(nsk): 314 idx: int 315 *prefix, idx = psk[:len(nsk) + 1] 316 x._sort_key = (*prefix, idx + 1) 317 elif len(psk) < len(nsk): 318 *prefix, idx = nsk[:len(psk) + 1] 319 x._sort_key = (*prefix, idx - 1) 320 else: # same length, increase length by 1 321 x._sort_key = (*psk, 0) 322 323 def __gt__(self, other: 'Node') -> bool: 324 return self._sort_key > other._sort_key 325 326 def __lt__(self, other: 'Node') -> bool: 327 return self._sort_key < other._sort_key 328 329 def __ge__(self, other: 'Node') -> bool: 330 return self > other or self == other 331 332 def __le__(self, other: 'Node') -> bool: 333 return self < other or self == other 334 335 @compatibility(is_backward_compatible=True) 336 def append(self, x: 'Node') -> None: 337 """ 338 Insert ``x`` after this node in the list of nodes in the graph. 339 Equivalent to ``self.next.prepend(x)`` 340 341 Args: 342 x (Node): The node to put after this node. Must be a member of the same graph. 343 """ 344 self._next.prepend(x) 345 346 def _remove_from_list(self) -> None: 347 p, n = self._prev, self._next 348 p._next, n._prev = n, p 349 350 @property 351 def args(self) -> Tuple[Argument, ...]: 352 """ 353 The tuple of arguments to this ``Node``. The interpretation of arguments 354 depends on the node's opcode. See the :class:`Node` docstring for more 355 information. 356 357 Assignment to this property is allowed. All accounting of uses and users 358 is updated automatically on assignment. 359 """ 360 return self._args 361 362 @args.setter 363 def args(self, a : Tuple[Argument, ...]) -> None: 364 """ 365 Set the tuple of arguments to this Node. The interpretation of arguments 366 depends on the node's opcode. See the ``fx.Graph`` docstring for more 367 information. 368 """ 369 # DO NOT CALL `__update_args_kwargs` directly. The correct way to 370 # set `args` is via direct assignment, i.e. `node.args = new_args` 371 self.__update_args_kwargs(a, self._kwargs) 372 373 @property 374 def kwargs(self) -> Dict[str, Argument]: 375 """ 376 The dict of keyword arguments to this ``Node``. The interpretation of arguments 377 depends on the node's opcode. See the :class:`Node` docstring for more 378 information. 379 380 Assignment to this property is allowed. All accounting of uses and users 381 is updated automatically on assignment. 382 """ 383 return self._kwargs 384 385 @kwargs.setter 386 def kwargs(self, k : Dict[str, Argument]) -> None: 387 """ 388 Set the dict of kwargs to this Node. The interpretation of arguments 389 depends on the node's opcode. See the ``fx.Graph`` docstring for more 390 information. 391 """ 392 # DO NOT CALL `__update_args_kwargs` directly. The correct way to 393 # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` 394 self.__update_args_kwargs(self._args, k) 395 396 @property 397 def all_input_nodes(self) -> List['Node']: 398 """ 399 Return all Nodes that are inputs to this Node. This is equivalent to 400 iterating over ``args`` and ``kwargs`` and only collecting the values that 401 are Nodes. 402 403 Returns: 404 405 List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this 406 ``Node``, in that order. 407 """ 408 return list(self._input_nodes.keys()) 409 410 @compatibility(is_backward_compatible=True) 411 def update_arg(self, idx : int, arg : Argument) -> None: 412 """ 413 Update an existing positional argument to contain the new value 414 ``arg``. After calling, ``self.args[idx] == arg``. 415 416 Args: 417 418 idx (int): The index into ``self.args`` of the element to update 419 arg (Argument): The new argument value to write into ``args`` 420 """ 421 args = list(self.args) 422 args[idx] = arg 423 self.args = tuple(args) 424 425 @compatibility(is_backward_compatible=True) 426 def insert_arg(self, idx : int, arg : Argument) -> None: 427 """ 428 Insert an positional argument to the argument list with given index. 429 430 Args: 431 432 idx (int): The index of the element in ``self.args`` to be inserted before. 433 arg (Argument): The new argument value to insert into ``args`` 434 """ 435 assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" 436 args_left = self.args[:idx] 437 args_right = self.args[idx:] 438 439 self._args = args_left + (arg,) + args_right 440 441 _new_input_nodes: Dict[Node, None] = {} 442 map_arg(arg, _new_input_nodes.setdefault) 443 444 for new_use in _new_input_nodes.keys(): 445 if new_use not in self._input_nodes: 446 self._input_nodes.setdefault(new_use) 447 new_use.users.setdefault(self) 448 449 @compatibility(is_backward_compatible=True) 450 def update_kwarg(self, key : str, arg : Argument) -> None: 451 """ 452 Update an existing keyword argument to contain the new value 453 ``arg``. After calling, ``self.kwargs[key] == arg``. 454 455 Args: 456 457 key (str): The key in ``self.kwargs`` of the element to update 458 arg (Argument): The new argument value to write into ``kwargs`` 459 """ 460 self.kwargs = {**self.kwargs, key: arg} 461 462 @property 463 def stack_trace(self) -> Optional[str]: 464 """ 465 Return the Python stack trace that was recorded during tracing, if any. 466 When traced with fx.Tracer, this property is usually populated by 467 `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, 468 set `record_stack_traces = True` on the `Tracer` instance. 469 When traced with dynamo, this property will be populated by default by 470 `OutputGraph.create_proxy`. 471 472 stack_trace would have the innermost frame at the end of the string. 473 """ 474 return self.meta.get("stack_trace", None) 475 476 @stack_trace.setter 477 def stack_trace(self, trace : Optional[str]) -> None: 478 self.meta["stack_trace"] = trace 479 480 def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: 481 """ 482 This API is internal. Do *not* call it directly. 483 """ 484 def update_users_and_input_nodes(n: Any) -> Any: 485 if isinstance(n, Node): 486 self._input_nodes.setdefault(n) 487 n.users.setdefault(self) 488 return n 489 490 # Clear prior users and input_nodes 491 for old_use in self._input_nodes.keys(): 492 old_use.users.pop(self) 493 self._input_nodes = {} 494 495 # We do three things in a single pass of the args 496 # - Normalize list->immutable_list, dict->immutable_dict, etc 497 # - Populate self._input_nodes 498 # - Populate arg.users[self] for each arg 499 self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] 500 self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] 501 502 def __repr__(self) -> str: 503 if self._repr_fn: 504 return self._repr_fn(self) 505 return self.name 506 507 def _pretty_print_target(self, target: object) -> str: 508 """ 509 Make target printouts more user-friendly. 510 1) builtins will be printed as `builtins.xyz` 511 2) operators will be printed as `operator.xyz` 512 3) other callables will be printed with qualified name, e.g. torch.add 513 """ 514 if isinstance(target, str): 515 return target 516 if hasattr(target, '__module__'): 517 name = getattr(target, '__name__', None) 518 if name is None: 519 # Just to be defensive, if we don't have `__name__`, get the 520 # qualname. Not sure if this happens for any members of `operator` 521 # or `builtins`. This fallback path is not as good, since e.g. 522 # things in `operator` have `_operator` as their __module__. 523 # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` 524 return _get_qualified_name(target) # type: ignore[arg-type] 525 if target.__module__ == 'builtins': 526 return f'builtins.{name}' 527 elif target.__module__ == '_operator': 528 return f'operator.{name}' 529 return _get_qualified_name(target) # type: ignore[arg-type] 530 531 @compatibility(is_backward_compatible=True) 532 def format_node(self, 533 placeholder_names: Optional[List[str]] = None, 534 maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: 535 """ 536 Return a descriptive string representation of ``self``. 537 538 This method can be used with no arguments as a debugging 539 utility. 540 541 This function is also used internally in the ``__str__`` method 542 of ``Graph``. Together, the strings in ``placeholder_names`` 543 and ``maybe_return_typename`` make up the signature of the 544 autogenerated ``forward`` function in this Graph's surrounding 545 GraphModule. ``placeholder_names`` and ``maybe_return_typename`` 546 should not be used otherwise. 547 548 Args: 549 placeholder_names: A list that will store formatted strings 550 representing the placeholders in the generated 551 ``forward`` function. Internal use only. 552 maybe_return_typename: A single-element list that will store 553 a formatted string representing the output of the 554 generated ``forward`` function. Internal use only. 555 556 Returns: 557 str: If 1) we're using ``format_node`` as an internal helper 558 in the ``__str__`` method of ``Graph``, and 2) ``self`` 559 is a placeholder Node, return ``None``. Otherwise, 560 return a descriptive string representation of the 561 current Node. 562 """ 563 if self.op == 'placeholder': 564 assert isinstance(self.target, str) 565 arg_str = self.target 566 arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' 567 if placeholder_names: 568 placeholder_names.append(arg_str) 569 return None 570 maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' 571 default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' 572 return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' 573 elif self.op == 'get_attr': 574 maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' 575 return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ 576 f'{self.op}[target={self._pretty_print_target(self.target)}]' 577 elif self.op == 'output': 578 if self.type and maybe_return_typename: 579 maybe_return_typename[0] = f' -> {_type_repr(self.type)}' 580 return f'return {self.args[0]}' 581 else: 582 maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' 583 return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ 584 f'{self.op}[target={self._pretty_print_target(self.target)}](' \ 585 f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' 586 587 @compatibility(is_backward_compatible=True) 588 def replace_all_uses_with(self, 589 replace_with: 'Node', 590 delete_user_cb: Callable[['Node'], bool] = lambda user: True, 591 *, 592 propagate_meta: bool = False 593 ) -> List['Node']: 594 """ 595 Replace all uses of ``self`` in the Graph with the Node ``replace_with``. 596 597 Args: 598 599 replace_with (Node): The node to replace all uses of ``self`` with. 600 delete_user_cb (Callable): Callback that is called to determine 601 whether a given user of the self node should be removed. 602 propagate_meta (bool): Whether or not to copy all properties 603 on the .meta field of the original node onto the replacement node. 604 For safety, this is only valid to do if the replacement node 605 doesn't already have an existing .meta field. 606 607 Returns: 608 609 The list of Nodes on which this change was made. 610 """ 611 if propagate_meta: 612 assert len(replace_with.meta) == 0, \ 613 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ 614 'but replace_with already has .meta keys' 615 for k, v in self.meta.items(): 616 replace_with.meta[k] = v 617 to_process = list(self.users) 618 skipped = [] 619 m = self.graph.owning_module 620 for use_node in to_process: 621 if not delete_user_cb(use_node): 622 skipped.append(use_node) 623 continue 624 625 def maybe_replace_node(n : Node) -> Node: 626 if n == self: 627 return replace_with 628 else: 629 return n 630 631 if getattr(m, "_replace_hook", None): 632 m._replace_hook(old=self, new=replace_with.name, user=use_node) 633 634 new_args = map_arg(use_node.args, maybe_replace_node) 635 new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) 636 assert isinstance(new_args, tuple) 637 assert isinstance(new_kwargs, dict) 638 use_node.__update_args_kwargs(new_args, new_kwargs) 639 640 assert len(self.users) - len(skipped) == 0 641 return [n for n in to_process if n not in skipped] 642 643 @compatibility(is_backward_compatible=False) 644 def is_impure(self) -> bool: 645 """ 646 Returns whether this op is impure, i.e. if its op is a placeholder or 647 output, or if a call_function or call_module which is impure. 648 649 Returns: 650 651 bool: If the op is impure or not. 652 """ 653 if self.op in {"placeholder", "output"}: 654 return True 655 656 # Check if an impure function based on schema. 657 if self.op == "call_function": 658 schema = getattr(self.target, "_schema", None) 659 schema_mutable = schema is not None and schema.is_mutable 660 return schema_mutable or self.target in _side_effectful_functions 661 662 # Check if an impure module. 663 if self.op == "call_module": 664 assert ( 665 self.graph.owning_module is not None 666 ), "self.graph.owning_module not set for purity check" 667 target_mod = self.graph.owning_module.get_submodule(self.target) 668 assert ( 669 target_mod is not None 670 ), f"Did not find expected submodule target {self.target}" 671 return getattr(target_mod, "_is_impure", False) 672 673 return False 674 675 @compatibility(is_backward_compatible=False) 676 def normalized_arguments( 677 self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, 678 kwarg_types : Optional[Dict[str, Any]] = None, 679 normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: 680 """ 681 Returns normalized arguments to Python targets. This means that 682 `args/kwargs` will be matched up to the module/functional's 683 signature and return exclusively kwargs in positional order 684 if `normalize_to_only_use_kwargs` is true. 685 Also populates default values. Does not support positional-only 686 parameters or varargs parameters. 687 688 Supports module calls. 689 690 May require `arg_types` and `kwarg_types` in order to disambiguate overloads. 691 692 Args: 693 root (torch.nn.Module): Module upon which to resolve module targets. 694 arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args 695 kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs 696 normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. 697 698 Returns: 699 700 Returns NamedTuple ArgsKwargsPair, or `None` if not successful. 701 """ 702 if self.op == 'call_function': 703 assert callable(self.target) 704 return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] 705 elif self.op == 'call_module': 706 assert isinstance(self.target, str) 707 return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] 708 709 return None 710 711 @compatibility(is_backward_compatible=True) 712 def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: 713 """ 714 Loop through input nodes of ``self``, and replace all instances of 715 ``old_input`` with ``new_input``. 716 717 Args: 718 719 old_input (Node): The old input node to be replaced. 720 new_input (Node): The new input node to replace ``old_input``. 721 """ 722 def maybe_replace_node(n : Node) -> Node: 723 return new_input if n == old_input else n 724 725 m = self.graph.owning_module 726 if getattr(m, "_replace_hook", None): 727 m._replace_hook(old=old_input, new=new_input.name, user=self) 728 729 new_args = map_arg(self.args, maybe_replace_node) 730 new_kwargs = map_arg(self.kwargs, maybe_replace_node) 731 assert isinstance(new_args, tuple) 732 assert isinstance(new_kwargs, dict) 733 self.__update_args_kwargs(new_args, new_kwargs) 734 735 def _rename(self, candidate: str) -> None: 736 if candidate == self.name: 737 return 738 name = self.graph._graph_namespace.create_name(candidate, None) 739 self.name = name 740 self.graph._graph_namespace._rename_object(self, name) 741 742 def __setattr__(self, name: str, value: Any) -> None: 743 if name == 'name' and hasattr(self, "name"): 744 m = self.graph.owning_module 745 if getattr(m, "_replace_hook", None): 746 assert isinstance(value, str) 747 for user in self.users: 748 m._replace_hook(old=self, new=value, user=user) 749 update = False 750 if ( 751 hasattr(self, name) and 752 hasattr(self.graph, "_find_nodes_lookup_table") and 753 self in self.graph._find_nodes_lookup_table 754 ): 755 update = True 756 self.graph._find_nodes_lookup_table.remove(self) 757 object.__setattr__(self, name, value) 758 if update: 759 self.graph._find_nodes_lookup_table.insert(self) 760 761@compatibility(is_backward_compatible=True) 762def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: 763 """ 764 Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. 765 """ 766 assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" 767 return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) 768 769@compatibility(is_backward_compatible=True) 770def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: 771 """ 772 Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. 773 """ 774 if isinstance(a, tuple): 775 t = tuple([map_aggregate(elem, fn) for elem in a]) 776 # Support NamedTuple (if it has `_fields`) by repacking into original type. 777 return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] 778 elif isinstance(a, list): 779 return immutable_list([map_aggregate(elem, fn) for elem in a]) 780 elif isinstance(a, dict): 781 rv = immutable_dict() 782 for k, v in a.items(): 783 dict.__setitem__(rv, k, map_aggregate(v, fn)) 784 return rv 785 elif isinstance(a, slice): 786 return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) 787 else: 788 return fn(a) 789