1# mypy: allow-untyped-defs 2import copy 3from typing import Any, Callable, List, Optional, Tuple, Type, Union 4 5import torch 6from torch.ao.quantization.experimental.adaround_fake_quantize import ( 7 AdaroundFakeQuantizer, 8) 9from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss 10from torch.ao.quantization.observer import MinMaxObserver 11from torch.nn import functional as F 12from torch.nn.parallel import DataParallel 13from torch.utils.data import DataLoader, TensorDataset 14 15 16class AdaptiveRoundingOptimizer: 17 def __init__( 18 self, 19 model: Union[torch.nn.Module, torch.nn.DataParallel], 20 callback: Callable[ 21 [ 22 Union[torch.nn.Module, torch.nn.DataParallel], 23 Any, 24 Optional[torch.nn.Module], 25 ], 26 None, 27 ], 28 forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], 29 data: Any, 30 observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, 31 max_iter=10000, 32 dtype: torch.dtype = torch.qint8, 33 quant_min=-128, 34 quant_max=127, 35 qscheme: torch.qscheme = torch.per_tensor_symmetric, 36 batch_size: int = 256, 37 feed_forward_wrapper: Optional[torch.nn.Module] = None, 38 ): 39 if torch.cuda.is_available(): 40 self.model = model.cuda() 41 if torch.cuda.device_count() > 1: 42 self.model = torch.nn.DataParallel(model) 43 else: 44 self.model = model 45 self.q_model = copy.deepcopy(self.model) 46 self.device = torch.device("cuda") if torch.cuda.is_available() else None 47 self.callback = callback 48 self.forward_hook_wrapper = forward_hook_wrapper 49 # TODO rather than having a data as list type or, we better pass *iterator* instead of list 50 self.data = data 51 self.batch_size = min(batch_size, len(data)) 52 self.max_iter = max_iter 53 self.adaptive_round_loss_fn = AdaptiveRoundingLoss( 54 max_iter=self.max_iter, warm_start=0.2 55 ) 56 self.dtype = dtype 57 self.observer = observer 58 self.quant_min = quant_min 59 self.quant_max = quant_max 60 self.qscheme = qscheme 61 self.feed_forward_wrapper = feed_forward_wrapper 62 63 def run_adaround(self) -> torch.nn.Module: 64 layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] 65 for (name, module), q_module in zip( 66 self.model.named_modules(), self.q_model.modules() 67 ): 68 if isinstance(module, torch.nn.ReLU): 69 # Disable all inplace operations 70 module.inplace = False 71 if isinstance(q_module, torch.nn.ReLU): 72 # Disable all inplace operations 73 q_module.inplace = False 74 if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): 75 # Knowing activation ahead-of-time would be helpful for asymmetric formulation 76 # But this is challenging in eager mode, but graph module. 77 layer_list.append((name, module, q_module)) 78 print(f"Total number of layers : {len(layer_list)}") # noqa: G004 79 80 for name, module, q_module in layer_list: 81 print( 82 f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 83 ) 84 self.optimize_adaptive_rounding( 85 module, 86 q_module, 87 None, 88 ) 89 90 return ( 91 self.q_model.module 92 if isinstance(self.q_model, DataParallel) 93 else self.q_model 94 ) 95 96 def get_data_inp_out( 97 self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any] 98 ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: 99 fp_out: List[torch.Tensor] = [] 100 q_input: List[torch.Tensor] = [] 101 fp_input: List[torch.Tensor] = [] 102 fp32_fetcher: List[torch.Tensor] = [] 103 quant_fetcher: List[torch.Tensor] = [] 104 handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher)) 105 handler2 = q_module.register_forward_hook( 106 self.forward_hook_wrapper(quant_fetcher) 107 ) 108 if torch.cuda.is_available(): 109 # Somehow, we need to move the model continuously 110 # Otherwise, the model will be lowered to CPU misteriously 111 self.model = self.model.cuda() 112 self.q_model = self.q_model.cuda() 113 for data_ in data: 114 with torch.no_grad(): 115 self.callback(self.model, data_, self.feed_forward_wrapper) 116 self.callback(self.q_model, data_, self.feed_forward_wrapper) 117 fp32_output = fp32_fetcher[1] 118 quant_input = quant_fetcher[0] 119 fp_out.append(fp32_output) 120 q_input.append(quant_input) 121 fp_input.append(fp32_fetcher[0]) 122 handler1.remove() 123 handler2.remove() 124 return q_input, fp_out, fp_input 125 126 @torch.no_grad() 127 def feed_forward(self, x, weight, module): 128 if isinstance(module, torch.nn.Conv1d): 129 out = torch.nn.functional.conv1d( 130 x, 131 weight, 132 stride=module.stride, 133 padding=module.padding, 134 dilation=module.dilation, 135 groups=module.groups, 136 ) 137 elif isinstance(module, torch.nn.Linear): 138 out = torch.nn.functional.linear( 139 x, 140 weight, 141 bias=module.bias, 142 ) 143 else: 144 raise NotImplementedError 145 return out 146 147 def _compute_and_display_local_losses( 148 self, 149 ada_quantizer: AdaroundFakeQuantizer, 150 q_module: torch.nn.Module, 151 q_inp: torch.Tensor, 152 fp_out: torch.Tensor, 153 ): 154 with torch.no_grad(): 155 ada_quantizer.use_soft_rounding = False 156 q_w_hard_round = ada_quantizer(q_module.weight) 157 out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module) 158 ada_quantizer.use_soft_rounding = True 159 q_w_soft_round = ada_quantizer(q_module.weight) 160 out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) 161 soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) 162 hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) 163 print( 164 f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 165 ) 166 167 def optimize_adaptive_rounding( 168 self, 169 module: torch.nn.Module, 170 q_module: torch.nn.Module, 171 activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 172 ) -> None: 173 ada_quantizer = AdaroundFakeQuantizer( 174 dtype=self.dtype, 175 observer=self.observer, 176 qscheme=self.qscheme, 177 quant_min=self.quant_min, 178 quant_max=self.quant_max, 179 reduce_range=False, 180 ) 181 ada_quantizer.enable_observer() 182 ada_quantizer(q_module.weight) 183 ada_quantizer.disable_observer() 184 ada_quantizer.enable_fake_quant() 185 optimizer = torch.optim.Adam([ada_quantizer.V]) 186 inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) 187 188 print("==================== Before adaround ====================") 189 assert ( 190 torch.abs(out[0] - module(fp_in[0])).sum().item() == 0 191 ), "In-placed activation is detected, please do not use activation in-placed" 192 # Stack the tensors in each list into a single tensor 193 # Assuming inp and out are your lists of tensors 194 inp_tensor = torch.vstack(inp) 195 out_tensor = torch.vstack(out) 196 dataset = TensorDataset(inp_tensor, out_tensor) 197 dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) 198 199 self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) 200 global_idx = 0 201 one_iter = len(out) // self.batch_size 202 for iteration in range(self.max_iter // one_iter): 203 reconstruction_loss = regularization_loss = torch.tensor(0) 204 for q_inp, fp_out in dataloader: 205 optimizer.zero_grad() 206 q_weight = ada_quantizer(q_module.weight) 207 if isinstance(module, torch.nn.Conv1d): 208 q_out = torch.nn.functional.conv1d( 209 q_inp, 210 q_weight, 211 bias=q_module.bias, 212 stride=q_module.stride, 213 padding=q_module.padding, 214 dilation=q_module.dilation, 215 groups=q_module.groups, 216 ) 217 elif isinstance(q_module, torch.nn.Linear): 218 q_out = torch.nn.functional.linear( 219 q_inp, 220 q_weight, 221 bias=q_module.bias, 222 ) 223 else: 224 raise NotImplementedError 225 regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn( 226 fp_out, 227 q_out, 228 ada_quantizer.V, 229 curr_iter=global_idx, 230 ) 231 loss = regularization_loss + reconstruction_loss 232 loss.backward() 233 optimizer.step() 234 global_idx += 1 235 if global_idx >= self.max_iter: 236 break 237 if global_idx >= self.max_iter: 238 break 239 if iteration % 30 == 0: 240 print( 241 f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 242 f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 243 ) 244 print("==================== After adaround ====================") 245 self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) 246 247 ada_quantizer.use_soft_rounding = True 248 ada_quantizer.V.requires_grad = False 249 ada_quantizer = ada_quantizer.eval() 250 q_weight = ada_quantizer(q_module.weight) 251 # At the end of optimization, we need to copy the adarounded weight back to the original module 252 q_module.weight.data.copy_(q_weight) 253 # Eager mode requires observer to be set as "weight_fake_quant" to be parsed 254 q_module.weight_fake_quant = ada_quantizer.activation_post_process 255