1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2import logging 3from typing import Any, Dict, List, Optional, Union 4 5import numpy as np 6 7import torch 8 9from executorch.exir import EdgeProgramManager 10from executorch.exir.dialects._ops import ops as exir_ops 11 12from executorch.exir.pass_base import ExportPass 13from executorch.exir.tensor import scalar_type_enum 14from torch.fx.passes.infra.pass_base import PassResult 15 16logger = logging.getLogger(__name__) 17 18 19def quantize_input( 20 exported_program, input_index, qparams: Optional[Dict[str, Any]] = None 21): 22 """ 23 Modify the program to expect quantized input at given index. The input is expected 24 to be quantizing this input as the first step. Must be called before 25 permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the 26 expected quantization. 27 """ 28 graph = exported_program.graph_module.graph 29 name = exported_program.graph_signature.user_inputs[input_index] 30 placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name] 31 assert placeholders 32 target_placeholder = placeholders[0] 33 34 if len(target_placeholder.users) != 1: 35 raise ValueError(f"Input {input_index} has more than one users") 36 quantize = next(iter(target_placeholder.users)) 37 if ( 38 quantize.target 39 != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 40 ): 41 raise ValueError(f"Input {input_index} is not used by a quantize op") 42 43 # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op 44 need_requant = False 45 if qparams is not None: 46 assert all( 47 qparam in qparams for qparam in ["scale", "zp", "dtype"] 48 ), "dtype/scale/zp must be specified in qparam for input requantization" 49 if qparams["dtype"] != quantize.args[5]: 50 if any( 51 dtype 52 not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16] 53 for dtype in [qparams["dtype"], quantize.args[5]] 54 ): 55 raise ValueError( 56 f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}" 57 ) 58 59 need_requant = True 60 elif ( 61 not np.isclose(qparams["scale"], quantize.args[1]) 62 or qparams["zp"] != quantize.args[2] 63 ): 64 need_requant = True 65 66 if need_requant: 67 assert qparams is not None 68 dtype = qparams["dtype"] 69 qmin = torch.iinfo(dtype).min 70 qmax = torch.iinfo(dtype).max 71 scale = qparams["scale"] 72 zero_point = qparams["zp"] 73 quant_args = (scale, zero_point, qmin, qmax, dtype) 74 logger.info( 75 f"Modifying program to requantize quantized input at index {input_index}" 76 ) 77 logger.info(f"Quantization parameters: {quant_args}") 78 79 with exported_program.graph_module.graph.inserting_before(quantize): 80 input_dequant = exported_program.graph_module.graph.call_function( 81 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 82 args=( 83 target_placeholder, 84 *quant_args, 85 ), 86 ) 87 input_dequant.meta["input_qparams"] = [ 88 { 89 "scale": scale, 90 "zero_point": zero_point, 91 "qmin": qmin, 92 "qmax": qmax, 93 "dtype": dtype, 94 } 95 ] 96 input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32) 97 target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype) 98 quantize.replace_input_with(target_placeholder, input_dequant) 99 else: 100 quant_args = quantize.args[1:] 101 logger.info(f"Modifying program to take quantized input at index {input_index}") 102 logger.info(f"Quantization parameters: {quant_args}") 103 104 target_placeholder.meta["val"] = ( 105 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( 106 target_placeholder.meta["val"], *quant_args 107 ) 108 ) 109 quantize.replace_all_uses_with(quantize.args[0]) 110 111 exported_program.graph_module.graph.eliminate_dead_code() 112 return quant_args 113 114 115def quantize_output(exported_program, output_index): 116 """ 117 Modify the program to produce quantized output at given index. The model is expected 118 to be dequantizing this output as the last step. Must be called before 119 permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the 120 output quantization. 121 """ 122 graph = exported_program.graph_module.graph 123 outputs = [n for n in graph.nodes if n.op == "output"] 124 if len(outputs) != 1: 125 raise NotImplementedError("Only 1 output node is supported") 126 127 output_node = outputs[0] 128 output_list = list(output_node.args[0]) 129 if output_index >= len(output_list): 130 raise ValueError( 131 f"{len(output_list)} outputs available, " 132 + f"output index out of bounds: {output_index}" 133 ) 134 135 target_output = output_list[output_index] 136 if ( 137 target_output.target 138 != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 139 ): 140 raise ValueError("Output {output_index} is not a dequantize op") 141 142 dequant = target_output 143 output_list[output_index] = dequant.args[0] 144 output_node.args = (output_list,) 145 dequant_args = dequant.args[1:] 146 graph.eliminate_dead_code() 147 148 logger.info( 149 f"Modifying program to produce quantized output at index {output_index}" 150 ) 151 logger.info(f"Dequantization parameters: {dequant_args}") 152 return dequant_args 153 154 155def get_config_method_name( 156 prefix: Optional[str] = "forward", 157 arg_type: str = "input", 158 index: int = 0, 159 key: str = "scale", 160): 161 if prefix is None: 162 prefix = "" 163 else: 164 prefix = prefix + "_" 165 assert arg_type in ["input", "output"], "arg_type must be either input or output" 166 assert index >= 0, "index must be non-negative" 167 assert key in [ 168 "scale", 169 "zp", 170 "quant_min", 171 "quant_max", 172 "dtype", 173 ], "key must be one of scale, zp, quant_min, quant_max, dtype" 174 return f"{prefix}{arg_type}{index}_{key}" 175 176 177class QuantizeInputs(ExportPass): 178 def __init__( 179 self, 180 edge_program_manager: EdgeProgramManager, 181 quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], 182 method_name: Optional[str] = None, 183 ): 184 super().__init__() 185 self.edge_program_manager = edge_program_manager 186 187 self.quantized_inputs_idx_dict = {} 188 if isinstance(quantized_inputs_idx, dict): 189 self.quantized_inputs_idx_dict = quantized_inputs_idx 190 else: 191 for idx in quantized_inputs_idx: 192 self.quantized_inputs_idx_dict[idx] = None 193 self.param_prefix_name = method_name 194 195 def call(self, graph_module: torch.fx.GraphModule): 196 for i, qparams in self.quantized_inputs_idx_dict.items(): 197 quant_args = quantize_input( 198 self.edge_program_manager.exported_program(), i, qparams 199 ) 200 201 if not self.edge_program_manager._config_methods: 202 self.edge_program_manager._config_methods = {} 203 204 self.edge_program_manager._config_methods[ 205 get_config_method_name(self.param_prefix_name, "input", i, "scale") 206 ] = quant_args[0] 207 self.edge_program_manager._config_methods[ # pyre-ignore 208 get_config_method_name(self.param_prefix_name, "input", i, "zp") 209 ] = quant_args[1] 210 self.edge_program_manager._config_methods[ 211 get_config_method_name(self.param_prefix_name, "input", i, "quant_min") 212 ] = quant_args[2] 213 self.edge_program_manager._config_methods[ 214 get_config_method_name(self.param_prefix_name, "input", i, "quant_max") 215 ] = quant_args[3] 216 self.edge_program_manager._config_methods[ 217 get_config_method_name(self.param_prefix_name, "input", i, "dtype") 218 ] = scalar_type_enum(quant_args[4]) 219 return PassResult(graph_module, True) 220 221 222class QuantizeOutputs(ExportPass): 223 def __init__( 224 self, 225 edge_program_manager: EdgeProgramManager, 226 quantized_outputs_idx_list: List[int], 227 method_name: Optional[str] = None, 228 ): 229 super().__init__() 230 self.edge_program_manager = edge_program_manager 231 self.quantized_outputs_idx_list = quantized_outputs_idx_list 232 self.param_prefix_name = method_name 233 234 def call(self, graph_module: torch.fx.GraphModule): 235 for i in self.quantized_outputs_idx_list: 236 dequant_args = quantize_output( 237 self.edge_program_manager.exported_program(), i 238 ) # noqa F841 239 240 if not self.edge_program_manager._config_methods: 241 self.edge_program_manager._config_methods = {} 242 243 self.edge_program_manager._config_methods[ 244 get_config_method_name(self.param_prefix_name, "output", i, "scale") 245 ] = dequant_args[0] 246 self.edge_program_manager._config_methods[ # pyre-ignore 247 get_config_method_name(self.param_prefix_name, "output", i, "zp") 248 ] = dequant_args[1] 249 self.edge_program_manager._config_methods[ 250 get_config_method_name(self.param_prefix_name, "output", i, "quant_min") 251 ] = dequant_args[2] 252 self.edge_program_manager._config_methods[ 253 get_config_method_name(self.param_prefix_name, "output", i, "quant_max") 254 ] = dequant_args[3] 255 self.edge_program_manager._config_methods[ 256 get_config_method_name(self.param_prefix_name, "output", i, "dtype") 257 ] = scalar_type_enum(dequant_args[4]) 258 259 return PassResult(graph_module, True) 260