xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/prepare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, Optional, Tuple, Union
3
4import torch
5from torch._subclasses import FakeTensor
6from torch.ao.quantization import (
7    CUSTOM_KEY,
8    NUMERIC_DEBUG_HANDLE_KEY,
9    ObserverOrFakeQuantize,
10    QConfigMapping,
11)
12from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
13from torch.ao.quantization.fx.prepare import (
14    _create_obs_or_fq_from_qspec,
15    _insert_obs_or_fq,
16    _is_activation_post_process_node,
17    _save_state,
18)
19from torch.ao.quantization.qconfig import QConfigAny
20from torch.ao.quantization.quantizer import (
21    EdgeOrNode,
22    QuantizationSpecBase,
23    SharedQuantizationSpec,
24)
25from torch.fx import Graph, GraphModule, Node
26from torch.fx.node import Argument
27
28
29# TODO: make pt2e folder private?
30__all__ = [
31    "prepare",
32]
33
34
35def _find_root_edge_or_node(
36    edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
37) -> EdgeOrNode:
38    """Find the root node for the sharing tree
39    Args:
40        edge_or_node: edge/node that we want to find the root
41        shared_with_map: each edge/node points to the parent, the root node will points to itself
42
43    Returns:
44        root edge/node
45    """
46    parent = shared_with_map[edge_or_node]
47    if parent == edge_or_node:
48        return edge_or_node
49    root = _find_root_edge_or_node(parent, shared_with_map)
50    # path compression
51    shared_with_map[edge_or_node] = root
52    return root
53
54
55def _union(
56    parent: EdgeOrNode,
57    child: EdgeOrNode,
58    shared_with_map: Dict[EdgeOrNode, EdgeOrNode],
59) -> None:
60    """Merge the subtree for `child` with `parent`, the order is important here"""
61    root_parent = _find_root_edge_or_node(parent, shared_with_map)
62    root_child = _find_root_edge_or_node(child, shared_with_map)
63    # union the two trees by pointing the root of child to root of parent
64    shared_with_map[root_child] = root_parent
65
66
67def _update_shared_with(
68    child: EdgeOrNode,
69    qspec: QuantizationSpecBase,
70    shared_with_map: Dict[EdgeOrNode, EdgeOrNode],
71):
72    """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
73    configuration and established the relationship between `edge_or_node` with the edge/node that it
74    is pointing to, we'll use this information in the end to get the group id
75    """
76    if isinstance(qspec, SharedQuantizationSpec):
77        parent = qspec.edge_or_node
78        # we point from edge_or_node to the node that it is sharing_with, e.g.
79        # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
80        _union(parent, child, shared_with_map)
81
82
83def _unwrap_shared_qspec(
84    qspec: QuantizationSpecBase,
85    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
86    shared_with_map: Dict[EdgeOrNode, EdgeOrNode],
87) -> QuantizationSpecBase:
88    """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
89    if qspec is SharedQuantizationSpec
90       (1). tries to find the root edge or node for the node that the qspec points to
91       (2). recursively find the root qspec based on the qspec for the root node
92    """
93    if isinstance(qspec, SharedQuantizationSpec):
94        sharing_with = qspec.edge_or_node
95        root = _find_root_edge_or_node(sharing_with, shared_with_map)
96        qspec = edge_or_node_to_qspec[root]
97        return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
98    return qspec
99
100
101def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
102    return (
103        hasattr(qspec_a, "dtype")
104        and hasattr(qspec_b, "dtype")
105        and qspec_a.dtype == qspec_b.dtype
106    )
107
108
109def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
110    return (
111        hasattr(qspec_a, "is_dynamic")
112        and hasattr(qspec_b, "is_dynamic")
113        and qspec_a.is_dynamic == qspec_b.is_dynamic
114    )
115
116
117def _get_edge_or_node_to_qspec(
118    model: torch.fx.GraphModule,
119) -> Dict[EdgeOrNode, QuantizationSpecBase]:
120    """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
121    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
122    for n in model.graph.nodes:
123        if hasattr(n, "meta") and "quantization_annotation" in n.meta:
124            qa = n.meta["quantization_annotation"]
125            for input_to_n, qspec in qa.input_qspec_map.items():
126                input_edge = (input_to_n, n)
127                edge_or_node_to_qspec[input_edge] = qspec
128            if qa.output_qspec is not None:
129                output_node = n
130                qspec = qa.output_qspec
131                edge_or_node_to_qspec[output_node] = qspec
132    return edge_or_node_to_qspec
133
134
135def _union_input_edge_with(
136    input_edge,
137    input_edge_root_qspec,
138    edge_or_node,
139    edge_or_node_to_qspec,
140    shared_with_map,
141):
142    """Union input edge with another edge or node, used in implicit sharing to point the current input
143    edge to other user edges of the producer node, or the output of producer node since these are
144    referring to the same Tensor
145    """
146    root_qspec = None
147    if edge_or_node in edge_or_node_to_qspec:
148        qspec = edge_or_node_to_qspec[edge_or_node]
149        root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
150    # TODO: add assertions for types of root qspecs
151    if (
152        root_qspec is not None
153        and _has_same_dtype(root_qspec, input_edge_root_qspec)
154        and _has_same_is_dynamic(root_qspec, input_edge_root_qspec)
155    ):
156        # the input arg to the node should reuse the existing output observer for arg
157        # since dtype is the same (we may want to extend this to be a more strict check
158        # in the future)
159        # so we point from `input_edge` to `arg` (output of the argument)
160        _union(edge_or_node, input_edge, shared_with_map)
161
162
163def _get_edge_or_node_to_group_id(
164    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]
165) -> Dict[EdgeOrNode, int]:
166    """Map from edge/node to the group ID, generated from quantization annotations,
167    edge/node with the same group ID should use the same observer/fake_quant instance
168
169    This is applying SharedQuantizationSpec configuration and map each edge/node to a group
170    There is another implicit sharing that's built in the quantization, when we have the following:
171       * op1 -> op2
172       * output of op1: int8_qspec
173       * (op1 -> op2) input edge: int8_qspec
174    we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
175
176    Figuring out the correct group ID for all edge/node is a standard union find problem:
177    https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
178
179    Args:
180        edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
181    Returns:
182        edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
183        belongs to the same group should have the same id
184
185    Example:
186        op2 -> cat1 -> cat2
187           op1 /        /
188                     op3
189        edge_or_node_to_qspec: {
190            op1: int8_qspec,
191            op2: int8_qspec,
192            (op1, cat1): int8_qspc,
193            (op2, cat1): SharedQuantizationSpec((op1, cat1)),
194            cat1: SharedQuantizationSpec((op1, cat1)),
195            (op3, cat2): int8_qspec,
196            (cat1, cat2): SharedQuantizationSpec((op3, cat2)),
197            cat2: SharedQuantizationSpec((op3, cat2)),
198        }
199
200        edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
201        edge_or_node_to_group_id: {
202            op1: 1,
203            op2: 1,
204            (op1, cat1): 1,
205            (op2, cat1): 1,
206            cat1: 1,
207            (op3, cat2): 1,
208            (cat1, cat2): 1,
209            cat2: 1,
210        }
211        # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
212        # connects the two sharing group around cat1 and cat2 op due to transitive sharing
213    """
214    # means the observer of key should be shared with observer with value, by default it will
215    # be shared with itself
216    shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {
217        k: k for k in edge_or_node_to_qspec.keys()
218    }
219    for edge_or_node, qspec in edge_or_node_to_qspec.items():
220        if isinstance(edge_or_node, torch.fx.Node):
221            output_node = edge_or_node
222            _update_shared_with(output_node, qspec, shared_with_map)
223        else:
224            input_edge = edge_or_node
225            input_edge_root_qspec = _unwrap_shared_qspec(
226                qspec, edge_or_node_to_qspec, shared_with_map
227            )
228
229            assert isinstance(input_edge, tuple)
230            arg, n = input_edge
231            if n.meta["quantization_annotation"].allow_implicit_sharing:
232                # NOTE: the order is important here, we first share with other users and then share with previous
233                # output because the reverse order could cause circular dependency
234                # e.g node1 -> node2
235                #          \ -> node3
236                # when processing (node1, node2), if we first point (node1, node2) to node1
237                # Step 1. shared_map = {(node1, node2): node1}
238                # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
239                # which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
240                # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
241                # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
242                # have a circular dependency
243                # the following order works around this issue, but this does not allow arbitrary configuration
244                # of sharing so it might break in a different case in the future, when it breaks
245                # quantizer writer can check the notes here to debug the issue
246
247                # sharing with other users of the producer node
248                # (arg, user)
249                if not isinstance(arg, Node) or not isinstance(n, Node):
250                    raise Exception(  # noqa: TRY002
251                        f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}"
252                    )
253                for user in arg.users:
254                    if user is n:
255                        continue
256                    arg_to_user_edge = (arg, user)
257                    _union_input_edge_with(
258                        input_edge,
259                        input_edge_root_qspec,
260                        arg_to_user_edge,
261                        edge_or_node_to_qspec,
262                        shared_with_map,
263                    )
264
265                # sharing with output of producer node
266                _union_input_edge_with(
267                    input_edge,
268                    input_edge_root_qspec,
269                    arg,
270                    edge_or_node_to_qspec,
271                    shared_with_map,
272                )
273
274            _update_shared_with(input_edge, qspec, shared_with_map)
275
276    # now that we get the sharing relations between all edges and nodes, we can assingn group ids
277    cur_group_id = 0
278    edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
279    for edge_or_node in shared_with_map.keys():
280        root = _find_root_edge_or_node(edge_or_node, shared_with_map)
281        if root not in edge_or_node_to_group_id:
282            edge_or_node_to_group_id[root] = cur_group_id
283            cur_group_id += 1
284        edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
285
286    return edge_or_node_to_group_id
287
288
289def _get_obs_or_fq_map(
290    edge_or_node_to_group_id: Dict[EdgeOrNode, int],
291    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
292    is_qat: bool,
293) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
294    """Generates the EdgeOrNode to observer/fake_quant instances
295    Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
296    instances
297    """
298    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
299    group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
300    for edge_or_node, qspec in edge_or_node_to_qspec.items():
301        group_id = edge_or_node_to_group_id[edge_or_node]
302        if group_id not in group_id_to_obs_or_fq:
303            # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
304            # the implementation for _create_obs_or_fq_from_qspec
305            group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(
306                qspec, obs_or_fq_map, is_qat
307            )
308        obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
309    return obs_or_fq_map
310
311
312def _maybe_insert_input_observer_for_arg_or_kwarg(
313    node: Union[Node, Any],
314    arg: Argument,
315    qconfig: QConfigAny,
316    model: torch.nn.Module,
317    named_modules: Dict[str, torch.nn.Module],
318    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
319    is_qat: bool,
320) -> Argument:
321    """
322    Given a `node` and an `arg`, inserts an input observer between
323    `node` and `arg` if necessary.
324    """
325    # for ops such as torch.cat([x0, x1]),
326    # traverse through the list
327    if isinstance(arg, (list, tuple)):
328        new_arg_to_return = []
329        for inner_arg in arg:
330            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
331                node,
332                inner_arg,
333                qconfig,
334                model,
335                named_modules,
336                obs_or_fq_map,
337                is_qat,
338            )
339            new_arg_to_return.append(new_inner_arg)
340        return type(arg)(new_arg_to_return)
341
342    if not isinstance(arg, Node):
343        return arg
344    assert isinstance(arg, Node)
345    # default (no observer)
346    new_arg = arg
347
348    # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
349    original_arg = arg
350    while _is_activation_post_process_node(original_arg, named_modules):
351        original_arg = original_arg.args[0]  # type: ignore[assignment]
352    assert isinstance(
353        original_arg, Node
354    ), f"expect original argument to be a Node, but got: {type(original_arg)}"
355
356    input_edge = (original_arg, node)
357    if input_edge not in obs_or_fq_map:
358        return new_arg
359    # input_edge needs to be observed
360    input_edge_obs_or_fq = obs_or_fq_map[input_edge]
361    if input_edge_obs_or_fq is None:
362        return new_arg
363
364    arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
365    # the arg is observed as the output and is using the same instance as the input_edge
366    # we'll reuse the inserted observer/fake_quant
367    if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
368        input_edge_obs_or_fq
369    ):
370        return new_arg
371
372    # otherwise, we'll insert a new observer/fake_quant node
373
374    existing_obs_node = None
375    # skip inserting new observers if the same observer instance is inserted before for another user
376    # Example:
377    # conv1 -> obs1 -> existing_obs -> conv2
378    #             \ -> conv3
379    #
380    # instead of inserting new observers we will have:
381    # conv1 -> obs1 -> existing_obs -> conv2
382    #                            \ -> conv3
383    for maybe_obs_node in arg.users.keys():
384        if not _is_activation_post_process_node(maybe_obs_node, named_modules):
385            continue
386        maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
387        if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
388            return maybe_obs_node
389
390    new_arg = _insert_obs_or_fq(
391        arg, input_edge_obs_or_fq, model, named_modules, model.graph
392    )
393    return new_arg
394
395
396def _maybe_insert_input_observers_for_node(
397    node: Node,
398    qconfig: QConfigAny,
399    model: torch.nn.Module,
400    named_modules: Dict[str, torch.nn.Module],
401    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
402    is_qat: bool,
403) -> None:
404    """
405    If needed, inserts observers to the input args and kwargs of `node`.
406    Note: modifies `node` inplace.
407
408    For example, if cur_node needs an observer after prev_node, we change from
409
410      prev_node -> cur_node
411
412    To
413
414      prev_node -> obs -> cur_node
415
416    """
417    # Look through every input arg.  If that arg's target dtype does not
418    # match the current node's target dtype, insert an observer.
419    new_args = []
420    for arg in node.args:
421        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
422            node,
423            arg,
424            qconfig,
425            model,
426            named_modules,
427            obs_or_fq_map,
428            is_qat,
429        )
430        new_args.append(new_arg)
431
432    # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
433    # gelu has a has an approximate kwarg that persist in exported graph.
434    # This is just a work around for these.
435    assert (
436        node.target == torch.ops.aten.clone.default
437        or node.target == torch.ops.aten.zeros_like.default
438        or node.target == torch.ops.aten.gelu.default
439        or len(node.kwargs) == 0
440    ), " expecting kwargs for aten op IR to be empty"
441
442    # assign the new args to the node, inplace
443    node.args = tuple(new_args)
444
445
446def _maybe_insert_output_observer_for_node(
447    node: Node,
448    model: torch.nn.Module,
449    named_modules: Dict[str, torch.nn.Module],
450    graph: Graph,
451    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
452    is_qat: bool,
453) -> Optional[Node]:
454    if node in obs_or_fq_map:
455        output_act_obs_or_fq = obs_or_fq_map[node]
456        new_output = _insert_obs_or_fq(
457            node, output_act_obs_or_fq, model, named_modules, graph
458        )
459        # propagate numeric debug handle from original node to observer/fake_quant node
460        if (
461            isinstance(node, Node)
462            and isinstance(new_output, Node)
463            and CUSTOM_KEY in node.meta
464            and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
465        ):
466            if CUSTOM_KEY not in new_output.meta:
467                new_output.meta[CUSTOM_KEY] = {}
468            new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
469                CUSTOM_KEY
470            ][NUMERIC_DEBUG_HANDLE_KEY]
471        return new_output
472    return None
473
474
475def _maybe_insert_input_and_output_observers_for_node(
476    node: Node,
477    model: torch.fx.GraphModule,
478    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
479    is_qat: bool,
480):
481    this_node_quantization_annotation = (
482        node.meta["quantization_annotation"]
483        if "quantization_annotation" in node.meta
484        else None
485    )
486    if this_node_quantization_annotation is None:
487        return
488
489    named_modules = dict(model.named_modules(remove_duplicate=False))
490    _maybe_insert_input_observers_for_node(
491        node,
492        None,  # qconfig
493        model,
494        named_modules,
495        obs_or_fq_map,
496        is_qat,
497    )
498
499    output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
500    if not output_is_a_tensor:
501        return
502
503    # this returns the new observer node if it was needed
504    maybe_output_obs_node = _maybe_insert_output_observer_for_node(
505        node, model, named_modules, model.graph, obs_or_fq_map, is_qat
506    )
507
508    if maybe_output_obs_node is None:
509        return
510    # Update users of original node to use the output observer
511    # instead. For example, change
512    #
513    #           next_node
514    #          /
515    #   cur_node -> obs
516    #
517    # to
518    #
519    #                 next_node
520    #                 /
521    #   cur_node -> obs
522    #
523    # We need to save orig users before updating uses because
524    # the list of users will change as we update uses
525    orig_users = list(node.users.keys())
526    for user_node in orig_users:
527        if user_node is maybe_output_obs_node:
528            continue
529        user_node.replace_input_with(node, maybe_output_obs_node)
530
531
532def prepare(
533    model: GraphModule,
534    node_name_to_scope: Dict[str, Tuple[str, type]],
535    is_qat: bool,
536) -> GraphModule:
537    # Since we are mutating the graph as we go, we iterate over the original
538    # nodes before observer insertion, instead of model.graph.nodes.
539    nodes_before_observation = list(model.graph.nodes)
540
541    # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
542    # all edge/nodes that belongs to the same group will use the same instance
543    # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
544    # instance
545    edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
546    edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
547    obs_or_fq_map = _get_obs_or_fq_map(
548        edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
549    )
550
551    for node in nodes_before_observation:
552        # TODO: simplify logic for inserting observers
553        _maybe_insert_input_and_output_observers_for_node(
554            node, model, obs_or_fq_map, is_qat
555        )
556
557    model = GraphModule(model, model.graph)
558
559    _save_state(
560        model,
561        {},  # node_name_to_qconfig
562        node_name_to_scope,
563        PrepareCustomConfig(),
564        {},  # equalization_node_name_to_qconfig
565        QConfigMapping(),
566        is_qat,
567        set(),  # observed_node_names
568    )
569    return model
570