xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/onednn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import operator
4
5import torch
6import torch.ao.nn.intrinsic as nni
7import torch.ao.nn.quantized.reference as nnqr
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
11from torch.ao.quantization.utils import MatchAllNode
12
13from ._common_operator_config_utils import (
14    _get_binary_op_configs,
15    _get_bn_configs,
16    _get_cat_config,
17    _get_conv_configs,
18    _get_default_op_configs,
19    _get_embedding_op_configs,
20    _get_fixed_qparams_op_configs,
21    _get_linear_configs,
22    _get_ln_configs,
23    _get_rnn_op_configs,
24    _get_share_qparams_op_configs,
25)
26from .backend_config import (
27    BackendConfig,
28    BackendPatternConfig,
29    DTypeConfig,
30    ObservationType,
31)
32
33
34# ===================
35# |  DTYPE CONFIGS  |
36# ===================
37
38onednn_weighted_op_int8_dtype_config = DTypeConfig(
39    input_dtype=torch.quint8,
40    output_dtype=torch.quint8,
41    weight_dtype=torch.qint8,
42    bias_dtype=torch.float,
43)
44
45onednn_op_quint8_dtype_config = DTypeConfig(
46    input_dtype=torch.quint8,
47    output_dtype=torch.quint8,
48)
49
50onednn_dynamic_int8_dtype_config = DTypeConfig(
51    input_dtype=torch.quint8,
52    output_dtype=torch.float,
53    weight_dtype=torch.qint8,
54    bias_dtype=torch.float,
55    is_dynamic=True,
56)
57
58onednn_weight_only_qint8_dtype_config = DTypeConfig(
59    input_dtype=torch.float,
60    output_dtype=torch.float,
61    weight_dtype=torch.qint8,
62)
63
64onednn_input_output_only_quint8_dtype_config = DTypeConfig(
65    input_dtype=torch.quint8,
66    output_dtype=torch.quint8,
67    weight_dtype=torch.float,
68    bias_dtype=torch.float,
69)
70
71# ===================
72# |  FUSER METHODS  |
73# ===================
74
75
76def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
77    r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
78    Args:
79        is_qat: a flag for whether we are using quantization aware training fusion
80                or post training quantization fusion
81        linear: Module instance of type Linear
82        bn: BatchNorm1d instance that needs to be fused with the linear layer
83        leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
84    Examples::
85        >>> # xdoctest: +SKIP(failing)
86        >>> m1 = nn.Linear(20, 10)
87        >>> b1 = nn.BatchNorm1d(10)
88        >>> lr = nn.LeakyReLU(0.01)
89        >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
90    """
91    assert (
92        linear.training == bn.training and bn.training == leaky_relu.training
93    ), "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
94
95    if is_qat:
96        raise NotImplementedError(
97            f"Cannot fuse train modules: {(linear, bn, leaky_relu)}"
98        )
99    else:
100        map_to_fused_module_eval = {
101            nn.Linear: nni.LinearLeakyReLU,
102        }
103        fused_module = map_to_fused_module_eval.get(type(linear), None)
104        if fused_module is not None:
105            fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
106            fm = fused_module(fused_linear, leaky_relu)
107            return fm
108        else:
109            raise NotImplementedError(
110                f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}"
111            )
112
113
114# ======================
115# |  CONFIGS FOR CONV  |
116# ======================
117observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
118
119conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
120conv_configs = _get_conv_configs(conv_dtype_configs)
121
122# (1) Conv2d + Add
123
124# conv2d   Y
125#   \   /
126#    add
127
128# include:
129# conv2d conv2d
130#   \   /
131#    add
132
133
134def _fuse_conv_add_left(is_qat, add, conv, _):
135    return nni.ConvAdd2d(conv, add)
136
137
138def _conv_add_root_node_getter_left(pattern):
139    _, conv, _ = pattern
140    return conv
141
142
143def _conv_add_extra_inputs_getter_left(pattern):
144    """get inputs pattern for extra inputs, inputs for root node
145    are assumed to be copied over from root node to the fused node
146    """
147    _, conv, extra_input = pattern
148    return [extra_input]
149
150
151# conv2d
152#  \
153#  bn   Y
154#   \   /
155#    add
156
157
158def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
159    bn, conv = bn_conv
160    if is_qat:
161        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
162    else:
163        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
164        return nni.ConvAdd2d(fused_conv, add)
165
166
167def _conv_bn_add_root_node_getter_left(add_pattern):
168    _, bn_conv, _ = add_pattern
169    bn, conv = bn_conv
170    return conv
171
172
173def _conv_bn_add_extra_inputs_getter_left(add_pattern):
174    """get inputs pattern for extra inputs, inputs for root node
175    are assumed to be copied over from root node to the fused node
176    """
177    _, bn_conv, extra_input = add_pattern
178    bn, conv = bn_conv
179    return [extra_input]
180
181
182conv_add_left_optioins = itertools.product(
183    [True, False],  # with_bn
184    [torch.add, operator.add],  # add_op
185)
186
187for with_bn, add_op in conv_add_left_optioins:
188    if with_bn:
189        conv_configs.append(
190            BackendPatternConfig()
191            ._set_pattern_complex_format(
192                (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)
193            )  # noqa: E131
194            .set_observation_type(observation_type)
195            .set_dtype_configs(conv_dtype_configs)
196            .set_fuser_method(_fuse_conv_bn_add_left)
197            ._set_root_node_getter(_conv_bn_add_root_node_getter_left)
198            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
199            .set_fused_module(nni.ConvAdd2d)
200        )
201    else:
202        conv_configs.append(
203            BackendPatternConfig()
204            ._set_pattern_complex_format(
205                (add_op, nn.Conv2d, MatchAllNode)
206            )  # noqa: E131
207            .set_observation_type(observation_type)
208            .set_dtype_configs(conv_dtype_configs)
209            .set_fuser_method(_fuse_conv_add_left)
210            ._set_root_node_getter(_conv_add_root_node_getter_left)
211            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
212            .set_fused_module(nni.ConvAdd2d)
213        )
214
215#  Y   conv2d
216#   \   /
217#    add
218
219
220def _fuse_conv_add_right(is_qat, add, _, conv):
221    return nni.ConvAdd2d(conv, add)
222
223
224def _conv_add_root_node_getter_right(pattern):
225    add, _, conv = pattern
226    return conv
227
228
229def _conv_add_extra_inputs_getter_right(pattern):
230    """get inputs pattern for extra inputs, inputs for root node
231    are assumed to be copied over from root node to the fused node
232    """
233    _, extra_input, conv = pattern
234    return [extra_input]
235
236
237#      conv2d
238#        /
239#  Y    bn
240#   \   /
241#    add
242
243
244def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
245    bn, conv = bn_conv
246    if is_qat:
247        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
248    else:
249        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
250        return nni.ConvAdd2d(fused_conv, add)
251
252
253def _conv_bn_add_root_node_getter_right(pattern):
254    add, _, bn_conv = pattern
255    bn, conv = bn_conv
256    return conv
257
258
259def _conv_bn_add_extra_inputs_getter_right(pattern):
260    """get inputs pattern for extra inputs, inputs for root node
261    are assumed to be copied over from root node to the fused node
262    """
263    _, extra_input, bn_conv = pattern
264    bn, conv = bn_conv
265    return [extra_input]
266
267
268conv_add_optioins = itertools.product(
269    [True, False],  # with_bn
270    [torch.add, operator.add],  # add_op
271)
272
273for with_bn, add_op in conv_add_optioins:
274    if with_bn:
275        conv_configs.append(
276            BackendPatternConfig()
277            ._set_pattern_complex_format(
278                (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))
279            )  # noqa: E131
280            .set_observation_type(observation_type)
281            .set_dtype_configs(conv_dtype_configs)
282            .set_fuser_method(_fuse_conv_bn_add_right)
283            ._set_root_node_getter(_conv_bn_add_root_node_getter_right)
284            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
285            .set_fused_module(nni.ConvAdd2d)
286        )
287    else:
288        conv_configs.append(
289            BackendPatternConfig()
290            ._set_pattern_complex_format(
291                (add_op, MatchAllNode, nn.Conv2d)
292            )  # noqa: E131
293            .set_observation_type(observation_type)
294            .set_dtype_configs(conv_dtype_configs)
295            .set_fuser_method(_fuse_conv_add_right)
296            ._set_root_node_getter(_conv_add_root_node_getter_right)
297            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
298            .set_fused_module(nni.ConvAdd2d)
299        )
300
301conv_configs.append(
302    BackendPatternConfig(nni.ConvAdd2d)
303    .set_observation_type(observation_type)  # noqa: E131
304    .set_dtype_configs(conv_dtype_configs)
305    .set_root_module(nn.Conv2d)
306    .set_reference_quantized_module(nnqr.Conv2d)
307)
308
309# (2) Conv2d + Add + Relu
310
311# conv2d Y
312#   \   /
313#    add
314#     \
315#     relu
316
317
318def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
319    add, conv, _ = add_pattern
320    return nni.ConvAddReLU2d(conv, add, relu)
321
322
323def _conv_add_relu_root_node_getter_left(pattern):
324    relu, add_pattern = pattern
325    _, conv, _ = add_pattern
326    return conv
327
328
329def _conv_add_relu_extra_inputs_getter_left(pattern):
330    """get inputs pattern for extra inputs, inputs for root node
331    are assumed to be copied over from root node to the fused node
332    """
333    relu, add_pattern = pattern
334    _, conv, extra_input = add_pattern
335    return [extra_input]
336
337
338# conv2d
339#  \
340#  bn   Y
341#   \   /
342#    add
343#     \
344#     relu
345
346
347def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
348    add, bn_conv, _ = add_pattern
349    bn, conv = bn_conv
350    if is_qat:
351        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
352    else:
353        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
354        return nni.ConvAddReLU2d(fused_conv, add, relu)
355
356
357def _conv_bn_add_relu_root_node_getter_left(pattern):
358    relu, add_pattern = pattern
359    _, bn_conv, _ = add_pattern
360    bn, conv = bn_conv
361    return conv
362
363
364def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
365    """get inputs pattern for extra inputs, inputs for root node
366    are assumed to be copied over from root node to the fused node
367    """
368    relu, add_pattern = pattern
369    _, bn_conv, extra_input = add_pattern
370    bn, conv = bn_conv
371    return [extra_input]
372
373
374conv_add_relu_left_optioins = itertools.product(
375    [True, False],  # with_bn
376    [torch.add, operator.add],  # add_op
377)
378
379for with_bn, add_op in conv_add_relu_left_optioins:
380    if with_bn:
381        conv_configs.append(
382            BackendPatternConfig()
383            ._set_pattern_complex_format(
384                (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
385            )  # noqa: E131
386            .set_observation_type(observation_type)
387            .set_dtype_configs(conv_dtype_configs)
388            .set_fuser_method(_fuse_conv_bn_add_relu_left)
389            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
390            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
391            .set_fused_module(nni.ConvAddReLU2d)
392        )
393    else:
394        conv_configs.append(
395            BackendPatternConfig()
396            ._set_pattern_complex_format(
397                (nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))
398            )  # noqa: E131
399            .set_observation_type(observation_type)
400            .set_dtype_configs(conv_dtype_configs)
401            .set_fuser_method(_fuse_conv_add_relu_left)
402            ._set_root_node_getter(_conv_add_relu_root_node_getter_left)
403            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left)
404            .set_fused_module(nni.ConvAddReLU2d)
405        )
406
407#  Y   conv2d
408#   \   /
409#    add
410#     \
411#     relu
412
413
414def _fuse_conv_add_relu_right(is_qat, relu, add_pattern):
415    add, _, conv = add_pattern
416    return nni.ConvAddReLU2d(conv, add, relu)
417
418
419def _conv_add_relu_root_node_getter_right(pattern):
420    relu, add_pattern = pattern
421    _, _, conv = add_pattern
422    return conv
423
424
425def _conv_add_relu_extra_inputs_getter_right(pattern):
426    """get inputs pattern for extra inputs, inputs for root node
427    are assumed to be copied over from root node to the fused node
428    """
429    relu, add_pattern = pattern
430    _, extra_input, conv = add_pattern
431    return [extra_input]
432
433
434#      conv2d
435#        /
436#  Y    bn
437#   \   /
438#    add
439#     \
440#     relu
441
442
443def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern):
444    add, _, bn_conv = add_pattern
445    bn, conv = bn_conv
446    if is_qat:
447        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
448    else:
449        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
450        return nni.ConvAddReLU2d(fused_conv, add, relu)
451
452
453def _conv_bn_add_relu_root_node_getter_right(pattern):
454    relu, add_pattern = pattern
455    _, _, bn_conv = add_pattern
456    bn, conv = bn_conv
457    return conv
458
459
460def _conv_bn_add_relu_extra_inputs_getter_right(pattern):
461    """get inputs pattern for extra inputs, inputs for root node
462    are assumed to be copied over from root node to the fused node
463    """
464    relu, add_pattern = pattern
465    _, extra_input, bn_conv = add_pattern
466    bn, conv = bn_conv
467    return [extra_input]
468
469
470conv_add_relu_optioins = itertools.product(
471    [True, False],  # with_bn
472    [torch.add, operator.add],  # add_op
473)
474
475for with_bn, add_op in conv_add_relu_optioins:
476    if with_bn:
477        conv_configs.append(
478            BackendPatternConfig()
479            ._set_pattern_complex_format(
480                (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))
481            )  # noqa: E131
482            .set_observation_type(observation_type)
483            .set_dtype_configs(conv_dtype_configs)
484            .set_fuser_method(_fuse_conv_bn_add_relu_right)
485            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right)
486            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right)
487            .set_fused_module(nni.ConvAddReLU2d)
488        )
489    else:
490        conv_configs.append(
491            BackendPatternConfig()
492            ._set_pattern_complex_format(
493                (nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))
494            )  # noqa: E131
495            .set_observation_type(observation_type)
496            .set_dtype_configs(conv_dtype_configs)
497            .set_fuser_method(_fuse_conv_add_relu_right)
498            ._set_root_node_getter(_conv_add_relu_root_node_getter_right)
499            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right)
500            .set_fused_module(nni.ConvAddReLU2d)
501        )
502
503conv_configs.append(
504    BackendPatternConfig(nni.ConvAddReLU2d)
505    .set_observation_type(observation_type)  # noqa: E131
506    .set_dtype_configs(conv_dtype_configs)
507    .set_root_module(nn.Conv2d)
508    .set_reference_quantized_module(nnqr.Conv2d)
509)
510
511# ========================
512# |  CONFIGS FOR LINEAR  |
513# ========================
514
515linear_dtype_configs = [
516    onednn_weighted_op_int8_dtype_config,
517    onednn_dynamic_int8_dtype_config,
518]
519linear_configs = _get_linear_configs(linear_dtype_configs)
520
521
522def _add_eltwise_fusion_configs(
523    configs,
524    root_module,
525    root_op,
526    post_module,
527    post_op,
528    dtype_configs,
529    fuser_method,
530    fused_module,
531    observation_type,
532    ref_quant_module,
533):
534    # 1 base module + op module fusion config
535    configs.append(
536        BackendPatternConfig((root_module, post_module))
537        .set_dtype_configs(dtype_configs)  # noqa: E131
538        .set_fuser_method(fuser_method)
539        .set_fused_module(fused_module)
540    )
541    # base module + functional post op
542    configs.append(
543        BackendPatternConfig((root_module, post_op))
544        .set_dtype_configs(dtype_configs)  # noqa: E131
545        .set_fuser_method(fuser_method)
546        .set_fused_module(fused_module)
547    )
548
549    # 2 fused module configs
550    configs.append(
551        BackendPatternConfig(fused_module)
552        .set_observation_type(observation_type)  # noqa: E131
553        .set_dtype_configs(dtype_configs)
554        .set_root_module(root_module)
555        .set_reference_quantized_module(ref_quant_module)
556    )
557
558    # 3 functional base op + post op configs
559    configs.append(
560        BackendPatternConfig((root_op, post_module))
561        .set_observation_type(observation_type)  # noqa: E131
562        .set_dtype_configs(dtype_configs)
563    )
564    configs.append(
565        BackendPatternConfig((root_op, post_op))
566        .set_observation_type(observation_type)  # noqa: E131
567        .set_dtype_configs(dtype_configs)
568    )
569
570
571# Configs for linear + leaky_relu fusion
572_add_eltwise_fusion_configs(
573    linear_configs,
574    nn.Linear,
575    F.linear,
576    nn.LeakyReLU,
577    F.leaky_relu,
578    linear_dtype_configs,
579    _sequential_wrapper2(nni.LinearLeakyReLU),
580    nni.LinearLeakyReLU,
581    observation_type,
582    nnqr.Linear,
583)
584
585# Configs for linear module + batchnorm + leaky_relu
586linear_configs.append(
587    BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU))
588    .set_dtype_configs(linear_dtype_configs)  # noqa: E131
589    .set_fuser_method(_fuse_linear_bn_leaky_relu)
590    .set_fused_module(nni.LinearLeakyReLU)
591)
592
593# Configs for linear + tanh fusion
594_add_eltwise_fusion_configs(
595    linear_configs,
596    nn.Linear,
597    F.linear,
598    nn.Tanh,
599    torch.tanh,
600    linear_dtype_configs,
601    _sequential_wrapper2(nni.LinearTanh),
602    nni.LinearTanh,
603    observation_type,
604    nnqr.Linear,
605)
606
607# ===========================
608# |  CONFIGS FOR OTHER OPS  |
609# ===========================
610
611binary_op_dtype_configs = [onednn_op_quint8_dtype_config]
612default_op_dtype_configs = [onednn_op_quint8_dtype_config]
613fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
614share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
615rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config]
616embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config]
617layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config]
618
619# =====================
620# |  BACKEND CONFIGS  |
621# =====================
622
623
624def get_onednn_backend_config() -> BackendConfig:
625    """
626    Return the `BackendConfig` for PyTorch's native ONEDNN backend.
627    """
628    return (
629        BackendConfig("onednn")
630        .set_backend_pattern_configs(conv_configs)
631        .set_backend_pattern_configs(linear_configs)
632        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
633        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
634        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
635        .set_backend_pattern_configs(
636            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
637        )
638        .set_backend_pattern_configs(
639            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
640        )
641        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
642        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
643        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
644        .set_backend_pattern_configs(
645            _get_embedding_op_configs(embedding_op_dtype_configs)
646        )
647    )
648
649
650__all__ = [
651    "get_onednn_backend_config",
652]
653