xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import operator
4import warnings
5from collections import namedtuple
6from dataclasses import dataclass
7from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
8
9import torch
10import torch.nn as nn
11from torch.ao.quantization import QConfigAny, QuantType
12from torch.ao.quantization.backend_config import DTypeWithConstraints
13from torch.ao.quantization.fake_quantize import (
14    FakeQuantizeBase,
15    FixedQParamsFakeQuantize,
16)
17from torch.ao.quantization.observer import (
18    _is_activation_post_process,
19    FixedQParamsObserver,
20    ObserverBase,
21)
22from torch.ao.quantization.qconfig import (
23    float16_dynamic_qconfig,
24    float16_static_qconfig,
25    qconfig_equals,
26)
27from torch.ao.quantization.qconfig_mapping import QConfigMapping
28from torch.ao.quantization.stubs import DeQuantStub
29from torch.ao.quantization.utils import (
30    _assert_and_get_unique_device,
31    activation_is_statically_quantized,
32)
33from torch.fx import GraphModule, map_arg
34from torch.fx.graph import Graph, Node
35
36# importing the lib so that the quantized_decomposed ops are registered
37from ._decomposed import quantized_decomposed_lib  # noqa: F401
38from .custom_config import PrepareCustomConfig
39
40
41# TODO: revisit this list. Many helper methods shouldn't be public
42__all__ = [
43    "all_node_args_except_first",
44    "all_node_args_have_no_tensors",
45    "assert_and_get_unique_device",
46    "collect_producer_nodes",
47    "create_getattr_from_value",
48    "create_node_from_old_node_preserve_meta",
49    "EMPTY_ARG_DICT",
50    "get_custom_module_class_keys",
51    "get_linear_prepack_op_for_dtype",
52    "get_new_attr_name_with_prefix",
53    "get_non_observable_arg_indexes_and_types",
54    "get_qconv_prepack_op",
55    "get_skipped_module_name_and_classes",
56    "graph_module_from_producer_nodes",
57    "maybe_get_next_module",
58    "NodeInfo",
59    "node_arg_is_bias",
60    "node_arg_is_weight",
61    "NON_OBSERVABLE_ARG_DICT",
62    "NON_QUANTIZABLE_WEIGHT_OPS",
63    "return_arg_list",
64    "ObservedGraphModuleAttrs",
65]
66
67NON_QUANTIZABLE_WEIGHT_OPS = {
68    torch.nn.functional.layer_norm,
69    torch.nn.functional.group_norm,
70    torch.nn.functional.instance_norm,
71}
72
73
74@dataclass
75class ObservedGraphModuleAttrs:
76    node_name_to_qconfig: Dict[str, QConfigAny]
77    node_name_to_scope: Dict[str, Tuple[str, type]]
78    prepare_custom_config: PrepareCustomConfig
79    equalization_node_name_to_qconfig: Dict[str, Any]
80    qconfig_mapping: QConfigMapping
81    is_qat: bool
82    observed_node_names: Set[str]
83    is_observed_standalone_module: bool = False
84    standalone_module_input_quantized_idxs: Optional[List[int]] = None
85    standalone_module_output_quantized_idxs: Optional[List[int]] = None
86
87
88def node_arg_is_weight(node: Node, arg: Any) -> bool:
89    """Returns if node arg is weight"""
90    weight_index = None
91    if "target_dtype_info" in node.meta:
92        weight_index = node.meta["target_dtype_info"].get("weight_index", None)
93    if (
94        weight_index is not None
95        and weight_index < len(node.args)
96        and node.args[weight_index] is arg
97    ):
98        return True
99    return node.kwargs.get("weight") is arg
100
101
102def node_arg_is_bias(node: Node, arg: Any) -> bool:
103    """Returns if node arg is bias"""
104    bias_index = None
105    if "target_dtype_info" in node.meta:
106        bias_index = node.meta["target_dtype_info"].get("bias_index", None)
107    if (
108        bias_index is not None
109        and bias_index < len(node.args)
110        and node.args[bias_index] is arg
111    ):
112        return True
113    return node.kwargs.get("bias") is arg
114
115
116def get_custom_module_class_keys(
117    custom_module_mapping: Dict[QuantType, Dict[Type, Type]]
118) -> List[Any]:
119    r"""Get all the unique custom module keys in the custom config dict
120    e.g.
121    Input:
122    {
123        QuantType.STATIC: {
124            CustomModule1: ObservedCustomModule
125        },
126        QuantType.DYNAMIC: {
127            CustomModule2: DynamicObservedCustomModule
128        },
129        QuantType.WEIGHT_ONLY: {
130            CustomModule3: WeightOnlyObservedCustomModule
131        },
132    }
133
134    Output:
135    # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
136    [CustomModule1, CustomModule2, CustomModule3]
137    """
138    # using set to dedup
139    float_custom_module_classes: Set[Any] = set()
140    for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
141        quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
142        quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
143        float_custom_module_classes |= quant_mode_custom_module_classes
144    return list(float_custom_module_classes)
145
146
147def get_linear_prepack_op_for_dtype(dtype):
148    if dtype == torch.float16:
149        return torch.ops.quantized.linear_prepack_fp16
150    elif dtype == torch.qint8:
151        return torch.ops.quantized.linear_prepack
152    else:
153        raise Exception("can't get linear prepack op for dtype:", dtype)  # noqa: TRY002
154
155
156def get_qconv_prepack_op(conv_op: Callable) -> Callable:
157    prepack_ops = {
158        torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
159        torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
160        torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
161        torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
162        torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
163        torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
164    }
165    prepack_op = prepack_ops.get(conv_op, None)
166    assert prepack_op, f"Didn't find prepack op for {conv_op}"
167    return prepack_op
168
169
170# Returns a function that can get a new attribute name for module with given
171# prefix, for example,
172# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
173# >> new_name = get_new_observer_name(module)
174# new_name will be an unused attribute name on module, e.g. `_observer_1`
175def get_new_attr_name_with_prefix(prefix: str) -> Callable:
176    prefix = prefix.replace(".", "_")
177
178    def get_new_attr_name(module: torch.nn.Module):
179        def get_attr_name(i: int):
180            return prefix + str(i)
181
182        i = 0
183        attr_name = get_attr_name(i)
184        while hasattr(module, attr_name):
185            i += 1
186            attr_name = get_attr_name(i)
187        return attr_name
188
189    return get_new_attr_name
190
191
192def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
193    r"""Starting from a target node, trace back until we hit inpu or
194    getattr node. This is used to extract the chain of operators
195    starting from getattr to the target node, for example
196    def forward(self, x):
197      observed = self.observer(self.weight)
198      return F.linear(x, observed)
199    collect_producer_nodes(observed) will either return a list of nodes that
200    produces the observed node or None if we can't extract a self contained
201    graph without free variables(inputs of the forward function).
202    """
203    nodes = [node]
204    frontier = [node]
205    while frontier:
206        node = frontier.pop()
207        all_args = list(node.args) + list(node.kwargs.values())
208        for arg in all_args:
209            if not isinstance(arg, Node):
210                continue
211            if arg.op == "placeholder":
212                # hit input, can't fold in this case
213                return None
214            nodes.append(arg)
215            if not (arg.op == "call_function" and arg.target == getattr):
216                frontier.append(arg)
217    return nodes
218
219
220def graph_module_from_producer_nodes(
221    root: GraphModule, producer_nodes: List[Node]
222) -> GraphModule:
223    r"""Construct a graph module from extracted producer nodes
224    from `collect_producer_nodes` function
225    Args:
226      root: the root module for the original graph
227      producer_nodes: a list of nodes we use to construct the graph
228    Return:
229      A graph module constructed from the producer nodes
230    """
231    assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
232    # since we traced back from node to getattr
233    producer_nodes.reverse()
234    graph = Graph()
235    env: Dict[Any, Any] = {}
236
237    def load_arg(a):
238        return map_arg(a, lambda node: env[node])
239
240    for producer_node in producer_nodes:
241        env[producer_node] = graph.node_copy(producer_node, load_arg)
242    graph.output(load_arg(producer_nodes[-1]))
243    graph_module = GraphModule(root, graph)
244    return graph_module
245
246
247# TODO: delete
248def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
249    """
250    Returns the unique device for a module, or None if no device is found.
251    Throws an error if multiple devices are detected.
252    """
253    return _assert_and_get_unique_device(module)
254
255
256def create_getattr_from_value(
257    module: torch.nn.Module, graph: Graph, prefix: str, value: Any
258) -> Node:
259    """
260    Given a value of any type, creates a getattr node corresponding to the value and
261    registers the value as a buffer to the module.
262    """
263    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
264    attr_name = get_new_attr_name(module)
265    device = assert_and_get_unique_device(module)
266    new_value = (
267        value.clone().detach()
268        if isinstance(value, torch.Tensor)
269        else torch.tensor(value, device=device)
270    )
271    module.register_buffer(attr_name, new_value)
272    # Create get_attr with value
273    attr_node = graph.create_node("get_attr", attr_name)
274    return attr_node
275
276
277def all_node_args_have_no_tensors(
278    node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]
279) -> bool:
280    """
281    If we know for sure that all of this node's args have no
282    tensors (are primitives), return True.  If we either
283    find a tensor or are not sure, return False. Note: this
284    function is not exact.
285    """
286    if cache and node in cache:
287        return cache[node]
288
289    result = False  # will be overwritten
290    if not isinstance(node, Node):
291        result = True
292    elif node.op == "placeholder":
293        result = False
294    elif node.op == "call_module":
295        assert isinstance(node.target, str)
296        if _is_activation_post_process(modules[node.target]):
297            result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
298    elif node.op == "call_module":
299        result = False
300    elif node.op == "call_function" and node.target is operator.getitem:
301        result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
302    elif node.op == "get_attr":
303        result = False
304    elif node.target is getattr and node.args[1] in ["ndim", "shape"]:
305        # x1 = x0.ndim
306        result = True
307    elif node.op == "call_method" and node.target == "size":
308        # x1 = x0.size(0)
309        result = True
310    else:
311        found_one_tensor = False
312        for arg in node.args:
313            if isinstance(arg, list):
314                for list_el in arg:
315                    if isinstance(list_el, Node):
316                        this_list_el_args_have_no_tensors = (
317                            all_node_args_have_no_tensors(list_el, modules, cache)
318                        )
319                        found_one_tensor = found_one_tensor or (
320                            not this_list_el_args_have_no_tensors
321                        )
322                        # If found_one_tensor is True, there is no point in
323                        # recursing further as the end result will always
324                        # be True.
325                        # TODO(future PR): remove this entire function  and
326                        # change to dtype inference without recursion.
327                        if found_one_tensor:
328                            result = not found_one_tensor
329                            if cache:
330                                cache[node] = result
331                            return result
332            elif isinstance(arg, int):
333                pass
334            else:
335                if isinstance(arg, Node):
336                    this_arg_args_have_no_tensors = all_node_args_have_no_tensors(
337                        arg, modules, cache
338                    )
339                    found_one_tensor = found_one_tensor or (
340                        not this_arg_args_have_no_tensors
341                    )
342                    # If found_one_tensor is True, there is no point in
343                    # recursing further as the end result will always
344                    # be True.
345                    # TODO(future PR): remove this entire function  and
346                    # change to dtype inference without recursion.
347                    if found_one_tensor:
348                        result = not found_one_tensor
349                        if cache:
350                            cache[node] = result
351                        return result
352                else:
353                    found_one_tensor = True
354            result = not found_one_tensor
355    if cache:
356        cache[node] = result
357    return result
358
359
360def all_node_args_except_first(node: Node) -> List[int]:
361    """
362    Returns all node arg indices after first
363    """
364    return list(range(1, len(node.args)))
365
366
367def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
368    """
369    Constructs a function that takes a node as arg and returns the arg_indices
370    that are valid for node.args
371    """
372
373    def arg_indices_func(node: Node) -> List[int]:
374        return [i for i in arg_indices if i < len(node.args)]
375
376    return arg_indices_func
377
378
379NodeInfo = namedtuple("NodeInfo", "op target")
380
381# this dict identifies which indices of a node are non tensors
382# so that they can be propagated correctly since inserting observers
383# for them would cause errors
384
385NON_OBSERVABLE_ARG_DICT: Dict[
386    NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]
387] = {
388    NodeInfo("call_method", "masked_fill"): {
389        torch.bool: return_arg_list([1]),
390        float: return_arg_list([2]),
391    },
392    NodeInfo("call_method", "permute"): {int: all_node_args_except_first},
393    NodeInfo("call_method", "repeat"): {int: all_node_args_except_first},
394    NodeInfo("call_method", "reshape"): {int: all_node_args_except_first},
395    NodeInfo("call_method", "size"): {int: return_arg_list([1])},
396    NodeInfo("call_method", "transpose"): {int: all_node_args_except_first},
397    NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first},
398    NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])},
399    NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])},
400    NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])},
401    NodeInfo("call_method", "view"): {int: all_node_args_except_first},
402}
403
404EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
405
406
407def get_non_observable_arg_indexes_and_types(
408    node: Node,
409) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
410    """
411    Returns a dict with of non float tensor types as keys and values which correspond to a
412    function to retrieve the list (which takes the node as an argument)
413    """
414    info = NodeInfo(node.op, node.target)
415
416    return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
417
418
419def maybe_get_next_module(
420    node: Node,
421    modules: Dict[str, nn.Module],
422    target_module_type: Optional[Type[nn.Module]] = None,
423    target_functional_type: Any = None,
424) -> Optional[Node]:
425    """Gets the next module that matches what is needed in
426    is_target_module_type if it exists
427
428    Args:
429        node: The node whose users we want to look at
430        target_module_type: Module type that we want to check
431        target_functional_type: Functional type that we want to check
432    """
433
434    for user in node.users.keys():
435        if (
436            user.op == "call_module"
437            and target_module_type is not None
438            and isinstance(modules[str(user.target)], target_module_type)
439        ):
440            return user
441        elif (
442            user.op == "call_function"
443            and target_functional_type is not None
444            and user.target == target_functional_type
445        ):
446            return user
447
448    return None
449
450
451def create_node_from_old_node_preserve_meta(
452    quantized_graph: Graph,
453    create_node_args: Tuple[Any, ...],
454    old_node: Node,
455) -> Node:
456    """
457    Creates `new_node` and copies the necessary metadata to it from `old_node`.
458    """
459    new_node = quantized_graph.create_node(*create_node_args)
460    new_node.stack_trace = old_node.stack_trace
461    return new_node
462
463
464def get_skipped_module_name_and_classes(
465    prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool
466) -> Tuple[List[str], List[Type[Any]]]:
467    skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
468    skipped_module_classes = copy.copy(
469        prepare_custom_config.non_traceable_module_classes
470    )
471    if not is_standalone_module:
472        # standalone module and custom module config are applied in top level module
473        skipped_module_names += list(
474            prepare_custom_config.standalone_module_names.keys()
475        )
476        skipped_module_classes += list(
477            prepare_custom_config.standalone_module_classes.keys()
478        )
479        skipped_module_classes += get_custom_module_class_keys(
480            prepare_custom_config.float_to_observed_mapping
481        )
482
483    return skipped_module_names, skipped_module_classes
484
485
486def _is_custom_module_lstm(
487    node: Node,
488    named_modules: Dict[str, torch.nn.Module],
489    qconfig: QConfigAny = None,
490    # QuantizeHandler, but we cannot include the type here due to circular imports
491    qhandler: Optional[Any] = None,
492) -> bool:
493    """
494    Return whether this refers to the custom module LSTM flow.
495    """
496    mod = _get_module(node, named_modules)
497    if qconfig is not None and qhandler is not None:
498        assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler)  # type: ignore[attr-defined]
499        return (
500            isinstance(mod, torch.nn.LSTM)
501            and activation_is_statically_quantized(qconfig)
502            and qhandler.is_custom_module()
503        )
504    else:
505        return isinstance(mod, torch.ao.nn.quantizable.LSTM)
506
507
508def _is_custom_module_mha(
509    node: Node,
510    named_modules: Dict[str, torch.nn.Module],
511    qconfig: QConfigAny = None,
512    # QuantizeHandler, but we cannot include the type here due to circular imports
513    qhandler: Optional[Any] = None,
514) -> bool:
515    """
516    Return whether this refers to the custom module MultiheadAttention flow.
517    """
518    mod = _get_module(node, named_modules)
519    if qconfig is not None and qhandler is not None:
520        assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler)  # type: ignore[attr-defined]
521        return (
522            isinstance(mod, torch.nn.MultiheadAttention)
523            and activation_is_statically_quantized(qconfig)
524            and qhandler.is_custom_module()
525        )
526    else:
527        return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
528
529
530def _get_module(
531    node: Node, named_modules: Dict[str, torch.nn.Module]
532) -> Optional[torch.nn.Module]:
533    """
534    If `node` refers to a call_module node, return the module, else None.
535    """
536    if node.op == "call_module" and str(node.target) in named_modules:
537        return named_modules[str(node.target)]
538    else:
539        return None
540
541
542def _insert_dequant_stub(
543    node: Node,
544    model: torch.nn.Module,
545    named_modules: Dict[str, torch.nn.Module],
546    graph: Graph,
547) -> Node:
548    """
549    Attach a `DeQuantStub` to the model and create a node that calls this
550    `DeQuantStub` on the output of `node`, similar to how observers are inserted.
551    """
552    prefix = "dequant_stub_"
553    get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
554    dequant_stub_name = get_new_dequant_stub_name(model)
555    dequant_stub = DeQuantStub()
556    setattr(model, dequant_stub_name, dequant_stub)
557    named_modules[dequant_stub_name] = dequant_stub
558    with graph.inserting_after(node):
559        return graph.call_module(dequant_stub_name, (node,))
560
561
562def _insert_dequant_stubs_for_custom_module_lstm_output(
563    node: Node,
564    model: torch.nn.Module,
565    named_modules: Dict[str, torch.nn.Module],
566    graph: Graph,
567) -> Node:
568    """
569    Insert DeQuantStubs after each internal output node of custom module LSTM.
570
571    Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
572    Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
573    components through `getitem`. This function transforms the graph as follows:
574
575      (1) Split the LSTM node into (output, (hidden0, hidden1))
576      (2) Insert a DeQuantStub after each internal node
577      (3) Recombine the DeQuantStubs into the same structure as before
578      (4) Reroute all consumers of the original LSTM node and its sub-nodes
579          (e.g. lstm[0])
580
581    Before:
582                   lstm_output
583                        |
584                        v
585                  original_user(s)
586    After:
587                   lstm_output
588                  /           \\
589                 /  (getitem)  \\
590                /               \\
591               v                 v
592             output            hidden
593               |               /   \\
594         (DeQuantStub)        (getitem)
595               |             /       \\
596               v            v         v
597           output_dq     hidden0    hidden1
598               |            |         |
599               |    (DeQuantStub) (DeQuantStub)
600               |            |         |
601               |            v         v
602               |      hidden0_dq  hidden1_dq
603               |            \\       /
604               |              (tuple)
605               |              \\   /
606               |               v  v
607               |             hidden_dq
608               \\               /
609                \\   (tuple)   /
610                 v            v
611                 lstm_output_dq
612                       |
613                       v
614                original_user(s)
615
616    For step (4), reroute all users of the original LSTM node(s) as follows:
617      lstm_output -> lstm_output_dq
618      lstm_output[0] -> output_dq
619      lstm_output[1] -> hidden_dq
620      lstm_output[1][0] -> hidden0_dq
621      lstm_output[1][1] -> hidden1_dq
622
623    Return the node `lstm_output_dq`.
624    """
625    # (1) Split the LSTM node into (output, (hidden0, hidden1))
626    # (2) Insert a DeQuantStub after each internal node
627    with graph.inserting_after(node):
628        output = graph.call_function(operator.getitem, (node, 0))
629        output_dq = _insert_dequant_stub(output, model, named_modules, graph)
630    with graph.inserting_after(output_dq):
631        hidden = graph.call_function(operator.getitem, (node, 1))
632    with graph.inserting_after(hidden):
633        hidden0 = graph.call_function(operator.getitem, (hidden, 0))
634        hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
635    with graph.inserting_after(hidden0_dq):
636        hidden1 = graph.call_function(operator.getitem, (hidden, 1))
637        hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
638
639    # (3) Recombine the DeQuantStubs into the same structure as before
640    with graph.inserting_after(hidden1_dq):
641        hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
642    with graph.inserting_after(hidden_dq):
643        lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
644
645    # (4) Reroute all consumers of the original LSTM node and its sub-nodes
646    for user in list(node.users.keys()):
647        if user != output and user != hidden:
648            user.replace_input_with(node, lstm_output_dq)
649    # The getitem and tuple nodes we added here may interfere with reference quantized
650    # pattern matching, so we need to redirect the consumers of internal nodes to the
651    # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
652    # in order to preserve reference patterns like "dequantize - consumer - quantize".
653    _reroute_tuple_getitem_pattern(graph)
654    return lstm_output_dq
655
656
657def _maybe_get_custom_module_lstm_from_node_arg(
658    arg: Node,
659    named_modules: Dict[str, torch.nn.Module],
660) -> Optional[Node]:
661    """
662    Given an argument of a node, if the argument refers to the path through which the node
663    is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
664
665    This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
666    skip inserting input observers for this node. This is because custom module LSTM produces
667    quantized outputs, so inserting an input observer for the consumer of custom module LSTM
668    would unnecessarily quantize the outputs again.
669
670      lstm -> consumer
671
672    In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
673    DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
674    This tuple can be consumed in one of four ways:
675
676      lstm -> getitem -> DeQuantStub -> consumer                       # consume lstm[0]
677      lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer   # consume lstm[1]
678      lstm -> getitem -> getitem -> DeQuantStub -> consumer            # consume lstm[1][0] or lstm[1][1]
679      lstm -> getitem -> DeQuantStub -> tuple -> consumer              # consume lstm
680
681    Thus, we must match against the above patterns instead of simply checking the parent node
682    to determine whether this node is a consumer of a custom module LSTM.
683    """
684
685    def match_dq(a):
686        return isinstance(_get_module(a, named_modules), DeQuantStub)
687
688    def match_lstm(a):
689        return _is_custom_module_lstm(a, named_modules)
690
691    def match_getitem(a):
692        return a.op == "call_function" and a.target == operator.getitem
693
694    def match_tuple(a):
695        return a.op == "call_function" and a.target == tuple
696
697    def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
698        """
699        Traverse up the graph and match the args one by one.
700        If there is a match, return the last matched node, or None otherwise.
701        """
702        a = arg
703        for i, match in enumerate(match_pattern):
704            if not match(a):
705                return None
706            # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
707            if i < len(match_pattern) - 1:
708                if match == match_tuple:
709                    a = a.args[0][0]  # type: ignore[assignment,index]
710                else:
711                    a = a.args[0]  # type: ignore[assignment]
712        return a
713
714    all_match_patterns = [
715        [match_dq, match_getitem, match_lstm],
716        [match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
717        [match_dq, match_getitem, match_getitem, match_lstm],
718        [match_tuple, match_dq, match_getitem, match_lstm],
719    ]
720
721    for p in all_match_patterns:
722        matched_node = _match_pattern(p)
723        if matched_node is not None:
724            return matched_node
725    return None
726
727
728def _reroute_tuple_getitem_pattern(graph: Graph):
729    """
730    Search for patterns where N consecutive `tuple` call_function nodes are followed by
731    N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
732    If we find this pattern, reroute the consumers of the last `getitem` to skip these
733    N `tuple` and `getitem` nodes.
734
735    Before:
736
737        a   b     c
738        |   \\   /
739        \\   tuple
740         \\   /
741          tuple
742            |
743        getitem(1)
744            |
745        getitem(0)
746            |
747            d
748
749    After:
750
751        b
752        |
753        d
754    """
755
756    def find_patterns(
757        node: Node,
758        index_stack: List[int],
759        current_pattern: List[Node],
760        matched_patterns: List[List[Node]],
761        seen: Set[Tuple[Node, Tuple[int, ...]]],
762    ):
763        """
764        Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
765        starting at the given node.
766
767        We use a stack to keep track of the expected `getitem` indices, since these are
768        reversed from the `tuple` indices. In the above example, the stack after
769        (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
770        and then by getitem(0).
771
772        TODO: traverse upwards from the output and handle the case when tuple is not a
773        separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
774        """
775        if len(index_stack) == 0 and len(current_pattern) > 0:
776            matched_patterns.append(copy.copy(current_pattern))
777            current_pattern.clear()
778
779        # Avoid duplicating work
780        state = (node, tuple(index_stack))
781        if state in seen:
782            return
783        seen.add(state)
784
785        # Iterate through users of this node to find tuple/getitem nodes to match
786        for user in node.users:
787            if user.op == "call_function" and user.target == tuple:
788                for i, user_arg in enumerate(user.args[0]):  # type: ignore[arg-type]
789                    if user_arg == node:
790                        index_stack.append(i)
791                        current_pattern.append(user)
792                        find_patterns(
793                            user, index_stack, current_pattern, matched_patterns, seen
794                        )
795            elif user.op == "call_function" and user.target == operator.getitem:
796                if len(index_stack) > 0:
797                    if user.args[1] == index_stack[-1]:
798                        index_stack.pop()
799                        current_pattern.append(user)
800                        find_patterns(
801                            user, index_stack, current_pattern, matched_patterns, seen
802                        )
803        return matched_patterns
804
805    # Collect all matched patterns
806    matched_patterns: List[List[Node]] = []
807    seen: Set[Tuple[Node, Tuple[int, ...]]] = set()  # (node, index_stack)
808    for node in graph.nodes:
809        find_patterns(node, [], [], matched_patterns, seen)
810
811    # For each pattern, redirect all consumers of the last getitem node to the correct input
812    # of the first tuple node
813    for pattern in matched_patterns:
814        first_tuple = pattern[0]
815        last_getitem = pattern[-1]
816        assert first_tuple.op == "call_function" and first_tuple.target == tuple
817        assert (
818            last_getitem.op == "call_function"
819            and last_getitem.target == operator.getitem
820        )
821        last_getitem_index = last_getitem.args[1]
822        new_input = first_tuple.args[0][last_getitem_index]  # type: ignore[index]
823        for user in list(last_getitem.users.keys()):
824            user.replace_input_with(last_getitem, new_input)  # type: ignore[arg-type]
825
826
827def _get_observer_from_activation_post_process(
828    activation_post_process: Union[ObserverBase, FakeQuantizeBase],
829) -> ObserverBase:
830    """
831    If `activation_post_process` is an observer, return the observer.
832    If `activation_post_process` is a fake quantize, return the internal observer.
833    """
834    if isinstance(activation_post_process, ObserverBase):
835        return activation_post_process
836    else:
837        assert isinstance(activation_post_process, FakeQuantizeBase)
838        return activation_post_process.activation_post_process  # type: ignore[return-value]
839
840
841def _qconfig_satisfies_dtype_config_constraints(
842    qconfig: QConfigAny,
843    dtype_with_constraints: DTypeWithConstraints,
844    is_activation: bool = True,
845) -> bool:
846    """
847    Return whether `qconfig` satisfies the following constraints from the backend,
848    specified through the activation and weight DTypeWithConstraints.
849
850        1. QConfig specified a quantization range that falls within the backend's, if any
851        2. QConfig specified a min scale value that is >= the backend's, if any
852        3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
853           scale and zero point that match the backend's, if any
854
855    If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
856    If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
857    """
858
859    # TODO: log warnings only when the user enabled a debug flag
860    def _activation_post_process_satisfies_dtype_config_constraints(
861        activation_post_process: Union[ObserverBase, FakeQuantizeBase],
862        dtype_with_constraints: DTypeWithConstraints,
863        debug_string: str,
864    ) -> bool:
865        observer = _get_observer_from_activation_post_process(activation_post_process)
866        app_quant_min = getattr(observer, "quant_min", None)
867        app_quant_max = getattr(observer, "quant_max", None)
868        # TODO: for now, just use the existing eps value as scale_min. In the future, we should
869        # resolve the differences between the two, either by renaming eps or some other way
870        app_scale_min = getattr(observer, "eps", None)
871        backend_quant_min = dtype_with_constraints.quant_min_lower_bound
872        backend_quant_max = dtype_with_constraints.quant_max_upper_bound
873        backend_scale_min = dtype_with_constraints.scale_min_lower_bound
874        backend_scale_exact_match = dtype_with_constraints.scale_exact_match
875        backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
876        # check quantization ranges
877        if backend_quant_min is not None and backend_quant_max is not None:
878            if app_quant_min is None or app_quant_max is None:
879                warnings.warn(
880                    f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}"
881                )
882                return False
883            elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
884                warnings.warn(
885                    f"QConfig {debug_string} quantization range must fall within the backend's:\n"
886                    f"QConfig range = ({app_quant_min}, {app_quant_max}), "
887                    f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
888                    f"ignoring {qconfig}"
889                )
890                return False
891        # check scale min
892        if backend_scale_min is not None:
893            if app_scale_min is None:
894                warnings.warn(
895                    f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}"
896                )
897                return False
898            if app_scale_min < backend_scale_min:
899                warnings.warn(
900                    f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
901                    f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
902                )
903                return False
904        # check fixed scale and zero point
905        if (
906            backend_scale_exact_match is not None
907            and backend_zero_point_exact_match is not None
908        ):
909            # For tests only, accept the following qconfigs for now
910            # TODO: handle fp16 qconfigs properly
911            for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
912                if qconfig_equals(qconfig, accepted_qconfig):
913                    return True
914            suggestion_str = (
915                "Please use torch.ao.quantization.get_default_qconfig_mapping or "
916                "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
917                '    qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n'
918                "    model = prepare_fx(model, qconfig_mapping, example_inputs)"
919            )
920            if not isinstance(
921                activation_post_process, FixedQParamsObserver
922            ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize):
923                warnings.warn(
924                    f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
925                    f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
926                )
927                return False
928            if (
929                observer.scale != backend_scale_exact_match
930                or observer.zero_point != backend_zero_point_exact_match
931            ):
932                warnings.warn(
933                    f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
934                    f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
935                    f"ignoring {qconfig}.\n{suggestion_str}"
936                )
937                return False
938        return True
939
940    if qconfig is None or dtype_with_constraints.dtype is None:
941        return True
942
943    activation_post_process_ctr = (
944        qconfig.activation if is_activation else qconfig.weight
945    )
946    debug_string = "activation" if is_activation else "weight"
947    satisfies_constraints = True
948    if activation_post_process_ctr is not None:
949        activation_post_process = activation_post_process_ctr()
950        assert _is_activation_post_process(activation_post_process)
951        # If dtypes don't match, don't check the activation_post_process and return True early
952        if activation_post_process.dtype != dtype_with_constraints.dtype:
953            return True
954        satisfies_constraints = (
955            _activation_post_process_satisfies_dtype_config_constraints(
956                activation_post_process, dtype_with_constraints, debug_string
957            )
958        )
959    return satisfies_constraints
960