xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import itertools
4from dataclasses import dataclass
5from typing import Callable, Dict, List, NamedTuple, Optional
6
7import torch
8import torch.nn.functional as F
9from torch._subclasses import FakeTensor
10from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
11from torch.ao.quantization.pt2e.export_utils import _WrapperModule
12from torch.ao.quantization.pt2e.utils import (
13    _conv1d_bn_example_inputs,
14    _conv2d_bn_example_inputs,
15    _get_aten_graph_module_for_pattern,
16    _is_conv_node,
17    _is_conv_transpose_node,
18)
19from torch.ao.quantization.quantizer import (
20    QuantizationAnnotation,
21    QuantizationSpec,
22    QuantizationSpecBase,
23    SharedQuantizationSpec,
24)
25from torch.ao.quantization.quantizer.utils import (
26    _annotate_input_qspec_map,
27    _annotate_output_qspec,
28)
29from torch.fx import Node
30from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
31    SubgraphMatcherWithNameNodeMap,
32)
33from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
34
35
36__all__ = [
37    "OperatorConfig",
38    "OperatorPatternType",
39    "QuantizationConfig",
40    "get_input_act_qspec",
41    "get_output_act_qspec",
42    "get_weight_qspec",
43    "get_bias_qspec",
44    "OP_TO_ANNOTATOR",
45    "propagate_annotation",
46]
47
48
49# In the absence of better name, just winging it with QuantizationConfig
50@dataclass(eq=True, frozen=True)
51class QuantizationConfig:
52    input_activation: Optional[QuantizationSpec]
53    output_activation: Optional[QuantizationSpec]
54    weight: Optional[QuantizationSpec]
55    bias: Optional[QuantizationSpec]
56    # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
57    is_qat: bool = False
58
59
60OperatorPatternType = List[Callable]
61OperatorPatternType.__module__ = (
62    "torch.ao.quantization.quantizer.xnnpack_quantizer_utils"
63)
64
65AnnotatorType = Callable[
66    [
67        torch.fx.GraphModule,
68        Optional[QuantizationConfig],
69        Optional[Callable[[Node], bool]],
70    ],
71    Optional[List[List[Node]]],
72]
73OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
74
75
76def register_annotator(op: str):
77    def decorator(annotator: AnnotatorType):
78        OP_TO_ANNOTATOR[op] = annotator
79
80    return decorator
81
82
83class OperatorConfig(NamedTuple):
84    # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
85    # Basically we are mapping a quantization config to some list of patterns.
86    # a pattern is defined as a list of nn module, function or builtin function names
87    # e.g. [nn.Conv2d, torch.relu, torch.add]
88    # We have not resolved whether fusion can be considered internal details of the
89    # quantizer hence it does not need communication to user.
90    # Note this pattern is not really informative since it does not really
91    # tell us the graph structure resulting from the list of ops.
92    config: QuantizationConfig
93    operators: List[OperatorPatternType]
94
95
96def _is_annotated(nodes: List[Node]):
97    """
98    Given a list of nodes (that represents an operator pattern),
99    check if any of the node is annotated, return True if any of the node
100    is annotated, otherwise return False
101    """
102    annotated = False
103    for node in nodes:
104        annotated = annotated or (
105            "quantization_annotation" in node.meta
106            and node.meta["quantization_annotation"]._annotated
107        )
108    return annotated
109
110
111def _mark_nodes_as_annotated(nodes: List[Node]):
112    for node in nodes:
113        if node is not None:
114            if "quantization_annotation" not in node.meta:
115                node.meta["quantization_annotation"] = QuantizationAnnotation()
116            node.meta["quantization_annotation"]._annotated = True
117
118
119def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
120    if quantization_config is None:
121        return None
122    if quantization_config.input_activation is None:
123        return None
124    quantization_spec: QuantizationSpec = quantization_config.input_activation
125    assert quantization_spec.qscheme in [
126        torch.per_tensor_affine,
127        torch.per_tensor_symmetric,
128    ]
129    return quantization_spec
130
131
132def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
133    if quantization_config is None:
134        return None
135    if quantization_config.output_activation is None:
136        return None
137    quantization_spec: QuantizationSpec = quantization_config.output_activation
138    assert quantization_spec.qscheme in [
139        torch.per_tensor_affine,
140        torch.per_tensor_symmetric,
141    ]
142    return quantization_spec
143
144
145def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
146    if quantization_config is None:
147        return None
148    assert quantization_config is not None
149    if quantization_config.weight is None:
150        return None
151    quantization_spec: QuantizationSpec = quantization_config.weight
152    if quantization_spec.qscheme not in [
153        torch.per_tensor_symmetric,
154        torch.per_channel_symmetric,
155    ]:
156        raise ValueError(
157            f"Unsupported quantization_spec {quantization_spec} for weight"
158        )
159    return quantization_spec
160
161
162def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
163    if quantization_config is None:
164        return None
165    assert quantization_config is not None
166    if quantization_config.bias is None:
167        return None
168    quantization_spec: QuantizationSpec = quantization_config.bias
169    assert (
170        quantization_spec.dtype == torch.float
171    ), "Only float dtype for bias is supported for bias right now"
172    return quantization_spec
173
174
175@register_annotator("linear")
176def _annotate_linear(
177    gm: torch.fx.GraphModule,
178    quantization_config: Optional[QuantizationConfig],
179    filter_fn: Optional[Callable[[Node], bool]] = None,
180) -> Optional[List[List[Node]]]:
181    annotated_partitions = []
182    input_act_qspec = get_input_act_qspec(quantization_config)
183    output_act_qspec = get_output_act_qspec(quantization_config)
184    weight_qspec = get_weight_qspec(quantization_config)
185    bias_qspec = get_bias_qspec(quantization_config)
186    for node in gm.graph.nodes:
187        if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
188            continue
189        if filter_fn and not filter_fn(node):
190            continue
191        act_node = node.args[0]
192        weight_node = node.args[1]
193        bias_node = None
194        if len(node.args) > 2:
195            bias_node = node.args[2]
196
197        if _is_annotated([node]) is False:  # type: ignore[list-item]
198            _annotate_input_qspec_map(
199                node,
200                act_node,
201                input_act_qspec,
202            )
203            _annotate_input_qspec_map(
204                node,
205                weight_node,
206                weight_qspec,
207            )
208            nodes_to_mark_annotated = [node, weight_node]
209            if bias_node:
210                _annotate_input_qspec_map(
211                    node,
212                    bias_node,
213                    bias_qspec,
214                )
215                nodes_to_mark_annotated.append(bias_node)
216            _annotate_output_qspec(node, output_act_qspec)
217            _mark_nodes_as_annotated(nodes_to_mark_annotated)
218            annotated_partitions.append(nodes_to_mark_annotated)
219
220    return annotated_partitions
221
222
223@register_annotator("linear_relu")
224def _annotate_linear_relu(
225    gm: torch.fx.GraphModule,
226    quantization_config: Optional[QuantizationConfig],
227    filter_fn: Optional[Callable[[Node], bool]] = None,
228) -> Optional[List[List[Node]]]:
229    annotated_partitions = []
230    input_act_qspec = get_input_act_qspec(quantization_config)
231    output_act_qspec = get_output_act_qspec(quantization_config)
232    weight_qspec = get_weight_qspec(quantization_config)
233    bias_qspec = get_bias_qspec(quantization_config)
234    for node in gm.graph.nodes:
235        if node.op != "call_function" or node.target not in [
236            torch.ops.aten.relu.default,
237            torch.ops.aten.relu_.default,
238        ]:
239            continue
240        relu_node = node
241        maybe_linear_node = node.args[0]
242        if (
243            not isinstance(maybe_linear_node, Node)
244            or maybe_linear_node.op != "call_function"
245            or maybe_linear_node.target != torch.ops.aten.linear.default
246        ):
247            continue
248
249        linear_node = maybe_linear_node
250        input_qspec_map = {}
251        input_act = linear_node.args[0]
252        assert isinstance(input_act, Node)
253        input_qspec_map[input_act] = input_act_qspec
254
255        weight = linear_node.args[1]
256        assert isinstance(weight, Node)
257        input_qspec_map[weight] = weight_qspec
258
259        # adding weight node to the partition as well
260        partition = [relu_node, linear_node, weight]
261        bias = linear_node.args[2] if len(linear_node.args) > 2 else None
262        if isinstance(bias, Node):
263            input_qspec_map[bias] = bias_qspec
264            partition.append(bias)
265
266        if _is_annotated(partition):
267            continue
268
269        if filter_fn and any(not filter_fn(n) for n in partition):
270            continue
271
272        linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
273            input_qspec_map=input_qspec_map,
274            _annotated=True,
275        )
276        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
277            output_qspec=output_act_qspec,
278            _annotated=True,
279        )
280        _mark_nodes_as_annotated(partition)
281        annotated_partitions.append(partition)
282    return annotated_partitions
283
284
285@register_annotator("conv")
286def _annotate_conv(
287    gm: torch.fx.GraphModule,
288    quantization_config: Optional[QuantizationConfig],
289    filter_fn: Optional[Callable[[Node], bool]] = None,
290) -> Optional[List[List[Node]]]:
291    annotated_partitions = []
292    for n in gm.graph.nodes:
293        if n.op != "call_function" or n.target not in [
294            torch.ops.aten.conv1d.default,
295            torch.ops.aten.conv2d.default,
296        ]:
297            continue
298        conv_node = n
299
300        input_qspec_map = {}
301        input_act = conv_node.args[0]
302        assert isinstance(input_act, Node)
303        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
304
305        weight = conv_node.args[1]
306        assert isinstance(weight, Node)
307        input_qspec_map[weight] = get_weight_qspec(quantization_config)
308
309        # adding weight node to the partition as well
310        partition = [conv_node, conv_node.args[1]]
311
312        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
313        if isinstance(bias, Node):
314            input_qspec_map[bias] = get_bias_qspec(quantization_config)
315            partition.append(bias)
316
317        if _is_annotated(partition):
318            continue
319
320        if filter_fn and any(not filter_fn(n) for n in partition):
321            continue
322
323        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
324            input_qspec_map=input_qspec_map,
325            output_qspec=get_output_act_qspec(quantization_config),
326            _annotated=True,
327        )
328        _mark_nodes_as_annotated(partition)
329        annotated_partitions.append(partition)
330    return annotated_partitions
331
332
333def _do_annotate_conv_relu(
334    gm: torch.fx.GraphModule,
335    quantization_config: Optional[QuantizationConfig],
336    filter_fn: Optional[Callable[[Node], bool]] = None,
337    is_conv_transpose: bool = False,
338):
339    annotated_partitions = []
340    for n in gm.graph.nodes:
341        if n.op != "call_function" or n.target not in [
342            torch.ops.aten.relu.default,
343            torch.ops.aten.relu_.default,
344        ]:
345            continue
346        relu_node = n
347        maybe_conv_node = n.args[0]
348
349        is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
350        if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node):
351            continue
352        conv_node = maybe_conv_node
353
354        input_qspec_map = {}
355        input_act = conv_node.args[0]
356        assert isinstance(input_act, Node)
357        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
358
359        weight = conv_node.args[1]
360        assert isinstance(weight, Node)
361        input_qspec_map[weight] = get_weight_qspec(quantization_config)
362
363        # adding weight node to the partition as well
364        partition = [relu_node, conv_node, conv_node.args[1]]
365        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
366        if isinstance(bias, Node):
367            input_qspec_map[bias] = get_bias_qspec(quantization_config)
368            partition.append(bias)
369
370        if _is_annotated(partition):
371            continue
372
373        if filter_fn and any(not filter_fn(n) for n in partition):
374            continue
375
376        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
377            input_qspec_map=input_qspec_map, _annotated=True
378        )
379        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
380            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
381            _annotated=True,
382        )
383        _mark_nodes_as_annotated(partition)
384        annotated_partitions.append(partition)
385    return annotated_partitions
386
387
388@register_annotator("conv_relu")
389def _annotate_conv_relu(
390    gm: torch.fx.GraphModule,
391    quantization_config: Optional[QuantizationConfig],
392    filter_fn: Optional[Callable[[Node], bool]] = None,
393) -> Optional[List[List[Node]]]:
394    return _do_annotate_conv_relu(
395        gm, quantization_config, filter_fn, is_conv_transpose=False
396    )
397
398
399@register_annotator("conv_transpose_relu")
400def _annotate_conv_transpose_relu(
401    gm: torch.fx.GraphModule,
402    quantization_config: Optional[QuantizationConfig],
403    filter_fn: Optional[Callable[[Node], bool]] = None,
404) -> Optional[List[List[Node]]]:
405    return _do_annotate_conv_relu(
406        gm, quantization_config, filter_fn, is_conv_transpose=True
407    )
408
409
410@register_annotator("conv_bn")
411def _annotate_conv_bn(
412    gm: torch.fx.GraphModule,
413    quantization_config: Optional[QuantizationConfig],
414    filter_fn: Optional[Callable[[Node], bool]] = None,
415) -> Optional[List[List[Node]]]:
416    """
417    Find conv + batchnorm parititions
418    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
419    """
420    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
421
422
423@register_annotator("conv_bn_relu")
424def _annotate_conv_bn_relu(
425    gm: torch.fx.GraphModule,
426    quantization_config: Optional[QuantizationConfig],
427    filter_fn: Optional[Callable[[Node], bool]] = None,
428) -> Optional[List[List[Node]]]:
429    """
430    Find conv + batchnorm + relu parititions
431    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
432    """
433    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
434
435
436@register_annotator("conv_transpose_bn")
437def _annotate_conv_transpose_bn(
438    gm: torch.fx.GraphModule,
439    quantization_config: Optional[QuantizationConfig],
440    filter_fn: Optional[Callable[[Node], bool]] = None,
441) -> Optional[List[List[Node]]]:
442    """
443    Find conv_transpose + batchnorm parititions
444    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
445    """
446    return _do_annotate_conv_bn(
447        gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True
448    )
449
450
451@register_annotator("conv_transpose_bn_relu")
452def _annotate_conv_transpose_bn_relu(
453    gm: torch.fx.GraphModule,
454    quantization_config: Optional[QuantizationConfig],
455    filter_fn: Optional[Callable[[Node], bool]] = None,
456) -> Optional[List[List[Node]]]:
457    """
458    Find conv_transpose + batchnorm + relu parititions
459    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
460    """
461    return _do_annotate_conv_bn(
462        gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True
463    )
464
465
466def _do_annotate_conv_bn(
467    gm: torch.fx.GraphModule,
468    quantization_config: Optional[QuantizationConfig],
469    filter_fn: Optional[Callable[[Node], bool]],
470    has_relu: bool,
471    is_conv_transpose: bool = False,
472) -> List[List[Node]]:
473    """
474    Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
475    return a list of annotated partitions.
476
477    The output of the pattern must include a dictionary from string name to node
478    for the following names: "input", "conv", "weight", "bias", and "output".
479    """
480
481    def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
482        def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
483            conv = conv_fn(x, conv_weight, conv_bias)
484            bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
485            if has_relu:
486                output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
487            else:
488                output = bn
489            return output, {
490                "input": x,
491                "conv": conv,
492                "weight": conv_weight,
493                "bias": conv_bias,
494                "output": output,
495            }
496
497        return _WrapperModule(_conv_bn)
498
499    # Needed for matching, otherwise the matches gets filtered out due to unused
500    # nodes returned by batch norm
501    gm.graph.eliminate_dead_code()
502    gm.recompile()
503
504    matches = []
505    if is_conv_transpose:
506        combinations = [
507            (F.conv_transpose1d, _conv1d_bn_example_inputs),
508            (F.conv_transpose2d, _conv2d_bn_example_inputs),
509        ]
510    else:
511        combinations = [
512            (F.conv1d, _conv1d_bn_example_inputs),  # type: ignore[list-item]
513            (F.conv2d, _conv2d_bn_example_inputs),  # type: ignore[list-item]
514        ]
515
516    # Add `is_cuda` and `relu_is_inplace` dimensions
517    combinations = itertools.product(  # type: ignore[assignment]
518        combinations,
519        [True, False] if torch.cuda.is_available() else [False],  # is_cuda
520        [True, False] if has_relu else [False],  # relu_is_inplace
521    )
522
523    # Match against all conv dimensions and cuda variants
524    for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:  # type: ignore[misc]
525        pattern = get_pattern(conv_fn, relu_is_inplace)  # type: ignore[has-type]
526        pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)  # type: ignore[has-type]
527        pattern.graph.eliminate_dead_code()
528        pattern.recompile()
529        matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
530        matches.extend(matcher.match(gm.graph))
531
532    # Annotate nodes returned in the matches
533    annotated_partitions = []
534    for match in matches:
535        name_node_map = match.name_node_map
536        input_node = name_node_map["input"]
537        conv_node = name_node_map["conv"]
538        weight_node = name_node_map["weight"]
539        bias_node = name_node_map["bias"]
540        output_node = name_node_map["output"]
541
542        # TODO: annotate the uses of input, weight, and bias separately instead
543        # of assuming they come from a single conv node. This is not possible today
544        # because input may have multiple users, and we can't rely on the conv node
545        # always being the first user. This was the case in models with skip
546        # connections like resnet18
547
548        # Validate conv args
549        if conv_node.args[0] is not input_node:
550            raise ValueError("Conv arg did not contain input node ", input_node)
551        if conv_node.args[1] is not weight_node:
552            raise ValueError("Conv arg did not contain weight node ", weight_node)
553        if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
554            raise ValueError("Conv arg did not contain bias node ", bias_node)
555
556        # Skip if the partition is already annotated or is filtered out by the user
557        partition = [conv_node, weight_node]
558        if bias_node is not None:
559            partition.append(bias_node)
560        if _is_annotated(partition):
561            continue
562        if filter_fn and any(not filter_fn(n) for n in partition):
563            continue
564
565        # Annotate conv inputs and pattern output
566        input_qspec_map = {}
567        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
568        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
569        if bias_node is not None:
570            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
571        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
572            input_qspec_map=input_qspec_map,
573            _annotated=True,
574        )
575        output_node.meta["quantization_annotation"] = QuantizationAnnotation(
576            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
577            _annotated=True,
578        )
579        _mark_nodes_as_annotated(partition)
580        annotated_partitions.append(partition)
581    return annotated_partitions
582
583
584@register_annotator("gru_io_only")
585def _annotate_gru_io_only(
586    gm: torch.fx.GraphModule,
587    quantization_config: Optional[QuantizationConfig],
588    filter_fn: Optional[Callable[[Node], bool]] = None,
589) -> Optional[List[List[Node]]]:
590    gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
591    gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values()))
592    annotated_partitions = []
593    for gru_partition in gru_partitions:
594        annotated_partitions.append(gru_partition.nodes)
595        output_nodes = gru_partition.output_nodes
596        input_nodes = gru_partition.input_nodes
597        # skip annotation if it is already annotated
598        if _is_annotated(input_nodes + output_nodes):
599            continue
600        # inside each GRU partition, we should be able to annotate each linear
601        # subgraph
602        input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
603        input_act = input_nodes[0]
604        input_act_user = next(iter(input_act.users.keys()))
605        assert isinstance(input_act, Node)
606        assert isinstance(input_act_user, Node)
607        input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
608            input_qspec_map={
609                input_act: get_input_act_qspec(quantization_config),
610            },
611            _annotated=True,
612        )
613
614        hidden_state = input_nodes[1]
615        hidden_state_user = next(iter(hidden_state.users.keys()))
616        assert isinstance(hidden_state, Node)
617        assert isinstance(hidden_state_user, Node)
618        hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
619            input_qspec_map={
620                hidden_state: get_input_act_qspec(quantization_config),
621            },
622            _annotated=True,
623        )
624
625        assert len(output_nodes) == 2, "expecting GRU to have two outputs"
626        for output in output_nodes:
627            output.meta["quantization_annotation"] = QuantizationAnnotation(
628                output_qspec=get_output_act_qspec(quantization_config),
629                _annotated=True,
630            )
631        nodes_to_mark_annotated = list(gru_partition.nodes)
632        _mark_nodes_as_annotated(nodes_to_mark_annotated)
633    return annotated_partitions
634
635
636@register_annotator("adaptive_avg_pool2d")
637def _annotate_adaptive_avg_pool2d(
638    gm: torch.fx.GraphModule,
639    quantization_config: Optional[QuantizationConfig],
640    filter_fn: Optional[Callable[[Node], bool]] = None,
641) -> Optional[List[List[Node]]]:
642    """Always annotate adaptive_avg_pool2d op"""
643    module_partitions = get_source_partitions(
644        gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
645    )
646    partitions = list(itertools.chain.from_iterable(module_partitions.values()))
647    annotated_partitions = []
648    for partition in partitions:
649        pool_node = partition.output_nodes[0]
650        if (
651            pool_node.op != "call_function"
652            or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
653        ):
654            raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
655
656        if _is_annotated([pool_node]):
657            continue
658
659        annotated_partitions.append(partition.nodes)
660        input_act = pool_node.args[0]
661        assert isinstance(input_act, Node)
662
663        # only annotate input output sharing operator
664        # when the output of the input node is annotated
665        if (
666            "quantization_annotation" not in input_act.meta
667            or not input_act.meta["quantization_annotation"]._annotated
668            or input_act.meta["quantization_annotation"].output_qspec is None
669        ):
670            input_act_qspec = get_input_act_qspec(quantization_config)
671        else:
672            input_act_qspec = SharedQuantizationSpec(input_act)
673
674        # output sharing with input
675        output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
676        pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
677            input_qspec_map={
678                input_act: input_act_qspec,
679            },
680            output_qspec=output_act_qspec,
681            _annotated=True,
682        )
683    return annotated_partitions
684
685
686def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule):
687    """Check if input is a large scalar value. So that we can skip quantization for the node
688    since histc op (in HistogramObserver) only works for values up to certain upper bound
689    """
690    if node.op == "get_attr":
691        qualified_name = str(node.target)
692        module_path, _, name = qualified_name.rpartition(".")
693        submod = gm.get_submodule(module_path)
694        tensor = getattr(submod, name)
695        # torch.histc works until this upper bound
696        HISTC_UPPER_BOUND = 3.4028235e15
697        return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
698    return False
699
700
701def _is_input_non_float_tensor(node: Node):
702    """Check if the input is not a float tensor, so that we can skip quantization for the node
703    since observers only works with float Tensors
704    """
705    if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
706        return True
707    return node.meta["val"].dtype != torch.float32
708
709
710@register_annotator("add_relu")
711def _annotate_add_relu(
712    gm: torch.fx.GraphModule,
713    quantization_config: Optional[QuantizationConfig],
714    filter_fn: Optional[Callable[[Node], bool]] = None,
715) -> Optional[List[List[Node]]]:
716    annotated_partitions = []
717    for node in gm.graph.nodes:
718        if node.op != "call_function" or node.target not in [
719            torch.ops.aten.relu.default,
720            torch.ops.aten.relu_.default,
721        ]:
722            continue
723        relu_node = node
724        maybe_add = node.args[0]
725        if (
726            not isinstance(maybe_add, Node)
727            or maybe_add.op != "call_function"
728            or maybe_add.target
729            not in [
730                torch.ops.aten.add.Tensor,
731                torch.ops.aten.add_.Tensor,
732            ]
733        ):
734            continue
735
736        add_node = maybe_add
737        partition = [relu_node, add_node]
738
739        if _is_annotated(partition):
740            continue
741
742        if filter_fn and any(not filter_fn(n) for n in partition):
743            continue
744
745        input_act_qspec = get_input_act_qspec(quantization_config)
746        output_act_qspec = get_output_act_qspec(quantization_config)
747
748        input_qspec_map = {}
749        input_act0 = add_node.args[0]
750        if isinstance(input_act0, Node):
751            if _is_input_large_scalar(input_act0, gm):
752                continue
753            if _is_input_non_float_tensor(input_act0):
754                continue
755            partition.append(input_act0)
756            input_qspec_map[input_act0] = input_act_qspec
757
758        input_act1 = add_node.args[1]
759        if isinstance(input_act1, Node):
760            if _is_input_large_scalar(input_act1, gm):
761                continue
762            if _is_input_non_float_tensor(input_act1):
763                continue
764            partition.append(input_act1)
765            input_qspec_map[input_act1] = input_act_qspec
766
767        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
768            input_qspec_map=input_qspec_map,
769            _annotated=True,
770        )
771        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
772            output_qspec=output_act_qspec,
773            _annotated=True,
774        )
775        annotated_partitions.append(partition)
776    return annotated_partitions
777
778
779@register_annotator("add")
780def _annotate_add(
781    gm: torch.fx.GraphModule,
782    quantization_config: Optional[QuantizationConfig],
783    filter_fn: Optional[Callable[[Node], bool]] = None,
784) -> Optional[List[List[Node]]]:
785    annotated_partitions = []
786    for node in gm.graph.nodes:
787        if node.op != "call_function" or node.target not in [
788            torch.ops.aten.add.Tensor,
789            torch.ops.aten.add_.Tensor,
790        ]:
791            continue
792        add_node = node
793        partition = [add_node]
794
795        if _is_annotated(partition):
796            continue
797
798        if filter_fn and any(not filter_fn(n) for n in partition):
799            continue
800
801        input_act_qspec = get_input_act_qspec(quantization_config)
802        output_act_qspec = get_output_act_qspec(quantization_config)
803
804        input_qspec_map = {}
805        input_act0 = add_node.args[0]
806        if isinstance(input_act0, Node):
807            if _is_input_large_scalar(input_act0, gm):
808                continue
809            if _is_input_non_float_tensor(input_act0):
810                continue
811            input_qspec_map[input_act0] = input_act_qspec
812            partition.append(input_act0)
813
814        input_act1 = add_node.args[1]
815        if isinstance(input_act1, Node):
816            if _is_input_large_scalar(input_act1, gm):
817                continue
818            if _is_input_non_float_tensor(input_act1):
819                continue
820            input_qspec_map[input_act1] = input_act_qspec
821            partition.append(input_act1)
822
823        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
824            input_qspec_map=input_qspec_map,
825            output_qspec=output_act_qspec,
826            _annotated=True,
827        )
828        annotated_partitions.append(partition)
829    return annotated_partitions
830
831
832@register_annotator("mul_relu")
833def _annotate_mul_relu(
834    gm: torch.fx.GraphModule,
835    quantization_config: Optional[QuantizationConfig],
836    filter_fn: Optional[Callable[[Node], bool]] = None,
837) -> Optional[List[List[Node]]]:
838    annotated_partitions = []
839    for node in gm.graph.nodes:
840        if node.op != "call_function" or node.target not in [
841            torch.ops.aten.relu.default,
842            torch.ops.aten.relu_.default,
843        ]:
844            continue
845        relu_node = node
846        maybe_mul = node.args[0]
847        if (
848            not isinstance(maybe_mul, Node)
849            or maybe_mul.op != "call_function"
850            or maybe_mul.target
851            not in [
852                torch.ops.aten.mul.Tensor,
853                torch.ops.aten.mul_.Tensor,
854            ]
855        ):
856            continue
857
858        mul_node = maybe_mul
859        partition = [relu_node, mul_node]
860
861        if _is_annotated(partition):
862            continue
863
864        if filter_fn and any(not filter_fn(n) for n in partition):
865            continue
866
867        input_act_qspec = get_input_act_qspec(quantization_config)
868        output_act_qspec = get_output_act_qspec(quantization_config)
869
870        input_qspec_map = {}
871        input_act0 = mul_node.args[0]
872        if isinstance(input_act0, Node):
873            if _is_input_large_scalar(input_act0, gm):
874                continue
875            if _is_input_non_float_tensor(input_act0):
876                continue
877            partition.append(input_act0)
878            input_qspec_map[input_act0] = input_act_qspec
879
880        input_act1 = mul_node.args[1]
881        if isinstance(input_act1, Node):
882            if _is_input_large_scalar(input_act1, gm):
883                continue
884            if _is_input_non_float_tensor(input_act1):
885                continue
886            partition.append(input_act1)
887            input_qspec_map[input_act1] = input_act_qspec
888
889        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
890            input_qspec_map=input_qspec_map,
891            _annotated=True,
892        )
893        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
894            output_qspec=output_act_qspec,
895            _annotated=True,
896        )
897        annotated_partitions.append(partition)
898    return annotated_partitions
899
900
901@register_annotator("mul")
902def _annotate_mul(
903    gm: torch.fx.GraphModule,
904    quantization_config: Optional[QuantizationConfig],
905    filter_fn: Optional[Callable[[Node], bool]] = None,
906) -> Optional[List[List[Node]]]:
907    annotated_partitions = []
908    for node in gm.graph.nodes:
909        if node.op != "call_function" or node.target not in [
910            torch.ops.aten.mul.Tensor,
911            torch.ops.aten.mul_.Tensor,
912        ]:
913            continue
914
915        mul_node = node
916        partition = [mul_node]
917        if _is_annotated(partition):
918            continue
919
920        if filter_fn and any(not filter_fn(n) for n in partition):
921            continue
922
923        input_act_qspec = get_input_act_qspec(quantization_config)
924        output_act_qspec = get_output_act_qspec(quantization_config)
925
926        input_qspec_map = {}
927        input_act0 = mul_node.args[0]
928        if isinstance(input_act0, Node):
929            if _is_input_large_scalar(input_act0, gm):
930                continue
931            if _is_input_non_float_tensor(input_act0):
932                continue
933            input_qspec_map[input_act0] = input_act_qspec
934            partition.append(input_act0)
935
936        input_act1 = mul_node.args[1]
937        if isinstance(input_act1, Node):
938            if _is_input_large_scalar(input_act1, gm):
939                continue
940            if _is_input_non_float_tensor(input_act1):
941                continue
942            input_qspec_map[input_act1] = input_act_qspec
943            partition.append(input_act0)
944
945        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
946            input_qspec_map=input_qspec_map,
947            output_qspec=output_act_qspec,
948            _annotated=True,
949        )
950        annotated_partitions.append(partition)
951    return annotated_partitions
952
953
954# TODO: remove Optional in return type, fix annotated_partitions logic
955@register_annotator("cat")
956def _annotate_cat(
957    gm: torch.fx.GraphModule,
958    quantization_config: Optional[QuantizationConfig],
959    filter_fn: Optional[Callable[[Node], bool]] = None,
960) -> Optional[List[List[Node]]]:
961    cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
962    cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
963    annotated_partitions = []
964    for cat_partition in cat_partitions:
965        cat_node = cat_partition.output_nodes[0]
966        if _is_annotated([cat_node]):
967            continue
968
969        if cat_node.target != torch.ops.aten.cat.default:
970            # TODO: change this to AnnotationException
971            raise Exception(  # noqa: TRY002
972                f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
973                " please check if you are calling the correct capture API"
974            )
975
976        annotated_partitions.append(cat_partition.nodes)
977
978        input_act_qspec = get_input_act_qspec(quantization_config)
979        inputs = cat_node.args[0]
980
981        input_qspec_map = {}
982        input_act0 = inputs[0]  # type: ignore[index]
983        if isinstance(input_act0, Node):
984            input_qspec_map[input_act0] = input_act_qspec
985
986        shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))  # type: ignore[arg-type]
987        for input_act in inputs[1:]:  # type: ignore[index]
988            input_qspec_map[input_act] = shared_with_input0_qspec  # type: ignore[index]
989
990        output_act_qspec = shared_with_input0_qspec
991
992        cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
993            input_qspec_map=input_qspec_map,
994            output_qspec=output_act_qspec,
995            _annotated=True,
996        )
997    return annotated_partitions
998
999
1000def _is_share_obs_or_fq_op(op: Callable) -> bool:
1001    return op in [
1002        torch.ops.aten.hardtanh.default,
1003        torch.ops.aten.hardtanh_.default,
1004        torch.ops.aten.max_pool2d.default,
1005        torch.ops.aten.mean.default,
1006        torch.ops.aten.mean.dim,
1007        torch.ops.aten.permute.default,
1008        torch.ops.aten.permute_copy.default,
1009        torch.ops.aten.squeeze.dim,
1010        torch.ops.aten.squeeze_copy.dim,
1011        # TODO: remove?
1012        torch.ops.aten.adaptive_avg_pool2d.default,
1013        torch.ops.aten.view_copy.default,
1014        torch.ops.aten.view.default,
1015        torch.ops.aten.slice_copy.Tensor,
1016        torch.ops.aten.flatten.using_ints,
1017    ]
1018
1019
1020def propagate_annotation(model: torch.fx.GraphModule) -> None:
1021    for n in model.graph.nodes:
1022        if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
1023            continue
1024
1025        prev_node = n.args[0]
1026        if not isinstance(prev_node, Node):
1027            continue
1028
1029        quantization_annotation = prev_node.meta.get("quantization_annotation", None)
1030        if not quantization_annotation:
1031            continue
1032
1033        output_qspec = quantization_annotation.output_qspec
1034        if not output_qspec:
1035            continue
1036
1037        # make sure current node is not annotated
1038        if (
1039            "quantization_annotation" in n.meta
1040            and n.meta["quantization_annotation"]._annotated
1041        ):
1042            continue
1043
1044        shared_qspec = SharedQuantizationSpec(prev_node)
1045        # propagate the previous output_qspec to the current node
1046        n.meta["quantization_annotation"] = QuantizationAnnotation(
1047            input_qspec_map={
1048                prev_node: shared_qspec,
1049            },
1050            output_qspec=shared_qspec,
1051            _annotated=True,
1052        )
1053
1054
1055# TODO: make the list of ops customizable
1056def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1057    for n in model.graph.nodes:
1058        if n.op != "call_function" or n.target not in [
1059            torch.ops.aten.add.Tensor,
1060            torch.ops.aten.mul.Tensor,
1061        ]:
1062            continue
1063        args = list(n.args)
1064        new_args = []
1065        for i in range(len(args)):
1066            if isinstance(args[i], torch.fx.Node):
1067                new_args.append(args[i])
1068                continue
1069            prefix = "_tensor_constant_"
1070            get_new_attr_name = get_new_attr_name_with_prefix(prefix)
1071            tensor_constant_name = get_new_attr_name(model)
1072            float_tensor = torch.tensor(float(args[i]))
1073            model.register_buffer(tensor_constant_name, float_tensor)
1074            fake_mode = n.meta["val"].fake_mode
1075            with model.graph.inserting_before(n):
1076                get_attr_node = model.graph.create_node(
1077                    "get_attr", tensor_constant_name, (), {}
1078                )
1079                get_attr_node.meta["val"] = fake_mode.from_tensor(
1080                    float_tensor, static_shapes=True
1081                )
1082                new_args.append(get_attr_node)
1083        n.args = tuple(new_args)
1084    model.recompile()
1085    return model
1086