1# mypy: allow-untyped-defs 2import torch 3 4from ._common_operator_config_utils import ( 5 _get_binary_op_configs, 6 _get_conv_configs, 7 _get_linear_configs, 8 _get_share_qparams_op_configs, 9 _get_tensor_info_op_configs, 10) 11from .backend_config import ( 12 BackendConfig, 13 BackendPatternConfig, 14 DTypeConfig, 15 ObservationType, 16) 17 18 19__all__ = [ 20 "get_tensorrt_backend_config", 21 "get_tensorrt_backend_config_dict", 22] 23 24 25def get_tensorrt_backend_config() -> BackendConfig: 26 """ 27 Return the `BackendConfig` for the TensorRT backend. 28 NOTE: Current api will change in the future, it's just to unblock experimentation for 29 new backends, please don't use it right now. 30 TODO: add a README when it's more stable 31 """ 32 # dtype configs 33 weighted_op_qint8_dtype_config = DTypeConfig( 34 input_dtype=torch.qint8, 35 output_dtype=torch.qint8, 36 weight_dtype=torch.qint8, 37 bias_dtype=torch.float, 38 ) 39 non_weighted_op_qint8_dtype_config = DTypeConfig( 40 input_dtype=torch.qint8, 41 output_dtype=torch.qint8, 42 ) 43 44 addmm_config = ( 45 BackendPatternConfig(torch.addmm) 46 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) 47 .add_dtype_config(weighted_op_qint8_dtype_config) 48 ._set_input_type_to_index( 49 { 50 "bias": 0, 51 "input": 1, 52 "weight": 2, 53 } 54 ) 55 ) 56 cat_config = ( 57 BackendPatternConfig(torch.cat) 58 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 59 .add_dtype_config(non_weighted_op_qint8_dtype_config) 60 ) 61 conv_dtype_configs = [ 62 weighted_op_qint8_dtype_config, 63 ] 64 linear_dtype_configs = [ 65 weighted_op_qint8_dtype_config, 66 ] 67 binary_op_dtype_configs = [ 68 weighted_op_qint8_dtype_config, 69 ] 70 share_qparams_op_dtype_configs = [ 71 non_weighted_op_qint8_dtype_config, 72 ] 73 tensor_info_op_dtype_configs = [ 74 non_weighted_op_qint8_dtype_config, 75 ] 76 # there might be things not supported in fx2trt, but it will error out 77 # during fx2trt conversion and can support them after that 78 return ( 79 BackendConfig("tensorrt") 80 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 81 .set_backend_pattern_config(addmm_config) 82 .set_backend_pattern_config(cat_config) 83 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 84 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 85 .set_backend_pattern_configs( 86 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 87 ) 88 .set_backend_pattern_configs( 89 _get_tensor_info_op_configs(tensor_info_op_dtype_configs) 90 ) 91 ) 92 93 94def get_tensorrt_backend_config_dict(): 95 """ 96 Return the `BackendConfig` for the TensorRT backend in dictionary form. 97 """ 98 return get_tensorrt_backend_config().to_dict() 99