1import copy 2import operator 3from typing import Any, Callable, Optional, Tuple 4 5import torch 6from torch.ao.quantization import ( 7 default_weight_fake_quant, 8 default_weight_observer, 9 FakeQuantizeBase, 10 QConfig, 11 QConfigMapping, 12) 13from torch.ao.quantization.backend_config import BackendConfig 14from torch.ao.quantization.observer import _PartialWrapper 15from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx 16 17 18# TODO: move all LSTM util functions from fx/utils.py to this file 19def _get_lstm_with_individually_observed_parts( 20 float_lstm: torch.nn.LSTM, 21 example_inputs: Tuple[Any, ...], 22 backend_config: Optional[BackendConfig] = None, 23 linear_output_obs_ctr: Optional[_PartialWrapper] = None, 24 sigmoid_obs_ctr: Optional[_PartialWrapper] = None, 25 tanh_obs_ctr: Optional[_PartialWrapper] = None, 26 cell_state_obs_ctr: Optional[_PartialWrapper] = None, 27 hidden_state_obs_ctr: Optional[_PartialWrapper] = None, 28) -> torch.ao.nn.quantizable.LSTM: 29 """ 30 Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` 31 with specific observers or fake quantizes assigned to the inner ops or submodules. 32 33 In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is 34 used as an observed custom module, which is responsible for inserting its own 35 observers. By default, all inner ops inherit the parent custom module's QConfig. 36 Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` 37 and use this helper function to customize the observer insertion logic. 38 39 This is meant to be used to convert a float module to an observed module in the 40 custom module flow. 41 42 Args: 43 `float_lstm`: The float LSTM module 44 `example_inputs`: example inputs for the forward function of the LSTM module 45 `backend_config`: BackendConfig to use to observe the LSTM module 46 `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, 47 where W is the weight matrix, b is the bias, and x is either the inputs 48 or the hidden state from the previous layer (if any) 49 `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations 50 `tanh_obs_ctr`: observer or fake quantize for tanh activations 51 `cell_state_obs_ctr`: observer or fake quantize for the cell state 52 `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and 53 the output 54 55 Return: 56 A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes 57 assigned to the inner ops. 58 """ 59 60 def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: 61 """ 62 Make a QConfig with fixed qparams observers or fake quantizes. 63 """ 64 if isinstance(obs_ctr(), FakeQuantizeBase): 65 weight = default_weight_fake_quant 66 else: 67 weight = default_weight_observer 68 return QConfig(activation=obs_ctr, weight=weight) 69 70 quantizable_lstm = torch.ao.nn.quantizable.LSTM( 71 float_lstm.input_size, 72 float_lstm.hidden_size, 73 float_lstm.num_layers, 74 float_lstm.bias, 75 float_lstm.batch_first, 76 float_lstm.dropout, 77 float_lstm.bidirectional, 78 ) 79 quantizable_lstm.qconfig = float_lstm.qconfig 80 81 for idx in range(float_lstm.num_layers): 82 quantizable_lstm.layers[ 83 idx 84 ] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( 85 float_lstm, idx, float_lstm.qconfig, batch_first=False 86 ) 87 88 # Build QConfigMapping for the LSTM cell 89 # Note: FloatFunctional qconfigs will be configured separately below 90 cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type] 91 if sigmoid_obs_ctr is not None: 92 cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr)) 93 cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr)) 94 cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr)) 95 if tanh_obs_ctr is not None: 96 cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr)) 97 98 # Insert observers into each LSTM cell 99 # TODO: maybe make this work for layer_bw as well 100 for layer in quantizable_lstm.layers: 101 cell = layer.layer_fw.cell 102 cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) 103 # HACK: Manually replace the activation_post_process following these ops. 104 # This is needed for FloatFunctional ops because there is currently no way 105 # to configure these ops in FX graph mode quantization today. This is because 106 # the FloatFunctional modules simply disappear from the graph after tracing. 107 # In the future, we should rewrite quantizable LSTM without FloatFunctionals. 108 op_index_to_activation_post_process_ctr = { 109 (torch.add, 0): linear_output_obs_ctr, # gates.add 110 (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul 111 (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul 112 (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add 113 (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul 114 } 115 add_count = 0 116 mul_count = 0 117 for node in cell.graph.nodes: 118 op_index: Optional[Tuple[Callable, int]] = None # e.g. (torch.add, 1) 119 if node.target == torch.add: 120 op_index = (torch.add, add_count) 121 add_count += 1 122 elif node.target == torch.mul: 123 op_index = (torch.mul, mul_count) 124 mul_count += 1 125 else: 126 # Neither torch.add nor torch.mul 127 continue 128 if op_index not in op_index_to_activation_post_process_ctr: 129 continue 130 assert len(node.users) == 1 131 activation_post_process_name = next(iter(node.users.keys())).name 132 activation_post_process_ctr = op_index_to_activation_post_process_ctr[ 133 op_index 134 ] 135 if activation_post_process_ctr is not None: 136 setattr( 137 cell, activation_post_process_name, activation_post_process_ctr() 138 ) 139 layer.layer_fw.cell = cell 140 return quantizable_lstm 141 142 143def _get_reference_quantized_lstm_module( 144 observed_lstm: torch.ao.nn.quantizable.LSTM, 145 backend_config: Optional[BackendConfig] = None, 146) -> torch.ao.nn.quantized.LSTM: 147 """ 148 Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM` 149 with observers or fake quantizes inserted through `prepare_fx`, e.g. from 150 `_get_lstm_with_individually_observed_parts`. 151 152 This is meant to be used to convert an observed module to a quantized module in the 153 custom module flow. 154 155 Args: 156 `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx` 157 `backend_config`: BackendConfig to use to produce the reference quantized model 158 159 Return: 160 A reference `torch.ao.nn.quantized.LSTM` module. 161 """ 162 quantized_lstm = torch.ao.nn.quantized.LSTM( 163 observed_lstm.input_size, 164 observed_lstm.hidden_size, 165 observed_lstm.num_layers, 166 observed_lstm.bias, 167 observed_lstm.batch_first, 168 observed_lstm.dropout, 169 observed_lstm.bidirectional, 170 ) 171 172 for i, layer in enumerate(quantized_lstm.layers): 173 cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] 174 cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] 175 assert isinstance(cell, torch.fx.GraphModule) 176 # HACK: Manually remove input quantize nodes and output dequantize nodes, 177 # since custom modules expect quint8 inputs and outputs for now. Note that 178 # this functionality is supposedly handled through PrepareCustomConfig's 179 # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that 180 # API doesn't currently handle tuple inputs and outputs, so we have to do 181 # this manually for now. In the future we should (1) relax the restriction 182 # on custom module input/output dtypes, and (2) expand support for complex 183 # input/output structures. 184 for node in cell.graph.nodes: 185 if node.target == torch.quantize_per_tensor: 186 arg = node.args[0] 187 # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) 188 if arg.target == "x" or ( 189 arg.target == operator.getitem and arg.args[0].target == "hidden" 190 ): 191 with cell.graph.inserting_before(node): 192 node.replace_all_uses_with(arg) 193 cell.graph.erase_node(node) 194 if node.target == "output": 195 # Remove all dequantize nodes in the output tuple 196 for arg in node.args[0]: 197 with cell.graph.inserting_before(node): 198 node.replace_input_with(arg, arg.args[0]) 199 cell.graph.eliminate_dead_code() 200 cell.recompile() 201 layer.layer_fw.cell = cell 202 return quantized_lstm 203