xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/tensorrt.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from ._common_operator_config_utils import (
5    _get_binary_op_configs,
6    _get_conv_configs,
7    _get_linear_configs,
8    _get_share_qparams_op_configs,
9    _get_tensor_info_op_configs,
10)
11from .backend_config import (
12    BackendConfig,
13    BackendPatternConfig,
14    DTypeConfig,
15    ObservationType,
16)
17
18
19__all__ = [
20    "get_tensorrt_backend_config",
21    "get_tensorrt_backend_config_dict",
22]
23
24
25def get_tensorrt_backend_config() -> BackendConfig:
26    """
27    Return the `BackendConfig` for the TensorRT backend.
28    NOTE: Current api will change in the future, it's just to unblock experimentation for
29    new backends, please don't use it right now.
30    TODO: add a README when it's more stable
31    """
32    # dtype configs
33    weighted_op_qint8_dtype_config = DTypeConfig(
34        input_dtype=torch.qint8,
35        output_dtype=torch.qint8,
36        weight_dtype=torch.qint8,
37        bias_dtype=torch.float,
38    )
39    non_weighted_op_qint8_dtype_config = DTypeConfig(
40        input_dtype=torch.qint8,
41        output_dtype=torch.qint8,
42    )
43
44    addmm_config = (
45        BackendPatternConfig(torch.addmm)
46        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
47        .add_dtype_config(weighted_op_qint8_dtype_config)
48        ._set_input_type_to_index(
49            {
50                "bias": 0,
51                "input": 1,
52                "weight": 2,
53            }
54        )
55    )
56    cat_config = (
57        BackendPatternConfig(torch.cat)
58        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
59        .add_dtype_config(non_weighted_op_qint8_dtype_config)
60    )
61    conv_dtype_configs = [
62        weighted_op_qint8_dtype_config,
63    ]
64    linear_dtype_configs = [
65        weighted_op_qint8_dtype_config,
66    ]
67    binary_op_dtype_configs = [
68        weighted_op_qint8_dtype_config,
69    ]
70    share_qparams_op_dtype_configs = [
71        non_weighted_op_qint8_dtype_config,
72    ]
73    tensor_info_op_dtype_configs = [
74        non_weighted_op_qint8_dtype_config,
75    ]
76    # there might be things not supported in fx2trt, but it will error out
77    # during fx2trt conversion and can support them after that
78    return (
79        BackendConfig("tensorrt")
80        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
81        .set_backend_pattern_config(addmm_config)
82        .set_backend_pattern_config(cat_config)
83        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
84        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
85        .set_backend_pattern_configs(
86            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
87        )
88        .set_backend_pattern_configs(
89            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
90        )
91    )
92
93
94def get_tensorrt_backend_config_dict():
95    """
96    Return the `BackendConfig` for the TensorRT backend in dictionary form.
97    """
98    return get_tensorrt_backend_config().to_dict()
99