xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/mappings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import operator
2from typing import Callable, Dict, List, Optional, Set, Tuple
3
4import torch
5import torch.ao.nn.intrinsic as nni
6import torch.ao.nn.intrinsic.qat as nniqat
7import torch.ao.nn.intrinsic.quantized as nniq
8import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
9import torch.ao.nn.qat as nnqat
10import torch.ao.nn.qat.dynamic as nnqatd
11import torch.ao.nn.quantized as nnq
12import torch.ao.nn.quantized.dynamic as nnqd
13import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend
14import torch.ao.quantization.quantization_mappings as quantization_mappings
15import torch.nn as nn
16import torch.nn.functional as F
17from torch.ao.quantization.backend_config import get_native_backend_config
18
19from .ns_types import NSNodeTargetType
20
21
22toq = torch.ops.quantized
23
24
25def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
26    # note: this set is modified below by items from backend_config
27    sets_of_related_ops: List[Set[NSNodeTargetType]] = [
28        # conv modules
29        {
30            nn.Conv1d,
31        },
32        {
33            nn.Conv2d,
34        },
35        {
36            nn.Conv3d,
37        },
38        # conv functionals
39        {
40            F.conv1d,
41        },
42        {
43            F.conv2d,
44        },
45        {
46            F.conv3d,
47        },
48        # linear modules
49        {
50            nn.Linear,
51        },
52        # linear functionals
53        {
54            F.linear,
55        },
56        # average pool
57        {
58            nn.AvgPool1d,
59            torch.avg_pool1d,
60        },
61        {
62            nn.AvgPool2d,
63            torch._C._nn.avg_pool2d,
64        },
65        {
66            nn.AvgPool3d,
67            torch._C._nn.avg_pool3d,
68        },
69        # adaptive average pool
70        {
71            nn.AdaptiveAvgPool1d,
72            F.adaptive_avg_pool1d,
73        },
74        {
75            nn.AdaptiveAvgPool2d,
76            F.adaptive_avg_pool2d,
77        },
78        {
79            nn.AdaptiveAvgPool3d,
80            F.adaptive_avg_pool3d,
81        },
82        # LSTM
83        {
84            nn.LSTM,
85        },
86        # add
87        {
88            torch.add,
89            operator.add,  # x + y
90        },
91        # cat
92        {
93            torch.cat,
94        },
95        # mul
96        {
97            torch.mul,
98            operator.mul,
99        },
100        # relu
101        {
102            F.relu,
103            nn.ReLU,
104            "relu",
105            "relu_",
106            torch.relu,
107        },
108        # maxpool
109        {
110            nn.MaxPool1d,
111            F.max_pool1d,
112        },
113        {
114            nn.MaxPool2d,
115            F.max_pool2d,
116        },
117        {
118            nn.MaxPool3d,
119            F.max_pool3d,
120        },
121        # sigmoid
122        {
123            torch.sigmoid,
124            "sigmoid",
125            "sigmoid_",
126            nn.Sigmoid,
127            F.sigmoid,
128        },
129        # BatchNorm
130        {
131            nn.BatchNorm2d,
132        },
133        {
134            nn.BatchNorm3d,
135        },
136        # ConvTranspose
137        {
138            nn.ConvTranspose1d,
139        },
140        {
141            nn.ConvTranspose2d,
142        },
143        {
144            nn.ConvTranspose3d,
145        },
146        # functional transposed conv
147        {
148            F.conv_transpose1d,
149        },
150        {
151            F.conv_transpose2d,
152        },
153        {
154            F.conv_transpose3d,
155        },
156        # ELU
157        {
158            nn.ELU,
159        },
160        # Embedding
161        {
162            nn.Embedding,
163        },
164        # EmbeddingBag
165        {
166            nn.EmbeddingBag,
167        },
168        # GroupNorm
169        {
170            nn.GroupNorm,
171        },
172        # Hardswish
173        {
174            nn.Hardswish,
175        },
176        # InstanceNorm
177        {
178            nn.InstanceNorm1d,
179        },
180        {
181            nn.InstanceNorm2d,
182        },
183        {
184            nn.InstanceNorm3d,
185        },
186        # LayerNorm
187        {
188            nn.LayerNorm,
189        },
190        # LeakyReLU
191        {
192            nn.LeakyReLU,
193        },
194        # ReLU6
195        {
196            nn.ReLU6,
197            F.relu6,
198        },
199        # F.elu
200        {
201            F.elu,
202        },
203        # F.hardswish
204        {
205            F.hardswish,
206        },
207        # F.group_norm
208        {
209            F.group_norm,
210        },
211        # F.instance_norm
212        {
213            F.instance_norm,
214        },
215        # F.layer_norm
216        {
217            F.layer_norm,
218        },
219        # F.leaky_relu
220        {
221            F.leaky_relu,
222        },
223        # F.silu
224        {
225            nn.SiLU,
226            F.silu,
227        },
228        # F.mish
229        {
230            nn.Mish,
231            F.mish,
232        },
233        # F.tanh
234        {
235            nn.Tanh,
236            F.tanh,
237            torch.tanh,
238            "tanh_",
239            "tanh",
240        },
241        # F.hardsigmoid
242        {
243            "hardsigmoid_",
244            "hardsigmoid",
245            F.hardsigmoid,
246            nn.Hardsigmoid,
247        },
248        # F.hardtanh
249        {
250            nn.Hardtanh,
251            F.hardtanh,
252            F.hardtanh_,
253        },
254        # floordiv
255        {
256            operator.floordiv,
257        },
258        # unsqueeze
259        {
260            torch.unsqueeze,
261        },
262        # stack
263        {
264            torch.stack,
265        },
266        # squeeze
267        {
268            torch.squeeze,
269        },
270        # sort
271        {
272            torch.sort,
273        },
274        # repeat_interleave
275        {
276            torch.repeat_interleave,
277        },
278        # min
279        {
280            torch.min,
281        },
282        # mean
283        {
284            torch.mean,
285        },
286        # max
287        {
288            torch.max,
289        },
290        # transpose
291        {
292            torch.transpose,
293        },
294        # flatten
295        {
296            torch.flatten,
297        },
298        # clamp
299        {
300            torch.clamp,
301        },
302        # chunk
303        {
304            torch.chunk,
305        },
306        # interpolate
307        {
308            torch.nn.functional.interpolate,
309        },
310        # dropout
311        {
312            nn.Dropout,
313        },
314        # F.dropout
315        {
316            F.dropout,
317        },
318        # matmul
319        {
320            torch.matmul,
321        },
322        # Softmax
323        {
324            nn.Softmax,
325        },
326        # PReLU
327        {
328            nn.PReLU,
329            nnq.PReLU,
330        },
331        # F.prelu
332        {
333            F.prelu,
334            toq.prelu,
335        },
336        # pixel shuffle
337        {
338            nn.PixelShuffle,
339        },
340        {
341            F.pixel_shuffle,
342        },
343        # pixel unshuffle
344        {
345            nn.PixelUnshuffle,
346        },
347        {
348            F.pixel_unshuffle,
349        },
350        # narrow
351        {
352            torch.narrow,
353        },
354    ]
355
356    # for each floating point op, add versions of the op added by
357    # backend_config
358    backend_config = get_native_backend_config()
359
360    new_connections: List[Tuple[Callable, Callable]] = [
361        # technical debt edge case
362        (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
363    ]
364
365    for pattern, config in backend_config._pattern_complex_format_to_config.items():
366        # pattern format: (c, (b, a))
367        first_element = pattern
368        # look from the end, because pattern is in reverse order
369        while isinstance(first_element, (list, tuple)):
370            first_element = first_element[-1]
371
372        if config.fused_module is not None:
373            # case 1: pattern fuses a pattern of ops into an op
374            # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
375            new_connections.append((first_element, config.fused_module))
376
377        if config.qat_module is not None:
378            # case 2: pattern swaps a module into a QAT module
379            # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
380            new_connections.append((first_element, config.qat_module))
381
382        if config.reference_quantized_module is not None:
383            # case 3: reference version of floating point module, such as
384            # nn.Conv2d and nnqr.Conv2d
385            new_connections.append((first_element, config.reference_quantized_module))
386
387    #
388    # Add reference module swaps from default lowering path
389    #
390
391    for source_to_target in (
392        _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
393        _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
394        _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
395        _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
396    ):
397        for source, target in source_to_target.items():  # type: ignore[attr-defined]
398            new_connections.append((source, target))
399
400    for source_to_double_target in (
401        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
402        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
403        _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
404    ):
405        for source, (target1, target2) in source_to_double_target.items():  # type: ignore[attr-defined]
406            new_connections.append((source, target1))
407            new_connections.append((source, target2))
408
409    #
410    # Add function swaps from default lowering path
411    #
412
413    for source, (
414        target1,
415        target2,
416    ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
417        new_connections.append((source, target1))
418        new_connections.append((source, target2))
419
420    for source_to_target in (
421        _lower_to_native_backend.QBIN_OP_MAPPING,
422        _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
423        quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
424    ):
425        for source, target in source_to_target.items():
426            new_connections.append((source, target))
427
428    #
429    # Add other swaps, ideally in the future this could be removed
430    # after the lowering code stops using these.
431    #
432    for source_to_target in (
433        quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
434    ):
435        for source, target in source_to_target.items():
436            new_connections.append((source, target))
437
438    # add the new connections from backend_config
439    for item1, item2 in new_connections:
440        for set_of_related_ops in sets_of_related_ops:
441            if item1 in set_of_related_ops or item2 in set_of_related_ops:
442                set_of_related_ops.add(item1)
443                set_of_related_ops.add(item2)
444                break
445
446    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
447
448    counter = 0
449    for set_of_related_ops in sets_of_related_ops:
450        base_name = str(counter)
451        counter += 1
452        base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
453
454    return base_name_to_sets_of_related_ops
455
456
457def get_base_name_for_op(
458    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
459    op: NSNodeTargetType,
460) -> Optional[str]:
461    for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
462        if op in set_of_related_ops:
463            return base_name
464    return None
465
466
467def add_op_to_sets_of_related_ops(
468    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
469    op: NSNodeTargetType,
470    related_op: Optional[NSNodeTargetType],
471) -> None:
472    if related_op is not None:
473        for set_of_related_ops in base_name_to_sets_of_related_ops.values():
474            if related_op in set_of_related_ops:
475                set_of_related_ops.add(op)
476                return
477        # if we got here, related_op was not found
478        raise AssertionError(f"{related_op} was not found")
479    else:
480        counter = 0
481        while str(counter) in base_name_to_sets_of_related_ops:
482            counter += 1
483        base_name_to_sets_of_related_ops[str(counter)] = {op}
484
485
486# TODO(future PR): clean this up
487def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
488    FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
489        F.linear,
490        F.conv1d,
491        F.conv2d,
492        F.conv3d,
493        torch.cat,
494        F.elu,
495        F.hardswish,
496        F.instance_norm,
497        F.layer_norm,
498        F.leaky_relu,
499        F.dropout,
500        F.silu,
501        F.mish,
502        operator.add,
503        torch.add,
504        operator.mul,
505        torch.mul,
506        torch.sum,
507        F.prelu,
508    }
509
510    FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
511
512    FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
513        toq.linear,
514        toq.linear_relu,
515        toq.conv1d,
516        toq.conv1d_relu,
517        toq.conv2d,
518        toq.conv2d_relu,
519        toq.conv3d,
520        toq.conv3d_relu,
521        toq.cat,
522        toq.elu,
523        toq.hardswish,
524        toq.instance_norm,
525        toq.layer_norm,
526        toq.leaky_relu,
527        toq.dropout,
528        toq.prelu,
529        # TODO(future PR): implement shadowing for binary ops and
530        # uncomment below
531        # toq.add,
532        # toq.mul,
533    }
534
535    FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
536        F.relu,
537        F.tanh,
538        torch.tanh,
539        F.sigmoid,
540        torch.sigmoid,
541        F.hardsigmoid,
542        operator.floordiv,
543        torch.adaptive_avg_pool1d,
544        F.adaptive_avg_pool2d,
545        F.adaptive_avg_pool3d,
546        F.dropout,
547        F.hardtanh,
548        F.hardtanh_,
549        F.interpolate,
550        F.max_pool1d,
551        F.max_pool2d,
552        F.max_pool3d,
553        F.relu6,
554        F.pixel_shuffle,
555        F.pixel_unshuffle,
556        torch.avg_pool1d,
557        torch._C._nn.avg_pool2d,
558        torch._C._nn.avg_pool3d,
559        torch.cat,
560        torch.chunk,
561        torch.clamp,
562        torch.flatten,
563        torch.transpose,
564        torch.max,
565        torch.mean,
566        torch.min,
567        torch.narrow,
568        torch.repeat_interleave,
569        torch.sort,
570        torch.squeeze,
571        torch.stack,
572        torch.unsqueeze,
573        operator.add,
574    }
575
576    MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
577        nn.Linear,
578        nnqat.Linear,
579        nnqatd.Linear,
580        nnqd.Linear,
581        torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
582        nn.Conv1d,
583        nn.Conv2d,
584        nn.Conv3d,
585        nnqat.Conv1d,
586        nnqat.Conv2d,
587        nnqat.Conv3d,
588        nnqat.Embedding,
589        nnqat.EmbeddingBag,
590        nn.LSTM,
591        # note: nnqd.Linear is an instance of nnq.Linear, so this
592        # check has to happen before the int8 module check
593        nnqd.LSTM,
594        nn.BatchNorm2d,
595        nn.BatchNorm3d,
596        nn.Dropout,
597        nn.ConvTranspose1d,
598        nn.ConvTranspose2d,
599        nn.ConvTranspose3d,
600        nn.ELU,
601        nn.GroupNorm,
602        nn.InstanceNorm1d,
603        nn.InstanceNorm2d,
604        nn.InstanceNorm3d,
605        nn.LayerNorm,
606        nn.Hardswish,
607        nn.LeakyReLU,
608        nn.ReLU6,
609        nn.SiLU,
610        nn.Mish,
611        nn.Softmax,
612        nn.PReLU,
613        nni.BNReLU2d,
614        nni.BNReLU3d,
615        nni.ConvReLU1d,
616        nni.ConvReLU2d,
617        nni.ConvReLU3d,
618        nni.LinearReLU,
619        nni.LinearBn1d,
620        nni.ConvBn1d,
621        nni.ConvBn2d,
622        nni.ConvBn3d,
623        nniqat.ConvBn1d,
624        nniqat.ConvBn2d,
625        nniqat.ConvBn3d,
626        nniqat.ConvBnReLU1d,
627        nniqat.ConvBnReLU2d,
628        nniqat.ConvBnReLU3d,
629        nniqat.ConvReLU1d,
630        nniqat.ConvReLU2d,
631        nniqat.ConvReLU3d,
632        nniqat.LinearReLU,
633        nniqat.LinearBn1d,
634        nniqd.LinearReLU,
635        nni.LinearLeakyReLU,
636        nni.LinearTanh,
637        nni.ConvAdd2d,
638        nni.ConvAddReLU2d,
639    }
640
641    MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
642        nnq.Linear,
643        nnq.Conv1d,
644        nnq.Conv2d,
645        nnq.Conv3d,
646        nnq.BatchNorm2d,
647        nnq.BatchNorm3d,
648        nnq.Dropout,
649        nnq.ConvTranspose1d,
650        nnq.ConvTranspose2d,
651        nnq.ELU,
652        nnq.InstanceNorm1d,
653        nnq.InstanceNorm2d,
654        nnq.InstanceNorm3d,
655        nnq.LayerNorm,
656        nnq.Hardswish,
657        nnq.LeakyReLU,
658        nnq.Embedding,
659        nnq.EmbeddingBag,
660        nnq.Dropout,
661        nnq.Softmax,
662        nnq.PReLU,
663        nniq.BNReLU2d,
664        nniq.BNReLU3d,
665        nniq.ConvReLU1d,
666        nniq.ConvReLU2d,
667        nniq.ConvReLU3d,
668        nniq.LinearReLU,
669        nniq.LinearLeakyReLU,
670        nniq.LinearTanh,
671        nniq.ConvAdd2d,
672        nniq.ConvAddReLU2d,
673    }
674
675    MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
676        nn.ReLU,
677        nn.Tanh,
678        nn.Sigmoid,
679        nn.Hardsigmoid,
680        nn.AdaptiveAvgPool1d,
681        nn.AdaptiveAvgPool2d,
682        nn.AdaptiveAvgPool3d,
683        nn.AvgPool1d,
684        nn.AvgPool2d,
685        nn.AvgPool3d,
686        nn.Dropout,
687        nn.Hardtanh,
688        nn.Identity,
689        nn.MaxPool1d,
690        nn.MaxPool2d,
691        nn.MaxPool3d,
692        nn.PixelShuffle,
693        nn.PixelUnshuffle,
694        nn.ReLU6,
695    }
696
697    METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
698        "sigmoid_",
699        "sigmoid",
700        "tanh_",
701        "tanh",
702        "hardsigmoid_",
703        "hardsigmoid",
704        "relu_",
705        "relu",
706    }
707
708    return {
709        "funs_io_type_fp32": FUNS_IO_TYPE_FP32,
710        "funs_io_type_fp16": FUNS_IO_TYPE_FP16,
711        "funs_io_type_int8": FUNS_IO_TYPE_INT8,
712        "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8,
713        "mods_io_type_fp32": MODS_IO_TYPE_FP32,
714        "mods_io_type_int8": MODS_IO_TYPE_INT8,
715        "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8,
716        "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8,
717    }
718
719
720def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
721    FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
722        torch.quantize_per_tensor,
723        operator.getitem,
724    }
725
726    MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
727        nn.Identity,
728    }
729
730    METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
731        "to",
732        "dequantize",
733        "reshape",
734        "view",
735        "unsqueeze_",
736        "unsqueeze",
737        "transpose",
738        "squeeze_",
739        "squeeze",
740        "size",
741        "shape",
742        "resize_",
743        "repeat_interleave",
744        "repeat",
745        "permute",
746        "numel",
747        "mean",
748        "detach_",
749        "detach",
750        "contiguous",
751        "clamp",
752        "chunk",
753    }
754
755    return {
756        "funs_unmatchable": FUNS_UNMATCHABLE,
757        "mods_unmatchable": MODS_UNMATCHABLE,
758        "meths_unmatchable": METHS_UNMATCHABLE,
759    }
760