xref: /aosp_15_r20/external/pytorch/torch/fx/node.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Nodes represent a definition of a value in our graph of operators.
2from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
3from ._compatibility import compatibility
4from .immutable_collections import immutable_dict, immutable_list
5import torch
6import builtins
7import types
8import inspect
9import warnings
10from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
11from .._ops import ops as _ops
12from torch._C import _NodeBase
13
14if TYPE_CHECKING:
15    from .graph import Graph
16
17__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
18
19BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
20                          torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload,
21                          torch.SymInt, torch.SymBool, torch.SymFloat]
22base_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined]
23
24Target = Union[Callable[..., Any], str]
25
26Argument = Optional[Union[
27    Tuple[Any, ...],  # actually Argument, but mypy can't represent recursive types
28    List[Any],  # actually Argument
29    Dict[str, Any],  # actually Argument
30    slice,  # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
31    range,
32    'Node',
33    BaseArgumentTypes
34]]
35
36_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'])
37
38_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
39    torch._C._set_grad_enabled,
40    torch.amp._enter_autocast,
41    torch.amp._exit_autocast,
42}
43
44# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
45# or add logic to correctly mark all inplace ops as side effectful.
46_side_effectful_functions: Set[Callable] = {
47    torch._assert,
48    torch._assert_async,
49    _ops.aten._assert_async.msg,
50    _ops.aten._assert_scalar.default,
51    _ops.aten.sym_constrain_range.default,
52    _ops.aten.sym_constrain_range_for_size.default,
53    _ops.profiler._record_function_enter,
54    _ops.profiler._record_function_enter_new,
55    _ops.profiler._record_function_exit,
56    _ops.inductor.accumulate_grad_.default,
57} | _side_effectful_need_to_be_preserved_pre_dispatch
58if hasattr(_ops.inductor, "resize_storage_bytes_"):
59    _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)
60
61
62@compatibility(is_backward_compatible=False)
63def has_side_effect(fn: Callable) -> Callable:
64    _side_effectful_functions.add(fn)
65    return fn
66
67
68# this is fixed on master, WAR for 1.5
69def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
70    name = orig_method.__name__
71    module = orig_method.__module__
72    if module is not None:
73        return module
74    for guess in [torch, torch.nn.functional]:
75        if getattr(guess, name, None) is orig_method:
76            return guess.__name__
77    raise RuntimeError(f'cannot find module for {orig_method}')
78
79# Borrowed from CPython typing module
80# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
81def _type_repr(obj: object) -> str:
82    """Return the repr() of an object, special-casing types (internal helper).
83    If obj is a type, we return a shorter version than the default
84    type.__repr__, based on the module and qualified name, which is
85    typically enough to uniquely identify a type.  For everything
86    else, we fall back on repr(obj).
87    """
88    if isinstance(obj, type):
89        if obj.__module__ == 'builtins':
90            return obj.__qualname__
91        return f'{obj.__module__}.{obj.__qualname__}'
92    if obj is ...:
93        return '...'
94    if isinstance(obj, types.FunctionType):
95        return obj.__name__
96    return repr(obj)
97
98def _get_qualified_name(func: Callable[..., Any]) -> str:
99    # things like getattr just appear in builtins
100    if getattr(builtins, func.__name__, None) is func:
101        return func.__name__
102    # torch.Tensor.{fn}
103    if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType))
104       and func is getattr(torch.Tensor, func.__name__, None)):
105        return f"torch.Tensor.{func.__name__}"
106    name = func.__name__
107    if name == "<lambda>":
108        # For lambdas, try to get their defining name in the module
109        try:
110            name = inspect.getsource(func).split("=")[0].strip()
111        except Exception as e:
112            raise RuntimeError("Unable to represent lambda") from e
113    module = _find_module_of_method(func)
114    module = module.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
115    # Fixup segment_reduce mismatch
116    if module == "torch" and name == "segment_reduce":
117        name = "_" + name
118    return f'{module}.{name}'
119
120def _format_arg(arg: object, max_list_len: float = float('inf')) -> str:
121    if hasattr(arg, '_custom_fx_repr_fn'):
122        return arg._custom_fx_repr_fn()
123    elif isinstance(arg, list):
124        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
125        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
126        return f'[{items}{maybe_len}]'
127    elif isinstance(arg, tuple):
128        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
129        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
130        maybe_comma = ',' if len(arg) == 1 else ''
131        return f'({items}{maybe_comma}{maybe_len})'
132    elif isinstance(arg, dict):
133        items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
134        return f'{{{items_str}}}'
135
136    if isinstance(arg, Node):
137        return '%' + str(arg)
138    else:
139        return str(arg)
140
141@compatibility(is_backward_compatible=True)
142class Node(_NodeBase):
143    """
144    ``Node`` is the data structure that represents individual operations within
145    a ``Graph``. For the most part, Nodes represent callsites to various entities,
146    such as operators, methods, and Modules (some exceptions include nodes that
147    specify function inputs and outputs). Each ``Node`` has a function specified
148    by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:
149
150    - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
151      ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
152      denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
153      the function parameters (e.g. ``x``) in the graph printout.
154    - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
155      fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
156      ``args`` and ``kwargs`` are don't-care
157    - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
158      to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
159      following the Python calling convention
160    - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
161      as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
162      ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*.
163    - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
164      to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
165      *including the self argument*
166    - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
167      in the Graph printout.
168    """
169    _args: Tuple['Argument', ...]
170    _kwargs: Dict[str, 'Argument']
171
172    @compatibility(is_backward_compatible=True)
173    def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
174                 args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
175                 return_type : Optional[Any] = None) -> None:
176        """
177        Instantiate an instance of ``Node``. Note: most often, you want to use the
178        Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
179        than instantiating a ``Node`` directly.
180
181        Args:
182            graph (Graph): The ``Graph`` to which this ``Node`` should belong.
183
184            name (str): The name to which the output of this ``Node`` should be assigned
185
186            op (str): The opcode for this ``Node``. Can be one of 'placeholder',
187                'call_method', 'call_module', 'call_function', 'get_attr',
188                'output'
189
190            target ('Target'): The target this op should call. See the broader
191                ``Node`` docstring for more details.
192
193            args (Tuple['Argument']): The args to be passed to ``target``
194
195            kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target``
196
197            return_type (Optional[Any]): The python type expression representing the
198                type of the output of this node. This field can be used for
199                annotation of values in the generated code or for other types
200                of analyses.
201        """
202        super().__init__()
203        self.graph = graph
204        self.name = name  # unique name of value being created
205        assert op in _legal_ops
206        self.op = op  # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
207        if op == 'call_function':
208            if not callable(target):
209                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
210                                 'but a Callable is expected')
211        else:
212            if not isinstance(target, str):
213                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
214                                 'but a str is expected')
215        self.target = target  # for method/module/function, the name of the method/module/function/attr
216        # being invoked, e.g add, layer1, or torch.add
217
218        # All `Node`-valued inputs. Key is the Node, value is don't-care.
219        # The public API for this is `all_input_nodes`, this private attribute
220        # should not be accessed directly.
221        self._input_nodes : Dict[Node, None] = {}
222        self.__update_args_kwargs(args, kwargs)
223
224        # All of the nodes that use the value produced by this Node
225        # Note one user may correspond to several uses, e.g. the node fo ``x + x``
226        # would appear once here, but represents two uses.
227        #
228        # Is a dict to act as an "ordered set". Keys are significant, value dont-care
229        self.users : Dict[Node, None] = {}
230        # Type expression representing the output value of this node.
231        # This should contain the same class of Type objects that would appear
232        # as type annotations for function inputs/outputs.
233        #
234        # For placeholder nodes, this value will be used to type-annotate the
235        # generated function parameters.
236        # For the return node, this value will be used to type-annotate the
237        # generated function return type. (Note this is a special case. ``return``
238        # does not produce a value, it's more of a notation. Thus, this value
239        # describes the type of args[0] in the ``return`` node.
240        self.type : Optional[Any] = return_type
241        self._sort_key: Any = ()
242
243        # If set, use this fn to print this node
244        self._repr_fn : Optional[Callable[[Node], str]] = None
245
246        # Dictionary to store metadata passes need to do their
247        # transformations. This metadata is preserved across node copies
248        self.meta : Dict[str, Any] = {}
249
250    def __getstate__(self) -> Dict[str, Any]:
251        state = self.__dict__.copy()
252        state["_erased"] = self._erased
253        state["_prev"] = self._prev
254        state["_next"] = self._next
255        return state
256
257    def __setstate__(self, state: Dict[str, Any]) -> None:
258        _erased = state.pop("_erased")
259        _prev = state.pop("_prev")
260        _next = state.pop("_next")
261        self.__dict__.update(state)
262        self._erased = _erased
263        self._prev = _prev
264        self._next = _next
265
266    @property
267    def next(self) -> 'Node':
268        """
269        Returns the next ``Node`` in the linked list of Nodes.
270
271        Returns:
272
273            The next ``Node`` in the linked list of Nodes.
274        """
275        return self._next
276
277    @property
278    def prev(self) -> 'Node':
279        """
280        Returns the previous ``Node`` in the linked list of Nodes.
281
282        Returns:
283
284            The previous ``Node`` in the linked list of Nodes.
285        """
286        return self._prev
287
288    @compatibility(is_backward_compatible=True)
289    def prepend(self, x: 'Node') -> None:
290        """
291        Insert x before this node in the list of nodes in the graph. Example::
292
293            Before: p -> self
294                    bx -> x -> ax
295            After:  p -> x -> self
296                    bx -> ax
297
298        Args:
299            x (Node): The node to put before this node. Must be a member of the same graph.
300        """
301        assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
302        if self == x:
303            warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.")
304            return
305        x._remove_from_list()
306        p = self._prev
307        p._next, x._prev = x, p
308        x._next, self._prev = self, x
309
310        # compute x._sort_key
311        psk = x._prev._sort_key
312        nsk = x._next._sort_key
313        if len(psk) > len(nsk):
314            idx: int
315            *prefix, idx = psk[:len(nsk) + 1]
316            x._sort_key = (*prefix, idx + 1)
317        elif len(psk) < len(nsk):
318            *prefix, idx = nsk[:len(psk) + 1]
319            x._sort_key = (*prefix, idx - 1)
320        else:  # same length, increase length by 1
321            x._sort_key = (*psk, 0)
322
323    def __gt__(self, other: 'Node') -> bool:
324        return self._sort_key > other._sort_key
325
326    def __lt__(self, other: 'Node') -> bool:
327        return self._sort_key < other._sort_key
328
329    def __ge__(self, other: 'Node') -> bool:
330        return self > other or self == other
331
332    def __le__(self, other: 'Node') -> bool:
333        return self < other or self == other
334
335    @compatibility(is_backward_compatible=True)
336    def append(self, x: 'Node') -> None:
337        """
338        Insert ``x`` after this node in the list of nodes in the graph.
339        Equivalent to ``self.next.prepend(x)``
340
341        Args:
342            x (Node): The node to put after this node. Must be a member of the same graph.
343        """
344        self._next.prepend(x)
345
346    def _remove_from_list(self) -> None:
347        p, n = self._prev, self._next
348        p._next, n._prev = n, p
349
350    @property
351    def args(self) -> Tuple[Argument, ...]:
352        """
353        The tuple of arguments to this ``Node``. The interpretation of arguments
354        depends on the node's opcode. See the :class:`Node` docstring for more
355        information.
356
357        Assignment to this property is allowed. All accounting of uses and users
358        is updated automatically on assignment.
359        """
360        return self._args
361
362    @args.setter
363    def args(self, a : Tuple[Argument, ...]) -> None:
364        """
365        Set the tuple of arguments to this Node. The interpretation of arguments
366        depends on the node's opcode. See the ``fx.Graph`` docstring for more
367        information.
368        """
369        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
370        # set `args` is via direct assignment, i.e. `node.args = new_args`
371        self.__update_args_kwargs(a, self._kwargs)
372
373    @property
374    def kwargs(self) -> Dict[str, Argument]:
375        """
376        The dict of keyword arguments to this ``Node``. The interpretation of arguments
377        depends on the node's opcode. See the :class:`Node` docstring for more
378        information.
379
380        Assignment to this property is allowed. All accounting of uses and users
381        is updated automatically on assignment.
382        """
383        return self._kwargs
384
385    @kwargs.setter
386    def kwargs(self, k : Dict[str, Argument]) -> None:
387        """
388        Set the dict of kwargs to this Node. The interpretation of arguments
389        depends on the node's opcode. See the ``fx.Graph`` docstring for more
390        information.
391        """
392        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
393        # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
394        self.__update_args_kwargs(self._args, k)
395
396    @property
397    def all_input_nodes(self) -> List['Node']:
398        """
399        Return all Nodes that are inputs to this Node. This is equivalent to
400        iterating over ``args`` and ``kwargs`` and only collecting the values that
401        are Nodes.
402
403        Returns:
404
405            List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
406            ``Node``, in that order.
407        """
408        return list(self._input_nodes.keys())
409
410    @compatibility(is_backward_compatible=True)
411    def update_arg(self, idx : int, arg : Argument) -> None:
412        """
413        Update an existing positional argument to contain the new value
414        ``arg``. After calling, ``self.args[idx] == arg``.
415
416        Args:
417
418            idx (int): The index into ``self.args`` of the element to update
419            arg (Argument): The new argument value to write into ``args``
420        """
421        args = list(self.args)
422        args[idx] = arg
423        self.args = tuple(args)
424
425    @compatibility(is_backward_compatible=True)
426    def insert_arg(self, idx : int, arg : Argument) -> None:
427        """
428        Insert an positional argument to the argument list with given index.
429
430        Args:
431
432            idx (int): The index of the element in ``self.args`` to be inserted before.
433            arg (Argument): The new argument value to insert into ``args``
434        """
435        assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)"
436        args_left = self.args[:idx]
437        args_right = self.args[idx:]
438
439        self._args = args_left + (arg,) + args_right
440
441        _new_input_nodes: Dict[Node, None] = {}
442        map_arg(arg, _new_input_nodes.setdefault)
443
444        for new_use in _new_input_nodes.keys():
445            if new_use not in self._input_nodes:
446                self._input_nodes.setdefault(new_use)
447                new_use.users.setdefault(self)
448
449    @compatibility(is_backward_compatible=True)
450    def update_kwarg(self, key : str, arg : Argument) -> None:
451        """
452        Update an existing keyword argument to contain the new value
453        ``arg``. After calling, ``self.kwargs[key] == arg``.
454
455        Args:
456
457            key (str): The key in ``self.kwargs`` of the element to update
458            arg (Argument): The new argument value to write into ``kwargs``
459        """
460        self.kwargs = {**self.kwargs, key: arg}
461
462    @property
463    def stack_trace(self) -> Optional[str]:
464        """
465        Return the Python stack trace that was recorded during tracing, if any.
466        When traced with fx.Tracer, this property is usually populated by
467        `Tracer.create_proxy`. To record stack traces during tracing for debug purposes,
468        set `record_stack_traces = True` on the `Tracer` instance.
469        When traced with dynamo, this property will be populated by default by
470        `OutputGraph.create_proxy`.
471
472        stack_trace would have the innermost frame at the end of the string.
473        """
474        return self.meta.get("stack_trace", None)
475
476    @stack_trace.setter
477    def stack_trace(self, trace : Optional[str]) -> None:
478        self.meta["stack_trace"] = trace
479
480    def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None:
481        """
482        This API is internal. Do *not* call it directly.
483        """
484        def update_users_and_input_nodes(n: Any) -> Any:
485            if isinstance(n, Node):
486                self._input_nodes.setdefault(n)
487                n.users.setdefault(self)
488            return n
489
490        # Clear prior users and input_nodes
491        for old_use in self._input_nodes.keys():
492            old_use.users.pop(self)
493        self._input_nodes = {}
494
495        # We do three things in a single pass of the args
496        # - Normalize list->immutable_list, dict->immutable_dict, etc
497        # - Populate self._input_nodes
498        # - Populate arg.users[self] for each arg
499        self._args = map_aggregate(new_args, update_users_and_input_nodes)  # type: ignore[assignment]
500        self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes)  # type: ignore[assignment]
501
502    def __repr__(self) -> str:
503        if self._repr_fn:
504            return self._repr_fn(self)
505        return self.name
506
507    def _pretty_print_target(self, target: object) -> str:
508        """
509        Make target printouts more user-friendly.
510        1) builtins will be printed as `builtins.xyz`
511        2) operators will be printed as `operator.xyz`
512        3) other callables will be printed with qualified name, e.g. torch.add
513        """
514        if isinstance(target, str):
515            return target
516        if hasattr(target, '__module__'):
517            name = getattr(target, '__name__', None)
518            if name is None:
519                # Just to be defensive, if we don't have `__name__`, get the
520                # qualname. Not sure if this happens for any members of `operator`
521                # or `builtins`. This fallback path is not as good, since e.g.
522                # things in `operator` have `_operator` as their __module__.
523                # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__`
524                return _get_qualified_name(target)  # type: ignore[arg-type]
525            if target.__module__ == 'builtins':
526                return f'builtins.{name}'
527            elif target.__module__ == '_operator':
528                return f'operator.{name}'
529        return _get_qualified_name(target)  # type: ignore[arg-type]
530
531    @compatibility(is_backward_compatible=True)
532    def format_node(self,
533                    placeholder_names: Optional[List[str]] = None,
534                    maybe_return_typename: Optional[List[str]] = None) -> Optional[str]:
535        """
536        Return a descriptive string representation of ``self``.
537
538        This method can be used with no arguments as a debugging
539        utility.
540
541        This function is also used internally in the ``__str__`` method
542        of ``Graph``. Together, the strings in ``placeholder_names``
543        and ``maybe_return_typename`` make up the signature of the
544        autogenerated ``forward`` function in this Graph's surrounding
545        GraphModule. ``placeholder_names`` and ``maybe_return_typename``
546        should not be used otherwise.
547
548        Args:
549            placeholder_names: A list that will store formatted strings
550                representing the placeholders in the generated
551                ``forward`` function. Internal use only.
552            maybe_return_typename: A single-element list that will store
553                a formatted string representing the output of the
554                generated ``forward`` function. Internal use only.
555
556        Returns:
557            str: If 1) we're using ``format_node`` as an internal helper
558                in the ``__str__`` method of ``Graph``, and 2) ``self``
559                is a placeholder Node, return ``None``. Otherwise,
560                return a  descriptive string representation of the
561                current Node.
562        """
563        if self.op == 'placeholder':
564            assert isinstance(self.target, str)
565            arg_str = self.target
566            arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else ''
567            if placeholder_names:
568                placeholder_names.append(arg_str)
569                return None
570            maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
571            default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
572            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
573        elif self.op == 'get_attr':
574            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
575            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
576                   f'{self.op}[target={self._pretty_print_target(self.target)}]'
577        elif self.op == 'output':
578            if self.type and maybe_return_typename:
579                maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
580            return f'return {self.args[0]}'
581        else:
582            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
583            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
584                   f'{self.op}[target={self._pretty_print_target(self.target)}](' \
585                   f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
586
587    @compatibility(is_backward_compatible=True)
588    def replace_all_uses_with(self,
589                              replace_with: 'Node',
590                              delete_user_cb: Callable[['Node'], bool] = lambda user: True,
591                              *,
592                              propagate_meta: bool = False
593                              ) -> List['Node']:
594        """
595        Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
596
597        Args:
598
599            replace_with (Node): The node to replace all uses of ``self`` with.
600            delete_user_cb (Callable): Callback that is called to determine
601              whether a given user of the self node should be removed.
602            propagate_meta (bool): Whether or not to copy all properties
603              on the .meta field of the original node onto the replacement node.
604              For safety, this is only valid to do if the replacement node
605              doesn't already have an existing .meta field.
606
607        Returns:
608
609            The list of Nodes on which this change was made.
610        """
611        if propagate_meta:
612            assert len(replace_with.meta) == 0, \
613                'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \
614                'but replace_with already has .meta keys'
615            for k, v in self.meta.items():
616                replace_with.meta[k] = v
617        to_process = list(self.users)
618        skipped = []
619        m = self.graph.owning_module
620        for use_node in to_process:
621            if not delete_user_cb(use_node):
622                skipped.append(use_node)
623                continue
624
625            def maybe_replace_node(n : Node) -> Node:
626                if n == self:
627                    return replace_with
628                else:
629                    return n
630
631            if getattr(m, "_replace_hook", None):
632                m._replace_hook(old=self, new=replace_with.name, user=use_node)
633
634            new_args = map_arg(use_node.args, maybe_replace_node)
635            new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
636            assert isinstance(new_args, tuple)
637            assert isinstance(new_kwargs, dict)
638            use_node.__update_args_kwargs(new_args, new_kwargs)
639
640        assert len(self.users) - len(skipped) == 0
641        return [n for n in to_process if n not in skipped]
642
643    @compatibility(is_backward_compatible=False)
644    def is_impure(self) -> bool:
645        """
646        Returns whether this op is impure, i.e. if its op is a placeholder or
647        output, or if a call_function or call_module which is impure.
648
649        Returns:
650
651            bool: If the op is impure or not.
652        """
653        if self.op in {"placeholder", "output"}:
654            return True
655
656        # Check if an impure function based on schema.
657        if self.op == "call_function":
658            schema = getattr(self.target, "_schema", None)
659            schema_mutable = schema is not None and schema.is_mutable
660            return schema_mutable or self.target in _side_effectful_functions
661
662        # Check if an impure module.
663        if self.op == "call_module":
664            assert (
665                self.graph.owning_module is not None
666            ), "self.graph.owning_module not set for purity check"
667            target_mod = self.graph.owning_module.get_submodule(self.target)
668            assert (
669                target_mod is not None
670            ), f"Did not find expected submodule target {self.target}"
671            return getattr(target_mod, "_is_impure", False)
672
673        return False
674
675    @compatibility(is_backward_compatible=False)
676    def normalized_arguments(
677            self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
678            kwarg_types : Optional[Dict[str, Any]] = None,
679            normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
680        """
681        Returns normalized arguments to Python targets. This means that
682        `args/kwargs` will be matched up to the module/functional's
683        signature and return exclusively kwargs in positional order
684        if `normalize_to_only_use_kwargs` is true.
685        Also populates default values. Does not support positional-only
686        parameters or varargs parameters.
687
688        Supports module calls.
689
690        May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
691
692        Args:
693            root (torch.nn.Module): Module upon which to resolve module targets.
694            arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
695            kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
696            normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
697
698        Returns:
699
700            Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
701        """
702        if self.op == 'call_function':
703            assert callable(self.target)
704            return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types)  # type: ignore[arg-type]
705        elif self.op == 'call_module':
706            assert isinstance(self.target, str)
707            return normalize_module(root, self.target, self.args, self.kwargs)  # type: ignore[arg-type]
708
709        return None
710
711    @compatibility(is_backward_compatible=True)
712    def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None:
713        """
714        Loop through input nodes of ``self``, and replace all instances of
715        ``old_input`` with ``new_input``.
716
717        Args:
718
719            old_input (Node): The old input node to be replaced.
720            new_input (Node): The new input node to replace ``old_input``.
721        """
722        def maybe_replace_node(n : Node) -> Node:
723            return new_input if n == old_input else n
724
725        m = self.graph.owning_module
726        if getattr(m, "_replace_hook", None):
727            m._replace_hook(old=old_input, new=new_input.name, user=self)
728
729        new_args = map_arg(self.args, maybe_replace_node)
730        new_kwargs = map_arg(self.kwargs, maybe_replace_node)
731        assert isinstance(new_args, tuple)
732        assert isinstance(new_kwargs, dict)
733        self.__update_args_kwargs(new_args, new_kwargs)
734
735    def _rename(self, candidate: str) -> None:
736        if candidate == self.name:
737            return
738        name = self.graph._graph_namespace.create_name(candidate, None)
739        self.name = name
740        self.graph._graph_namespace._rename_object(self, name)
741
742    def __setattr__(self, name: str, value: Any) -> None:
743        if name == 'name' and hasattr(self, "name"):
744            m = self.graph.owning_module
745            if getattr(m, "_replace_hook", None):
746                assert isinstance(value, str)
747                for user in self.users:
748                    m._replace_hook(old=self, new=value, user=user)
749        update = False
750        if (
751                hasattr(self, name) and
752                hasattr(self.graph, "_find_nodes_lookup_table") and
753                self in self.graph._find_nodes_lookup_table
754        ):
755            update = True
756            self.graph._find_nodes_lookup_table.remove(self)
757        object.__setattr__(self, name, value)
758        if update:
759            self.graph._find_nodes_lookup_table.insert(self)
760
761@compatibility(is_backward_compatible=True)
762def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
763    """
764    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
765    """
766    assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
767    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
768
769@compatibility(is_backward_compatible=True)
770def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
771    """
772    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
773    """
774    if isinstance(a, tuple):
775        t = tuple([map_aggregate(elem, fn) for elem in a])
776        # Support NamedTuple (if it has `_fields`) by repacking into original type.
777        return t if not hasattr(a, '_fields') else type(a)(*t)  # type: ignore[arg-type]
778    elif isinstance(a, list):
779        return immutable_list([map_aggregate(elem, fn) for elem in a])
780    elif isinstance(a, dict):
781        rv = immutable_dict()
782        for k, v in a.items():
783            dict.__setitem__(rv, k, map_aggregate(v, fn))
784        return rv
785    elif isinstance(a, slice):
786        return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
787    else:
788        return fn(a)
789