xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/graph_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3
4import torch
5from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map
6from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
7from torch.ao.quantization.observer import _is_activation_post_process
8from torch.fx import GraphModule, map_arg
9from torch.fx.graph import Graph, Node
10
11from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph
12from .utils import (
13    get_arg_indices_of_inputs_to_log,
14    get_node_first_input_and_output_type,
15    get_node_input_qparams,
16    get_normalized_nth_input,
17    get_number_of_non_param_args,
18    get_target_type_str,
19    getattr_from_fqn,
20    NodeInputOrOutputType,
21    op_type_supports_shadowing,
22    return_first_non_observer_node,
23)
24
25
26def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
27    fqn = None
28    if hasattr(gm, "_node_name_to_scope"):
29        # fqn on observers is not present, because they do not
30        # exist when the fqns are created during tracing. If this is
31        # an observer, get the fqn of the node being observed.
32        node_to_use_for_fqn = node
33        if node.op == "call_module":
34            assert isinstance(node.target, str)
35            module = getattr_from_fqn(gm, node.target)
36            if _is_activation_post_process(module):
37                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
38        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0]  # type: ignore[index]
39    return fqn  # type: ignore[return-value]
40
41
42def _insert_logger_after_node(
43    node: Node,
44    gm: GraphModule,
45    logger_cls: Callable,
46    logger_node_name_suffix: str,
47    ref_node_name: str,
48    model_name: str,
49    ref_name: str,
50    ref_node_target_type: str,
51    results_type: str,
52    index_within_arg: int,
53    index_of_arg: int,
54    fqn: Optional[str],
55) -> Node:
56    """
57    Given a starting graph of
58
59    prev_node -> node -> next_node
60
61    This function creates a new logger_cls obj and adds it
62    after node, resulting in
63
64    prev_node -> node -> logger_obj -> next_node
65    """
66    # create new name
67    logger_node_name = get_new_attr_name_with_prefix(
68        node.name + logger_node_name_suffix
69    )(gm)
70    target_type = get_target_type_str(node, gm)
71    # create the logger object
72    logger_obj = logger_cls(
73        ref_node_name,
74        node.name,
75        model_name,
76        ref_name,
77        target_type,
78        ref_node_target_type,
79        results_type,
80        index_within_arg,
81        index_of_arg,
82        fqn,
83    )
84    # attach the logger object to the parent module
85    setattr(gm, logger_node_name, logger_obj)
86    logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {})
87    return logger_node
88
89
90def add_loggers_to_model(
91    gm: GraphModule,
92    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
93    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
94    logger_cls: Callable,
95    model_name: str,
96) -> GraphModule:
97    """
98    Takes the graph of gm, adds loggers to the output
99    of each node in nodes_to_instrument. Returns a GraphModule with the new
100    graph.
101    """
102
103    new_graph = Graph()
104    env: Dict[str, Any] = {}
105    modules = dict(gm.named_modules())
106
107    def load_arg(a):
108        return map_arg(a, lambda node: env[node.name])
109
110    for node in gm.graph.nodes:
111        if node.op == "output":
112            new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
113            continue
114
115        if (node in node_to_instrument_inputs_to_ref_node_name) or (
116            node in node_to_instrument_outputs_to_ref_node_name
117        ):
118            fqn = _maybe_get_fqn(node, gm)
119
120            if node in node_to_instrument_inputs_to_ref_node_name:
121                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
122                    node
123                ]
124                # Ops such add and mul are special because either
125                # one or two of the first two arguments can be tensors,
126                # and if one argument is a tensor it can be first or
127                # second (x + 1 versus 1 + x).
128                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
129                for node_arg_idx in arg_indices_to_log:
130                    node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
131                    if type(node_arg) == Node:
132                        # create a single input logger
133                        prev_node = env[node_arg.name]
134                        env[node_arg.name] = _insert_logger_after_node(
135                            prev_node,
136                            gm,
137                            logger_cls,
138                            "_ns_logger_",
139                            node.name,
140                            model_name,
141                            ref_name,
142                            ref_node_type,
143                            NSSingleResultValuesType.NODE_INPUT.value,
144                            index_within_arg=0,
145                            index_of_arg=node_arg_idx,
146                            fqn=fqn,
147                        )
148                    elif (
149                        type(node_arg) == torch.fx.immutable_collections.immutable_list
150                    ):
151                        # create N input loggers, one for each node
152                        for arg_idx, arg in enumerate(node_arg):  # type: ignore[var-annotated, arg-type]
153                            prev_node = env[arg.name]
154                            env[prev_node.name] = _insert_logger_after_node(
155                                prev_node,
156                                gm,
157                                logger_cls,
158                                "_ns_logger_",
159                                node.name,
160                                model_name,
161                                ref_name,
162                                ref_node_type,
163                                NSSingleResultValuesType.NODE_INPUT.value,
164                                index_within_arg=arg_idx,
165                                index_of_arg=node_arg_idx,
166                                fqn=fqn,
167                            )
168                    else:
169                        pass
170
171            # ensure env is populated with base node
172            # Note: runs for both inputs and outputs
173            env[node.name] = new_graph.node_copy(node, load_arg)
174
175            if node in node_to_instrument_outputs_to_ref_node_name:
176                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
177                    node
178                ]
179                # add the logger after the base node
180                env[node.name] = _insert_logger_after_node(
181                    env[node.name],
182                    gm,
183                    logger_cls,
184                    "_ns_logger_",
185                    node.name,
186                    model_name,
187                    ref_name,
188                    ref_node_type,
189                    NSSingleResultValuesType.NODE_OUTPUT.value,
190                    index_within_arg=0,
191                    index_of_arg=0,
192                    fqn=fqn,
193                )
194
195        else:
196            env[node.name] = new_graph.node_copy(node, load_arg)
197
198    new_gm = GraphModule(gm, new_graph)
199    return new_gm
200
201
202def _insert_quantize_per_tensor_node(
203    prev_node_c: Node,
204    node_a: Node,
205    gm_b: GraphModule,
206    graph_c: Graph,
207    scale: Union[torch.Tensor, float],
208    zero_point: Union[torch.Tensor, int],
209    dtype_cast_name: str,
210) -> Node:
211    # copy scale
212    scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b)
213    setattr(gm_b, scale_node_name, scale)
214    scale_node = graph_c.create_node(
215        "get_attr", scale_node_name, (), {}, scale_node_name
216    )
217    # copy zero_point
218    zero_point_node_name = get_new_attr_name_with_prefix(
219        node_a.name + "_input_zero_point_"
220    )(gm_b)
221    setattr(gm_b, zero_point_node_name, zero_point)
222    zero_point_node = graph_c.create_node(
223        "get_attr", zero_point_node_name, (), {}, zero_point_node_name
224    )
225    # create the quantize_per_tensor call
226    return graph_c.create_node(
227        "call_function",
228        torch.quantize_per_tensor,
229        (prev_node_c, scale_node, zero_point_node, torch.quint8),
230        {},
231        dtype_cast_name,
232    )
233
234
235def _insert_dtype_cast_after_node(
236    node_a: Node,
237    node_c: Node,
238    prev_node_c: Union[Node, List[Node]],
239    gm_a: GraphModule,
240    gm_b: GraphModule,
241    graph_c: Graph,
242    node_name_prefix: str,
243    logger_cls: Callable,
244    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
245) -> Union[Node, List[Node]]:
246    """
247    Given a starting graph C (derived from graph B) of
248
249    ... -> prev_node_c -> node_c -> ...
250
251    And a corresponding related node_a, inserts the correct dtype
252    cast node after prev_node_c to cast into the dtype expected
253    by node_a, resulting in:
254
255                          dtype_cast
256                        /
257    ... -> prev_node_c -> node_c -> ...
258
259    For example, if node_c is an int8 op and node_a is an fp32 op, this function
260    will insert a dequant.
261    """
262    dtype_cast_op = None
263    dtype_cast_mod_cls = None
264    dtype_cast_method = None
265    dtype_cast_method_dtype = None
266    dtype_cast_scale = None
267    dtype_cast_zero_point = None
268    node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type(
269        node_a, gm_a, logger_cls, node_type_to_io_type_map
270    )
271    node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type(
272        node_c, gm_b, logger_cls, node_type_to_io_type_map
273    )
274
275    if (
276        (
277            node_input_type_a == NodeInputOrOutputType.FP32
278            and node_input_type_c == NodeInputOrOutputType.INT8
279        )
280        or (
281            node_input_type_a == NodeInputOrOutputType.FP32
282            and node_input_type_c == NodeInputOrOutputType.FP16
283        )
284        or
285        # TODO(future PR): determine the actual dtype of node_c,
286        # the current code only works because dequantize works with
287        # multiple input dtypes.
288        (
289            node_input_type_a == NodeInputOrOutputType.FP32
290            and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8
291        )
292    ):
293        dtype_cast_op = torch.dequantize
294    elif (
295        node_input_type_a == node_input_type_c
296        and node_input_type_a != NodeInputOrOutputType.UNKNOWN
297    ):
298        dtype_cast_mod_cls = torch.nn.Identity
299    elif (
300        node_input_type_a == NodeInputOrOutputType.INT8
301        and node_input_type_c == NodeInputOrOutputType.FP32
302    ):
303        # int8 shadows fp32, the dtype cast needs to quantize to int8
304        # with the right qparams.
305        node_a_input_qparams = get_node_input_qparams(
306            node_a, gm_a, node_type_to_io_type_map
307        )
308        if node_a_input_qparams is not None:
309            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
310            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
311    elif (
312        node_input_type_a == NodeInputOrOutputType.FP16
313        and node_input_type_c == NodeInputOrOutputType.FP32
314    ):
315        dtype_cast_method = "to"
316        dtype_cast_method_dtype = torch.float16
317    else:
318        raise AssertionError(
319            f"dtype cast from {node_input_type_c} {node_c.format_node()} to "
320            + f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
321        )
322
323    if isinstance(prev_node_c, Node):
324        new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
325        if dtype_cast_op:
326            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
327                return _insert_quantize_per_tensor_node(
328                    prev_node_c,
329                    node_a,
330                    gm_b,
331                    graph_c,
332                    dtype_cast_scale,
333                    dtype_cast_zero_point,
334                    new_dtype_cast_name,
335                )
336            else:
337                return graph_c.create_node(
338                    "call_function",
339                    dtype_cast_op,
340                    (prev_node_c,),
341                    {},
342                    new_dtype_cast_name,
343                )
344        elif dtype_cast_method:
345            return graph_c.create_node(
346                "call_method",
347                dtype_cast_method,
348                (prev_node_c, dtype_cast_method_dtype),
349                {},
350                new_dtype_cast_name,
351            )
352        else:
353            assert dtype_cast_mod_cls
354            dtype_cast_mod = dtype_cast_mod_cls()
355            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
356            return graph_c.create_node(
357                "call_module",
358                new_dtype_cast_name,
359                (prev_node_c,),
360                {},
361                new_dtype_cast_name,
362            )
363    elif isinstance(prev_node_c, list):
364        results = []
365        for prev_node_c_inner in prev_node_c:
366            new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
367            if dtype_cast_op:
368                # TODO(future PR): add handling for quantize_per_tensor
369                new_dtype_cast_node = graph_c.create_node(
370                    "call_function",
371                    dtype_cast_op,
372                    (prev_node_c_inner,),
373                    {},
374                    new_dtype_cast_name,
375                )
376                results.append(new_dtype_cast_node)
377            else:
378                assert dtype_cast_mod_cls
379                dtype_cast_mod = dtype_cast_mod_cls()
380                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
381                new_dtype_cast_node = graph_c.create_node(
382                    "call_module",
383                    new_dtype_cast_name,
384                    (prev_node_c_inner,),
385                    {},
386                    new_dtype_cast_name,
387                )
388                results.append(new_dtype_cast_node)
389        return results
390    else:
391        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
392
393
394# TODO(future PR): look into using copy_node API instead
395def _copy_node_from_a_to_c(
396    node_a: Node,
397    gm_a: GraphModule,
398    gm_b: GraphModule,
399    graph_c: Graph,
400) -> Node:
401    """
402    Simple copy of node_a to graph_c.
403    """
404    if node_a.op == "get_attr":
405        node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")(
406            gm_b
407        )
408        node_a_obj = getattr_from_fqn(gm_a, node_a.target)  # type: ignore[arg-type]
409        if torch.is_tensor(node_a_obj):
410            node_a_obj = node_a_obj.detach()
411        setattr(gm_b, node_a_copy_name, node_a_obj)
412        node_a_copy = graph_c.create_node(
413            node_a.op, node_a_copy_name, (), {}, node_a_copy_name
414        )
415        return node_a_copy
416    elif node_a.op == "call_method":
417        assert node_a.target in (
418            "dequantize",
419            "to",
420        ), f"target {node_a.target} is not implemented"
421        if node_a.target == "dequantize":
422            arg_copy = _copy_node_from_a_to_c(
423                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
424            )  # type: ignore[arg-type]
425            node_a_copy_name = get_new_attr_name_with_prefix(
426                node_a.name + "_shadow_copy_"
427            )(gm_b)
428            node_a_copy = graph_c.create_node(
429                node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name
430            )
431            return node_a_copy
432        else:  # to
433            arg_copy = _copy_node_from_a_to_c(
434                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
435            )  # type: ignore[arg-type]
436            node_a_copy_name = get_new_attr_name_with_prefix(
437                node_a.name + "_shadow_copy_"
438            )(gm_b)
439            node_a_copy = graph_c.create_node(
440                node_a.op,
441                node_a.target,
442                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
443                {},
444                node_a_copy_name,
445            )
446            return node_a_copy
447
448    else:
449        raise AssertionError(
450            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
451        )
452
453
454def _can_insert_copy_of_subgraph_a(
455    subgraph_a: NSSubgraph,
456    gm_a: GraphModule,
457    num_non_param_args_node_a: int,
458) -> bool:
459    """
460    This function returns `False` if the input subgraph cannot be copied by
461    `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
462    that there is a corner case logic for which copy is not yet implemented.
463    """
464    # populate the list of nodes we need to check
465    nodes = []
466    cur_node = subgraph_a.end_node
467    while cur_node != subgraph_a.start_node:
468        nodes.append(cur_node)
469        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
470    nodes.append(cur_node)
471    nodes.reverse()
472
473    def _can_insert(node_a_arg, gm_a):
474        if isinstance(node_a_arg, Node):
475            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
476            if arg_a.op == "call_method":
477                return arg_a.target in ("dequantize", "to")
478            elif arg_a.op == "get_attr":
479                return True
480            else:
481                return False
482        elif isinstance(node_a_arg, (list, tuple)):
483            for el in node_a_arg:
484                if not isinstance(el, Node):
485                    return False
486        return True
487
488    # For each node, check if we handle the copy behavior. This follows the
489    # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
490    for node_a in nodes:
491        local_num_non_param_args_node_a = (
492            num_non_param_args_node_a if node_a is nodes[0] else 1
493        )
494
495        norm_args_kwargs = node_a.normalized_arguments(
496            gm_a, normalize_to_only_use_kwargs=True
497        )
498        if norm_args_kwargs is not None:
499            norm_args, norm_kwargs = norm_args_kwargs
500        else:
501            norm_args, norm_kwargs = node_a.args, node_a.kwargs
502
503        cur_idx = 0
504
505        while cur_idx < len(norm_args):
506            if cur_idx == 0:
507                pass
508            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
509                pass
510            else:
511                if not _can_insert(norm_args[cur_idx], gm_a):
512                    return False
513            cur_idx += 1
514
515        for kwarg_val in norm_kwargs.values():
516            # stitch the inputs from base graph
517            if cur_idx == 0:
518                pass
519            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
520                pass
521            else:
522                if not _can_insert(kwarg_val, gm_a):
523                    return False
524            cur_idx += 1
525
526    return True
527
528
529def _insert_copy_of_subgraph_a_after_input_node_c(
530    input_node_c: Union[Node, List[Node]],
531    input_node_c_2: Optional[Union[Node, List[Node]]],
532    subgraph_a: NSSubgraph,
533    gm_a: GraphModule,
534    gm_b: GraphModule,
535    node_name_prefix: str,
536) -> Node:
537    """
538    TODO(before land): real docblock
539    """
540    if isinstance(input_node_c, Node):
541        graph_c = input_node_c.graph
542    else:
543        assert isinstance(input_node_c, list)
544        graph_c = input_node_c[0].graph
545
546    # create a sequential list of the subgraphs' nodes from start to end,
547    # because we need to add the nodes to graph C in non-reverse order
548    nodes_of_a = [subgraph_a.end_node]
549    cur_node = subgraph_a.end_node
550    while cur_node != subgraph_a.start_node:
551        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
552        nodes_of_a.insert(0, cur_node)
553
554    # go through nodes of a in order, and insert them into the graph of c
555    # sequentially
556    cur_node_a = nodes_of_a[0]
557    cur_node_c = _insert_copy_of_node_a_after_input_node_c(
558        input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix
559    )
560    for cur_idx_a in range(1, len(nodes_of_a)):
561        cur_node_a = nodes_of_a[cur_idx_a]
562        prev_node_c = cur_node_c  # previous added node is the input to next node
563        cur_node_c = _insert_copy_of_node_a_after_input_node_c(
564            prev_node_c,
565            # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
566            None,
567            cur_node_a,
568            gm_a,
569            gm_b,
570            node_name_prefix,
571        )
572    # return the last inserted node
573    return cur_node_c
574
575
576def _insert_copy_of_node_a_after_input_node_c(
577    input_node_c: Union[Node, List[Node]],
578    input_node_c_2: Optional[Union[Node, List[Node]]],
579    node_a: Node,
580    gm_a: GraphModule,
581    gm_b: GraphModule,
582    node_name_prefix: str,
583) -> Node:
584    """
585    Assume that node_a from graph_a has
586      args (input, (input2)?, arg1, ...), and
587      kwargs {kw0: kwarg0, ...}
588
589    Note: input2 is optional. If it equals to None, we assume that the op
590    has a single non-param input.  If it is specified, we assume that the op
591    has two non-param inputs.
592
593    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
594    and creates the corresponding nodes in graph_c. Note: observers are ignored,
595    so if an arg is an observer we navigate up until we find a non-observer parent.
596
597    If node_a is a call_module, points the module pointed to by node_a to gm_b.
598
599    Creates the copy of node_a in graph_c, with input as the first arg,
600    and all other args and kwargs pointing to the copies of the objects
601    in gm_b created above.
602
603    An example in pictures:
604
605    graph A:
606    ========
607
608    input -------------> node_a
609                         / / /
610    (input_2)?----------/ / /
611                         / /
612    weight -> weight_obs  /
613                         /
614    bias ----------------
615
616    graph C (derived from B):
617    =========================
618
619    input_node_c --> node_a_copy
620                     / / /
621    (input_node_c_2)? / /
622                     / /
623    weight_copy ----/ /
624                     /
625    bias_copy ------/
626    """
627    if isinstance(input_node_c, Node):
628        graph_c = input_node_c.graph
629    else:
630        assert isinstance(input_node_c, list)
631        graph_c = input_node_c[0].graph
632
633    norm_args_kwargs = node_a.normalized_arguments(
634        gm_a, normalize_to_only_use_kwargs=True
635    )
636    if norm_args_kwargs is not None:
637        norm_args, norm_kwargs = norm_args_kwargs
638    else:
639        norm_args, norm_kwargs = node_a.args, node_a.kwargs
640
641    new_args = []
642    new_kwargs = {}
643
644    def _copy_arg(arg):
645        # copy the other inputs from the other graph
646        if isinstance(arg, Node):
647            arg = return_first_non_observer_node(arg, gm_a)
648            arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
649            return arg
650        elif isinstance(arg, (int, float, torch.dtype)):
651            return arg
652        elif isinstance(kwarg_val, (list, tuple)):
653            for el in kwarg_val:
654                assert not isinstance(
655                    el, Node
656                ), "handling of Node inside list is not implemented"
657            return arg
658        else:
659            raise AssertionError(
660                f"handling for kwarg of type {type(kwarg_val)} is not implemented"
661            )
662
663    cur_idx = 0
664
665    while cur_idx < len(norm_args):
666        if cur_idx == 0:
667            new_arg = input_node_c
668        elif cur_idx == 1 and input_node_c_2 is not None:
669            new_arg = input_node_c_2
670        else:
671            new_arg = _copy_arg(norm_args[cur_idx])
672        new_args.append(new_arg)
673        cur_idx += 1
674
675    for kwarg_name, kwarg_val in norm_kwargs.items():
676        # stitch the inputs from base graph
677        if cur_idx == 0:
678            new_kwargs[kwarg_name] = input_node_c
679        elif cur_idx == 1 and input_node_c_2 is not None:
680            new_kwargs[kwarg_name] = input_node_c_2
681        else:
682            new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
683        cur_idx += 1
684
685    new_args = tuple(new_args)  # type: ignore[assignment]
686
687    node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
688
689    if node_a.op == "call_module":
690        # if target is a module, we point to the module from gm_b
691        new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
692        # fetch the corresponding module from gm_a
693        assert isinstance(node_a.target, str)
694        mod_a = getattr_from_fqn(gm_a, node_a.target)
695        setattr(gm_b, new_mod_copy_name, mod_a)
696        node_a_shadows_c = graph_c.create_node(
697            node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name  # type: ignore[arg-type]
698        )
699        return node_a_shadows_c
700    else:
701        assert node_a.op in ("call_function", "call_method")
702        node_a_shadows_c = graph_c.create_node(
703            node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name  # type: ignore[arg-type]
704        )
705        return node_a_shadows_c
706
707
708def create_a_shadows_b(
709    name_a: str,
710    gm_a: GraphModule,
711    name_b: str,
712    gm_b: GraphModule,
713    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
714    logger_cls: Callable,
715    should_log_inputs: bool,
716    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
717) -> GraphModule:
718    """
719    Creates a new GraphModule consisting of the graph of C, with the meaningful
720    nodes of A shadowing the corresponding nodes of B.  For example,
721
722    Graph A:
723    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
724
725    Graph B:
726    b0 -> op0_int8 -> b1 -> op1_int8 -> b2
727
728    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
729
730    Graph C (A shadows B):
731
732        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
733       /                                     /
734    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
735
736    In a nutshell, this function does the following for each node pair:
737    * copies the necessary attributes and modules from gm_a to gm_b,
738      keeping names unique
739    * adds a dtype cast op (dequant, quant, etc)
740    * adds a copy of node_a in gm_b's graph
741    * adds loggers to the outputs of node_a and node_b
742    """
743
744    if node_type_to_io_type_map is None:
745        node_type_to_io_type_map = get_node_type_to_io_type_map()
746
747    # graph_c is the graph created from copying the nodes of graph_b and inserting
748    # the shadows with the nodes copied from graph_a
749    graph_c = Graph()
750    env_c: Dict[str, Any] = {}
751    modules = dict(gm_b.named_modules())
752
753    def load_arg(a):
754        return map_arg(a, lambda node: env_c[node.name])
755
756    start_node_b_to_matched_subgraph_a_and_name = {}
757    end_node_b_to_matched_subgraph_a_and_name = {}
758    for match_name, match in matched_subgraph_pairs.items():
759        subgraph_a, subgraph_b = match
760        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
761        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
762        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = (
763            subgraph_a,
764            match_name,
765            ref_node_type_a,
766            ref_node_type_b,
767        )
768        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = (
769            subgraph_a,
770            match_name,
771            ref_node_type_a,
772            ref_node_type_b,
773        )
774
775    for node_b in gm_b.graph.nodes:
776        if node_b.op == "output":
777            graph_c.output(map_arg(node_b.args[0], load_arg))
778            continue
779
780        # calculate the flags to determine what to do with this node
781        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
782        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
783
784        if node_b_is_start_node or node_b_is_end_node:
785            if node_b_is_start_node:
786                (
787                    subgraph_a,
788                    ref_name,
789                    ref_node_type_a,
790                    ref_node_type_b,
791                ) = start_node_b_to_matched_subgraph_a_and_name[node_b]
792            else:
793                assert node_b_is_end_node
794                (
795                    subgraph_a,
796                    ref_name,
797                    ref_node_type_a,
798                    ref_node_type_b,
799                ) = end_node_b_to_matched_subgraph_a_and_name[node_b]
800
801            all_op_types_support_shadowing = op_type_supports_shadowing(
802                subgraph_a.start_node
803            ) and op_type_supports_shadowing(node_b)
804            if not all_op_types_support_shadowing:
805                print(
806                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
807                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
808                    + ", unsupported"
809                )
810                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
811                continue
812
813            # For both start_node and end_node verify that we know how to do
814            # the dtype cast. If we do not, skip.
815            (
816                node_input_type_a,
817                node_output_type_a,
818            ) = get_node_first_input_and_output_type(
819                subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map
820            )
821            (
822                node_input_type_b,
823                node_output_type_b,
824            ) = get_node_first_input_and_output_type(
825                node_b, gm_b, logger_cls, node_type_to_io_type_map
826            )
827            node_io_types_known_a_and_b = (
828                node_input_type_a != NodeInputOrOutputType.UNKNOWN
829                and node_output_type_a != NodeInputOrOutputType.UNKNOWN
830                and node_input_type_b != NodeInputOrOutputType.UNKNOWN
831                and node_output_type_b != NodeInputOrOutputType.UNKNOWN
832            )
833            if not node_io_types_known_a_and_b:
834                print(
835                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
836                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
837                    + ", unknown dtype cast"
838                )
839                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
840                continue
841
842            # If we are shadowing from fp32 to int8, we need to insert
843            # quantize_per_tensor call with qparams from the previous node.
844            # Only do this if we are able to infer these qparams from the graph.
845            if (
846                node_input_type_a == NodeInputOrOutputType.INT8
847                and node_input_type_b == NodeInputOrOutputType.FP32
848            ):
849                node_a_input_qparams = get_node_input_qparams(
850                    subgraph_a.start_node, gm_a, node_type_to_io_type_map
851                )
852                if not node_a_input_qparams:
853                    print(
854                        f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
855                        + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
856                        + ", unknown input qparams"
857                    )
858                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
859                    continue
860
861            num_non_param_args_node_a = get_number_of_non_param_args(
862                subgraph_a.start_node, gm_a
863            )
864            if not _can_insert_copy_of_subgraph_a(
865                subgraph_a, gm_a, num_non_param_args_node_a
866            ):
867                print(
868                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
869                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
870                    + ", unhandled logic in subgraph copy"
871                )
872                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
873                continue
874
875            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
876            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)  # type: ignore[possibly-undefined]
877
878            if node_b_is_start_node:
879                # if necessary, log the input of node_c
880                if should_log_inputs:
881                    prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
882                    if isinstance(prev_node_b, Node):
883                        prev_node_c = env_c[prev_node_b.name]
884                        env_c[prev_node_c.name] = _insert_logger_after_node(
885                            prev_node_c,
886                            gm_b,
887                            logger_cls,
888                            "_ns_logger_b_inp_",
889                            node_b.name,
890                            name_b,
891                            ref_name,
892                            ref_node_type_b,
893                            NSSingleResultValuesType.NODE_INPUT.value,
894                            index_within_arg=0,
895                            index_of_arg=0,
896                            fqn=fqn_base_b,
897                        )
898                    elif isinstance(prev_node_b, list):
899                        # first, save the prev_node instances, because they
900                        # will be overwritten in the env after the first logger
901                        # is added
902                        prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
903
904                        for arg_idx, arg in enumerate(prev_node_b):
905                            prev_node_c = prev_node_c_list[arg_idx]
906                            env_c[prev_node_c.name] = _insert_logger_after_node(
907                                prev_node_c,
908                                gm_b,
909                                logger_cls,
910                                "_ns_logger_b_inp_",
911                                node_b.name,
912                                name_b,
913                                ref_name,
914                                ref_node_type_b,
915                                NSSingleResultValuesType.NODE_INPUT.value,
916                                index_within_arg=arg_idx,
917                                index_of_arg=0,
918                                fqn=fqn_base_b,
919                            )
920                    else:
921                        # logging of inputs which are not lists is not supported yet
922                        raise AssertionError(
923                            f"type {type(prev_node_b)} is not handled yet"
924                        )
925                # subgraph so far:
926                #
927                # (prev_node_c)+ -> (logger_c_input)?
928
929            # Note: this if statement is always True, spelling it out to clarify code
930            # intent.
931            if node_b_is_start_node or node_b_is_end_node:
932                # ensure env_c is populated with base node
933                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
934                node_c = env_c[node_b.name]
935
936                # after this point,
937                #
938                # node_a is the original node from graph_a, with parent module gm_a
939                # node_b is the original node from graph_b, with parent module gm_b
940                # node_c is the copy of node_b in graph_c
941                #
942                # subgraph so far:
943                #
944                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
945
946            if node_b_is_start_node:
947                # cast dtype from the dtype of node_c's input to the dtype of
948                # node_a's input (dequant, etc)
949                # prev_node_c = node_c.args[0]
950                prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)  # type: ignore[possibly-undefined]
951                if should_log_inputs:
952                    # skip the input logger when inserting a dtype cast
953                    if isinstance(prev_node_c, Node):
954                        prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
955                    elif isinstance(prev_node_c, list):
956                        prev_node_c = [
957                            get_normalized_nth_input(arg, gm_b, 0)
958                            for arg in prev_node_c
959                        ]
960                dtype_cast_node = _insert_dtype_cast_after_node(
961                    subgraph_a.start_node,
962                    node_c,
963                    prev_node_c,
964                    gm_a,
965                    gm_b,
966                    graph_c,
967                    node_b.name + "_dtype_cast_",
968                    logger_cls,
969                    node_type_to_io_type_map,
970                )
971                # note: not inserting to env_c because all nodes which use the dtype
972                #   casts are copied from graph_a
973                #
974                # subgraph so far:
975                #
976                #           (dtype_cast_node)+
977                #                  /
978                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
979
980                # if input logging is enabled, log the input to the subgraph
981                if should_log_inputs:
982                    # TODO: explain this
983                    ref_node_name = ""
984                    if isinstance(dtype_cast_node, Node):
985                        dtype_cast_node = _insert_logger_after_node(
986                            dtype_cast_node,
987                            gm_b,
988                            logger_cls,
989                            "_ns_logger_a_inp_",
990                            ref_node_name,
991                            name_a,
992                            ref_name,
993                            ref_node_type_a,
994                            NSSingleResultValuesType.NODE_INPUT.value,
995                            index_within_arg=0,
996                            index_of_arg=0,
997                            fqn=fqn_base_a,
998                        )
999                        input_logger: Union[Node, List[Node]] = dtype_cast_node
1000                    else:
1001                        assert isinstance(dtype_cast_node, list)
1002                        new_loggers = []
1003                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(
1004                            dtype_cast_node
1005                        ):
1006                            dtype_cast_logger = _insert_logger_after_node(
1007                                dtype_cast_node_inner,
1008                                gm_b,
1009                                logger_cls,
1010                                "_ns_logger_a_inp_",
1011                                ref_node_name,
1012                                name_a,
1013                                ref_name,
1014                                ref_node_type_a,
1015                                NSSingleResultValuesType.NODE_INPUT.value,
1016                                index_within_arg=dtype_cast_idx,
1017                                index_of_arg=0,
1018                                fqn=fqn_base_a,
1019                            )
1020                            new_loggers.append(dtype_cast_logger)
1021                        dtype_cast_node = new_loggers
1022                        input_logger = dtype_cast_node
1023                    # subgraph so far:
1024                    #
1025                    #       (dtype_cast_node)+ -> (logger_a_input)?
1026                    #                  /
1027                    # prev_node_c -> (logger_c_input)? -> node_start_c
1028
1029                # hook up the new mod_a copy to be in the graph, receiving the
1030                # same inputs as mod_b does, with dtype cast to match a
1031                # Some ops, such as LSTMs, have two non-param inputs. If we have
1032                # such an op, pass the second param as well. Note: dtype casting
1033                # for the second param is not implemented yet, it can be added
1034                # later if there is a use case.
1035                node_c_second_non_param_arg = None
1036                num_non_param_args_node_a = get_number_of_non_param_args(
1037                    subgraph_a.start_node, gm_a
1038                )
1039                if num_non_param_args_node_a == 2:
1040                    # node_c_second_non_param_arg = node_c.args[1]
1041                    node_c_second_non_param_arg = get_normalized_nth_input(
1042                        node_c, gm_b, 1
1043                    )
1044                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
1045                    dtype_cast_node,
1046                    node_c_second_non_param_arg,
1047                    subgraph_a,
1048                    gm_a,
1049                    gm_b,
1050                    node_c.name + "_shadow_copy_",
1051                )
1052                env_c[node_a_shadows_c.name] = node_a_shadows_c
1053                # subgraph so far:
1054                #
1055                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
1056                #                  /
1057                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
1058
1059                if should_log_inputs:
1060                    # When we created the input logger, we left the ref_node_name
1061                    # as an empty string, because the subgraph copy did not exist
1062                    # yet. Now that the subgraph copy exists, we modify this name
1063                    # to its true value.
1064                    # Note: the alternative to this is to create the input logger
1065                    # after creating the subgraph, which is slightly more
1066                    # complicated. This is the lesser of two evils.
1067                    # input_logger = env_c[dtype_cast_node.name]
1068                    # Find the first node in the subgraph
1069                    cur_node = node_a_shadows_c
1070                    while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:  # type: ignore[possibly-undefined]
1071                        cur_node = get_normalized_nth_input(cur_node, gm_b, 0)  # type: ignore[assignment]
1072                    if isinstance(input_logger, Node):
1073                        input_logger_mod = getattr(gm_b, input_logger.name)
1074                        input_logger_mod.ref_node_name = cur_node.name
1075                    else:
1076                        assert isinstance(input_logger, list)
1077                        for input_logger_inner in input_logger:
1078                            input_logger_mod = getattr(gm_b, input_logger_inner.name)
1079                            input_logger_mod.ref_node_name = cur_node.name
1080
1081                # hook up a logger to the mod_a copy
1082                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
1083                    env_c[node_a_shadows_c.name],
1084                    gm_b,
1085                    logger_cls,
1086                    "_ns_logger_a_",
1087                    node_a_shadows_c.name,
1088                    name_a,
1089                    ref_name,
1090                    ref_node_type_a,
1091                    NSSingleResultValuesType.NODE_OUTPUT.value,
1092                    index_within_arg=0,
1093                    index_of_arg=0,
1094                    fqn=fqn_base_a,
1095                )
1096                # subgraph so far:
1097                #
1098                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
1099                #                  /
1100                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
1101
1102            if node_b_is_end_node:
1103                # hook up a logger to the mod_b copy
1104                env_c[node_b.name] = _insert_logger_after_node(
1105                    env_c[node_b.name],
1106                    gm_b,
1107                    logger_cls,
1108                    "_ns_logger_b_",
1109                    node_b.name,
1110                    name_b,
1111                    ref_name,
1112                    ref_node_type_b,
1113                    NSSingleResultValuesType.NODE_OUTPUT.value,
1114                    index_within_arg=0,
1115                    index_of_arg=0,
1116                    fqn=fqn_base_b,
1117                )
1118                # subgraph so far:
1119                #
1120                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
1121                #                  /
1122                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
1123                #
1124                # Note: node_start_c may be the same node as node_end_c, or they
1125                # may have nodes inbetween.
1126
1127        else:
1128            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
1129
1130    gm_c = GraphModule(gm_b, graph_c)
1131    return gm_c
1132