1# mypy: allow-untyped-defs 2import operator 3import warnings 4from collections import namedtuple 5from typing import Any, Dict, List, Optional, Tuple 6 7import torch 8import torch.ao.nn.intrinsic as nni 9import torch.nn as nn 10import torch.nn.functional as F 11from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr 12from torch.ao.quantization.observer import ( 13 _with_args, 14 ObserverBase, 15 PerChannelMinMaxObserver, 16) 17from torch.ao.quantization.utils import _parent_name, check_min_max_valid 18from torch.fx import GraphModule 19from torch.fx.graph import Node 20 21from .utils import ( 22 get_new_attr_name_with_prefix, 23 maybe_get_next_module, 24 node_arg_is_weight, 25) 26 27 28CUSTOM_MODULE_SUPP_LIST: List[Any] = [] 29 30 31def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: 32 """Reshapes the scale so that we can multiply it to the input by the given axis.""" 33 new_shape = [1] * input.ndim 34 new_shape[axis] = input.size(axis) 35 return scale.view(new_shape) 36 37 38qsheme_mapping_per_tensor_to_per_channel = { 39 torch.per_tensor_affine: torch.per_channel_affine, 40 torch.per_tensor_symmetric: torch.per_channel_symmetric, 41} 42 43 44class _InputEqualizationObserver(nn.Module): 45 r"""Observer for tracking the running min/max values of input columns, and 46 computing the quantization parameters for the overall min/max input values. 47 48 Args: 49 dtype: Quantized data type 50 qscheme: Quantization scheme 51 quant_min: Minimum quantization value. If unspecified, it will 52 follow the 8-bit setup. 53 quant_max: Maximum quantization value. If unspecified, it will 54 follow the 8-bit setup. 55 56 The running minimum/maximum :math:`x_\text{min/max}` are computed in the 57 same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`, 58 with the difference that the running min/max values are stored per column. 59 This observer is intended to be used along with a WeightEqualizationObserver 60 to calculate the equalization scale. 61 """ 62 63 def __init__( 64 self, 65 dtype=torch.quint8, 66 qscheme=torch.per_tensor_affine, 67 quant_min=None, 68 quant_max=None, 69 factory_kwargs=None, 70 ) -> None: 71 super().__init__() 72 73 if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: 74 raise TypeError("Input qscheme must be per-tensor") 75 76 self.dtype = dtype 77 self.qscheme = qscheme 78 79 per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] 80 self.input_obs = PerChannelMinMaxObserver( 81 ch_axis=1, 82 dtype=dtype, 83 qscheme=per_channel_qscheme, 84 quant_min=quant_min, 85 quant_max=quant_max, 86 factory_kwargs=factory_kwargs, 87 ) 88 89 self.equalization_scale = torch.tensor(1) 90 self.equalization_shape: List[int] = [] 91 92 def forward(self, x_orig): 93 if not (x_orig.ndim >= 2 and x_orig.ndim <= 5): 94 raise ValueError( 95 "InputEqualizationObserver only supports Linear and Conv layers" 96 ) 97 98 # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers) 99 self.equalization_shape = [1] * x_orig.ndim 100 self.equalization_shape[1] = x_orig.size(1) 101 102 return self.input_obs(x_orig) 103 104 def get_input_minmax(self): 105 return (self.input_obs.min_val, self.input_obs.max_val) 106 107 def set_equalization_scale(self, equalization_scale): 108 # Reshape the equalization scale along axis=1 so that it can be 109 # multiplied with the input along axis=1 110 if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): 111 return 112 self.equalization_scale = torch.reshape( 113 equalization_scale, self.equalization_shape 114 ) 115 116 def calculate_scaled_minmax(self): 117 r"""Returns the scaled min/max inputs""" 118 if ( 119 self.equalization_scale.nelement() == 1 120 and self.equalization_scale == torch.tensor(1) 121 ): 122 warnings.warn( 123 "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " 124 + "Will not scale the next quantization observer." 125 ) 126 return None, None 127 128 # Calculate qparams for the scaled min/max inputs 129 # Scale the input by the equalization scale located at the same column 130 # index 131 (min_inputs, max_inputs) = self.get_input_minmax() 132 equalization_scale_reshaped = reshape_scale( 133 self.equalization_scale, 0, min_inputs 134 ) 135 min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped)) 136 max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped)) 137 138 return min_input_scaled, max_input_scaled 139 140 with_args = classmethod(_with_args) 141 142 143class _WeightEqualizationObserver(nn.Module): 144 r"""Observer for tracking the running min/max values of weight columns and 145 rows, and computing the quantization parameters for the weight rows. 146 147 Args: 148 dtype: Quantized data type 149 qscheme: Quantization scheme 150 quant_min: Minimum quantization value. If unspecified, it will 151 follow the 8-bit setup. 152 quant_max: Maximum quantization value. If unspecified, it will 153 follow the 8-bit setup. 154 155 This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used 156 to record the running minimum and maximum of columns of incoming weight 157 tensors. This observer is intended to be used along with an 158 InputEqualizationObserver to calculate the equalization scale. 159 160 The running minimum/maximum :math:`w_\text{min/max}` are computed in the 161 same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`. 162 """ 163 164 def __init__( 165 self, 166 dtype=torch.qint8, 167 qscheme=torch.per_tensor_affine, 168 quant_min=None, 169 quant_max=None, 170 factory_kwargs=None, 171 ) -> None: 172 super().__init__() 173 174 self.dtype = dtype 175 self.qscheme = qscheme 176 self.ch_axis = 1 177 178 per_channel_qscheme = qscheme 179 if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: 180 per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] 181 self.weight_col_obs = PerChannelMinMaxObserver( 182 ch_axis=1, 183 dtype=dtype, 184 qscheme=per_channel_qscheme, 185 quant_min=quant_min, 186 quant_max=quant_max, 187 factory_kwargs=factory_kwargs, 188 ) 189 190 self.equalization_scale = torch.tensor(1) 191 192 def forward(self, w_orig): 193 if not (w_orig.ndim >= 2 and w_orig.ndim <= 5): 194 raise ValueError( 195 "InputEqualizationObserver only supports Linear and Conv layers" 196 ) 197 198 return self.weight_col_obs(w_orig) 199 200 def get_weight_col_minmax(self): 201 return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) 202 203 def set_equalization_scale(self, equalization_scale): 204 self.equalization_scale = equalization_scale 205 206 with_args = classmethod(_with_args) 207 208 209def calculate_equalization_scale( 210 input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver 211) -> torch.Tensor: 212 r"""Calculates the equalization scale and sets the equalization_scale value 213 in the observers. 214 215 Args: 216 input_obs: Observer that tracks the ranges for the input columns 217 weight_obs: Observer that tracks the ranges for the weight columns 218 """ 219 220 (min_inputs, max_inputs) = input_obs.get_input_minmax() 221 (min_weights, max_weights) = weight_obs.get_weight_col_minmax() 222 223 if not ( 224 check_min_max_valid(min_inputs, max_inputs) 225 and check_min_max_valid(min_weights, max_weights) 226 ): 227 warnings.warn( 228 "Must run observer before calling calculate_equalization_scale. " 229 + "Returning default equalization scale torch.tensor(1)." 230 ) 231 return torch.tensor(1) 232 233 if not (min_inputs.shape == min_weights.shape): 234 raise ValueError( 235 "Input and Weight must have the same column dimension. " 236 + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." 237 ) 238 239 equalization_scale = torch.sqrt( 240 (max_weights - min_weights) / (max_inputs - min_inputs) 241 ) 242 # Replace all 'inf', 'nan', 0's with 1s to prevent errors 243 equalization_scale[equalization_scale == 0.0] = 1 244 equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1) 245 return equalization_scale 246 247 248class EqualizationQConfig( 249 namedtuple("EqualizationQConfig", ["input_activation", "weight"]) 250): 251 """ 252 Describes how to quantize a layer or a part of the network specifically for 253 input-weight equalization by providing settings (observer classes) for 254 inputs, outputs, and weights. 255 256 Note that EqualizationQConfig needs to contain observer **classes** (like 257 MinMaxObserver) or a callable that returns instances on invocation, not the 258 concrete observer instances themselves. 259 Quantization function will instantiate observers multiple times for each of 260 the layers. 261 262 Observer classes have usually reasonable default arguments, but they can be 263 overwritten with `with_args` method (that behaves like functools.partial): 264 265 my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8), 266 weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) 267 """ 268 269 def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): 270 if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): 271 raise ValueError( 272 "EqualizationQConfig received observer instance, please pass observer class instead. " 273 + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" 274 ) 275 self = super().__new__(cls, input_activation, weight) 276 return self 277 278 279input_equalization_observer = _InputEqualizationObserver.with_args( 280 dtype=torch.quint8, qscheme=torch.per_tensor_symmetric 281) 282weight_equalization_observer = _WeightEqualizationObserver.with_args( 283 dtype=torch.qint8, qscheme=torch.per_channel_symmetric 284) 285default_equalization_qconfig = EqualizationQConfig( 286 input_activation=input_equalization_observer, weight=weight_equalization_observer 287) 288 289 290def fused_module_supports_equalization(module) -> bool: 291 """Checks if the fused node supports equalization.""" 292 return type(module) in [ 293 nni.LinearReLU, 294 nni.ConvReLU1d, 295 nni.ConvReLU2d, 296 nni.ConvReLU3d, 297 ] 298 299 300def nn_module_supports_equalization(module) -> bool: 301 """Checks if the torch.nn node supports equalization.""" 302 return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] 303 304 305def custom_module_supports_equalization(module) -> bool: 306 """Checks if the custom node supports equalization.""" 307 return type(module) in CUSTOM_MODULE_SUPP_LIST 308 309 310def node_supports_equalization(node: Node, modules) -> bool: 311 """Checks if the current node supports equalization 312 Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers 313 """ 314 if node.op == "call_module": 315 return ( 316 nn_module_supports_equalization(modules[str(node.target)]) 317 or fused_module_supports_equalization(modules[str(node.target)]) 318 or custom_module_supports_equalization(modules[str(node.target)]) 319 ) 320 elif node.op == "call_function": 321 return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] 322 return False 323 324 325def is_equalization_observer(observer: nn.Module) -> bool: 326 return isinstance( 327 observer, (_InputEqualizationObserver, _WeightEqualizationObserver) 328 ) 329 330 331############################################################################### 332# Functions for equalization during convert # 333############################################################################### 334 335 336def get_op_node_and_weight_eq_obs( 337 input_eq_obs_node: Node, model: GraphModule, modules: Dict[str, nn.Module] 338) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]: 339 """Gets the following weight equalization observer. There should always 340 exist a weight equalization observer after an input equalization observer. 341 342 Returns the operation node that follows the input equalization observer node 343 and the weight equalization observer 344 """ 345 346 # Find the op node that comes directly after the input equalization observer 347 op_node = None 348 for user in input_eq_obs_node.users.keys(): 349 if node_supports_equalization(user, modules): 350 op_node = user 351 break 352 353 assert op_node is not None 354 if op_node.op == "call_module": 355 # If the op_node is a nn.Linear layer, then it must have a 356 # WeightEqualizationObserver configuration 357 maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( 358 model, "equalization_node_name_to_qconfig" 359 ) 360 assert maybe_equalization_node_name_to_config is not None 361 equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment] 362 assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None 363 weight_eq_obs = equalization_node_name_to_qconfig.get( 364 op_node.name, None 365 ).weight() 366 367 assert isinstance(weight_eq_obs, _WeightEqualizationObserver) 368 return op_node, weight_eq_obs 369 370 elif op_node.op == "call_function": 371 weight_node = maybe_get_weight_eq_obs_node(op_node, modules) 372 if weight_node is not None: 373 weight_eq_obs = modules[str(weight_node.target)] 374 assert isinstance(weight_eq_obs, _WeightEqualizationObserver) 375 return op_node, weight_eq_obs 376 377 return None, None 378 379 380def maybe_get_weight_eq_obs_node( 381 op_node: Node, modules: Dict[str, nn.Module] 382) -> Optional[Node]: 383 """Gets the weight equalization observer node if it exists.""" 384 assert op_node.op == "call_function" 385 for node_arg in op_node.args: 386 if node_arg_is_weight(op_node, node_arg): 387 assert ( 388 isinstance(node_arg, Node) 389 and node_arg.op == "call_module" 390 and isinstance( 391 modules[str(node_arg.target)], _WeightEqualizationObserver 392 ) 393 ) 394 return node_arg 395 return None 396 397 398def maybe_get_next_input_eq_obs( 399 node: Node, modules: Dict[str, nn.Module] 400) -> Optional[_InputEqualizationObserver]: 401 """Gets the following input equalization observer if it exists. 402 403 For example, in the case of connecting linear layers: 404 x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 405 If the node being passed in is the linear1 node, then we want to return eq_obs2, 406 the following equalization observer for linear2. 407 408 However, if there are no connecting layers: 409 x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add 410 Then we want to return None. 411 412 In the case of an unfused linear-relu layer with a connecting linear layer: 413 linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 414 Since it is unfused, we want to skip over the relu layer and return eq_obs2, 415 the following equalization observer for linear2. 416 """ 417 418 assert node_supports_equalization(node, modules) 419 420 # Locate the following nn.ReLU or F.relu node if it exists 421 maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) 422 if maybe_relu_node is None: 423 maybe_relu_node = maybe_get_next_module( 424 node, modules, target_functional_type=F.relu 425 ) 426 427 # Locate the following output observer if it exists. 428 # We will skip the relu node if it exists. 429 maybe_obs_node = ( 430 maybe_get_next_module(node, modules, ObserverBase) 431 if maybe_relu_node is None 432 else maybe_get_next_module(maybe_relu_node, modules, ObserverBase) 433 ) 434 if maybe_obs_node is None: 435 return None 436 437 maybe_eq_obs_node = maybe_get_next_module( 438 maybe_obs_node, modules, _InputEqualizationObserver 439 ) 440 if maybe_eq_obs_node is None: 441 return None 442 443 maybe_eq_obs = modules[str(maybe_eq_obs_node)] 444 assert isinstance(maybe_eq_obs, _InputEqualizationObserver) 445 return maybe_eq_obs 446 447 448def maybe_get_next_equalization_scale( 449 node: Node, modules: Dict[str, nn.Module] 450) -> Optional[torch.Tensor]: 451 """If the next next node is an InputEqualizationObserver then we want to 452 return its equalization scale, else we return 1 453 454 This is used in the case where there are two connecting linear layers: 455 linear1 -> LinearOutObs -> InputEqObs -> linear2 456 In this case, the node given is linear1 and we want to locate the InputEqObs. 457 """ 458 next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) 459 if next_inp_eq_obs: 460 if ( 461 next_inp_eq_obs.equalization_scale.nelement() == 1 462 and next_inp_eq_obs.equalization_scale == torch.tensor(1) 463 ): 464 return None 465 return next_inp_eq_obs.equalization_scale 466 return None 467 468 469def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None: 470 """Scales the following input quantization observer's min/max values by 471 updating the values with the scaled min/max values calculated by the input 472 equalization observer 473 """ 474 input_eq_obs = modules[str(node.target)] 475 assert isinstance(input_eq_obs, _InputEqualizationObserver) 476 477 input_quant_obs_node = node.args[0] 478 assert isinstance(input_quant_obs_node, Node) 479 480 input_quant_obs = modules[str(input_quant_obs_node.target)] 481 if not isinstance(input_quant_obs, ObserverBase): 482 return 483 484 min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() 485 if min_input_scaled is None and max_input_scaled is None: 486 return 487 input_quant_obs.min_val = min_input_scaled 488 input_quant_obs.max_val = max_input_scaled 489 490 491def scale_weight_node( 492 node: Node, 493 modules: Dict[str, nn.Module], 494 equalization_scale: torch.Tensor, 495 next_equalization_scale: Optional[torch.Tensor], 496) -> None: 497 """Scale the weights for input-weight equalization by multiplying the 498 weight by 1/equalization_scale and next_equalization_scale 499 500 Args: 501 node: Current node whose weights we want to scale 502 equalization_scale: Current node's calculated equalization scale 503 next_equalization_scale: Next node's calculated equalization scale if 504 the following node needs to be equalized, 1 otherwise 505 """ 506 if equalization_scale is None: 507 return 508 509 if fused_module_supports_equalization(modules[str(node.target)]): 510 op_module = modules[str(node.target)][0] # type: ignore[index] 511 else: 512 op_module = modules[str(node.target)] 513 assert nn_module_supports_equalization( 514 op_module 515 ) or custom_module_supports_equalization(op_module) 516 517 # Scale the weights for input-weight equalization 518 # If the following layer needs to be equalized then we will multiply its scale 519 weight = op_module.weight 520 assert isinstance(weight, torch.Tensor) 521 522 # Scale the weights by the reciprocal of the equalization scale 523 # Reshape the equalization scale so that we can multiply it to the weight along axis=1 524 equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) 525 scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) 526 527 if next_equalization_scale is None: 528 op_module.weight = nn.Parameter(scaled_weight) 529 return 530 531 # Multiply the weights row wise by the next equalization scale 532 # Reshape the equalization scale so that we can multiply it to the weight along axis=0 533 next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight) 534 scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) 535 536 op_module.weight = nn.Parameter(scaled_weight) 537 538 # Multiply the bias element wise by the next equalization scale 539 bias = op_module.bias 540 if bias is None: 541 return 542 assert isinstance(bias, torch.Tensor) 543 544 # Reshape the equalization scale so that we can multiply it element-wise to the bias 545 next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) 546 scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) 547 op_module.bias = nn.Parameter(scaled_bias) 548 549 550def scale_weight_functional( 551 op_node: Node, 552 model: GraphModule, 553 modules: Dict[str, nn.Module], 554 equalization_scale: torch.Tensor, 555 next_equalization_scale: Optional[torch.Tensor], 556) -> None: 557 """Scales the weight value for functional layers""" 558 if equalization_scale is None: 559 return 560 561 # From the given op_node, the path looks like: 562 # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node 563 # So we want to trace back from the op_node to get the equalization observer 564 # node, then the quantization observer node, and then finally the weight 565 # node which contains the weight values. 566 567 # Get the equalization observer node 568 weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) 569 if weight_eq_obs_node is None: 570 return 571 572 # Get the quantization observer node 573 weight_quant_obs_node = weight_eq_obs_node.args[0] 574 if weight_quant_obs_node is None: 575 return 576 assert isinstance(weight_quant_obs_node, Node) and isinstance( 577 modules[str(weight_quant_obs_node.target)], ObserverBase 578 ) 579 580 # Get the get_attr(weight) node 581 weight_node = weight_quant_obs_node.args[0] 582 if weight_node is None: 583 return 584 assert isinstance(weight_node, Node) and weight_node.op == "get_attr" 585 586 weight_parent_name, weight_name = _parent_name(weight_node.target) 587 weight = getattr(modules[weight_parent_name], weight_name) 588 589 # Scale the weights for input-weight equalization 590 # If the following layer needs to be equalized then we will multiply its scale 591 # Reshape the equalization scale so that we can multiply it to the weight along axis=1 592 equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) 593 scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) 594 595 if next_equalization_scale is None: 596 setattr(modules[weight_parent_name], weight_name, scaled_weight) 597 return 598 599 # Multiply the weights row wise by the next equalization scale 600 # Reshape the equalization scale so that we can multiply it to the weight along axis=1 601 next_equalization_scale_reshaped = reshape_scale( 602 next_equalization_scale, 0, scaled_weight 603 ) 604 scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) 605 606 setattr(modules[weight_parent_name], weight_name, scaled_weight) 607 assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) 608 609 # Multiply the bias element wise by the next equalization scale 610 bias_node = None 611 for node in op_node.args: 612 # Find the node containing the weight values 613 if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name: 614 bias_node = node 615 break 616 if bias_node is None: 617 return 618 619 bias_parent_name, bias_name = _parent_name(bias_node.target) 620 bias = getattr(modules[bias_parent_name], bias_name) 621 622 # Reshape the equalization scale so that we can multiply it element-wise to the bias 623 next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) 624 scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) 625 setattr(modules[bias_parent_name], bias_name, scaled_bias) 626 627 628def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None: 629 """Given the operation node, we want find the corresponding quantization 630 observer and reset its min/max values 631 """ 632 weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) 633 if weight_eq_obs_node is None: 634 return 635 636 weight_quant_obs_node = weight_eq_obs_node.args[0] 637 if weight_quant_obs_node is None: 638 return 639 assert isinstance(weight_quant_obs_node, Node) 640 641 weight_quant_obs = modules[str(weight_quant_obs_node.target)] 642 assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) 643 weight_quant_obs.reset_min_max_vals() # type: ignore[operator] 644 645 646def remove_node(model: GraphModule, node: Node, prev_node: Node): 647 """Removes the given node from the model by replacing all of its users with 648 the given previous node 649 """ 650 # For all of the current node's users, replace the current node with 651 # the input quantization observer node 652 orig_users = list(node.users.keys()) 653 for user_node in orig_users: 654 user_node.replace_input_with(node, prev_node) 655 656 # Erase the InputEqualizationObserver node 657 model.graph.erase_node(node) 658 659 660def update_obs_for_equalization( 661 model: GraphModule, modules: Dict[str, nn.Module] 662) -> Dict[str, _WeightEqualizationObserver]: 663 """Update all of the observer's equalization scale. For each 664 InputEqualizationObserver, we will find the location of the next 665 WeightEqualizationObserver, create it, and calculate the equalization scale 666 based on the two observers. 667 668 We will then return a dictionary mapping operation node names to 669 the corresponding WeightEqualizationObservers for that operation. 670 """ 671 weight_eq_obs_dict = {} 672 for node in model.graph.nodes: 673 if node.op == "call_module" and isinstance( 674 modules[node.target], _InputEqualizationObserver 675 ): 676 input_eq_obs = modules[node.target] 677 assert isinstance(input_eq_obs, _InputEqualizationObserver) 678 op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) 679 680 if op_node is None or weight_eq_obs is None: 681 continue 682 683 if op_node.op == "call_module": 684 # Calibrate the weight equalization observer since it has just 685 # been created 686 if fused_module_supports_equalization(modules[str(op_node.target)]): 687 module = modules[str(op_node.target)][0] # type: ignore[index] 688 assert nn_module_supports_equalization(module) 689 weight_eq_obs(module.weight) 690 else: 691 weight_eq_obs(modules[str(op_node.target)].weight) 692 693 # Calculate and set the equalization scale values 694 equalization_scale = calculate_equalization_scale( 695 input_eq_obs, weight_eq_obs 696 ) 697 input_eq_obs.set_equalization_scale(equalization_scale) 698 weight_eq_obs.set_equalization_scale(equalization_scale) 699 700 weight_eq_obs_dict[op_node.name] = weight_eq_obs 701 702 return weight_eq_obs_dict 703 704 705def convert_eq_obs( 706 model: GraphModule, 707 modules: Dict[str, nn.Module], 708 weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver], 709) -> None: 710 """Converts the equalization operations and updates the other nodes in the 711 following way: 712 - Removes the input equalization observers and inserts a mul operator 713 along with an equalization scale node wherever applicable (we do not 714 want to insert a mul operator between connecting linear layers). 715 - Updates the input quantization observers with the scaled input min/max 716 values. 717 - Scales the weights by the current and next equalization scales. 718 - Removes the weight equalization observer node if it exists. 719 720 Before (after prepare): 721 weight values 722 | 723 WeightQuantObs 724 | 725 WeightEqObs 726 | 727 x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs 728 729 After this function: 730 scaled weight values 731 | 732 equalization scale WeightQuantObs 733 | | 734 x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs 735 736 After convert: 737 equalization scale scaled weight values 738 | | 739 x -> mul -> quantize_per_tensor -> quantized::linear 740 741 Note that although the equalization observer appeared after the quantization 742 observer after prepare_fx, the mul node appears before the quantization node 743 after convert_fx. This is because placing the equalization observer after 744 the quantization observer in prepare_fx would allow us to keep the invariant 745 that the graph before the current node inserts its observers is not 746 modified. 747 748 Having the equalization observer before the quantization observer would also 749 cause some inconsistences between the ordering of the quantization and 750 equalization observers. 751 For example, a single linear layer would look like: 752 x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 753 But between two connected linear layers, it would look like: 754 linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 755 """ 756 for node in model.graph.nodes: 757 if node.op == "call_module" and isinstance( 758 modules[node.target], _InputEqualizationObserver 759 ): 760 inp_quant_obs_node = node.args[0] 761 prev_node = inp_quant_obs_node.args[0] 762 763 # If the previous node is a layer that needs to be equalized, then 764 # we will remove the current node because we do not need to add any 765 # equalization nodes between two layers that need to be equalized 766 767 # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2 768 # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2 769 if ( 770 node_supports_equalization(prev_node, modules) 771 or "relu" in prev_node.name 772 ): 773 remove_node(model, node, inp_quant_obs_node) 774 continue 775 776 # Update the following input quantization observer's min/max values 777 scale_input_observer(node, modules) 778 779 # Remove the InputEqualization node and add a mul operator before 780 # the quantization observer node that appears before the equalization node 781 # Before: x -> input_quant_obs -> input_eq_obs -> linear 782 # After: x -> mul -> input_quant_obs -> linear 783 784 # Create a node containing the equalization scale 785 with model.graph.inserting_before(inp_quant_obs_node): 786 get_new_eq_scale_name = get_new_attr_name_with_prefix( 787 prev_node.name + "_equalization_scale" 788 ) 789 name = get_new_eq_scale_name(modules) 790 setattr(model, name, modules[node.target].equalization_scale) 791 eq_scale_node = model.graph.create_node("get_attr", name) 792 793 # Create a node multiplying the input with the equalization scale 794 with model.graph.inserting_after(eq_scale_node): 795 inputs = (prev_node, eq_scale_node) 796 mul_node = model.graph.create_node("call_function", torch.mul, inputs) 797 798 # Set the mul nod to be the input_quant_obs_node's input instead of 799 # the previous node 800 inp_quant_obs_node.replace_input_with(prev_node, mul_node) 801 remove_node(model, node, inp_quant_obs_node) 802 803 elif weight_eq_obs_dict.get(node.name, None) is not None: 804 weight_eq_obs = weight_eq_obs_dict.get(node.name) 805 assert isinstance(weight_eq_obs, _WeightEqualizationObserver) 806 equalization_scale = weight_eq_obs.equalization_scale 807 808 if ( 809 equalization_scale.nelement() == 1 810 and equalization_scale == torch.tensor(1) 811 ): 812 equalization_scale = None # type: ignore[assignment] 813 maybe_next_equalization_scale = maybe_get_next_equalization_scale( 814 node, modules 815 ) 816 817 # Scale the weight nodes 818 if node.op == "call_module": 819 scale_weight_node( 820 node, modules, equalization_scale, maybe_next_equalization_scale 821 ) 822 elif node.op == "call_function": 823 scale_weight_functional( 824 node, 825 model, 826 modules, 827 equalization_scale, 828 maybe_next_equalization_scale, 829 ) 830 831 weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) 832 if weight_eq_obs_node is None: 833 return 834 assert isinstance( 835 modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver 836 ) 837 838 # Clear the quantization observer's min/max values so that they 839 # can get updated later based on the new scale values 840 clear_weight_quant_obs_node(node, modules) 841 842 # Erase the weight equalization observer node 843 prev_node = weight_eq_obs_node.args[0] 844 remove_node(model, weight_eq_obs_node, prev_node) 845 else: 846 raise ValueError( 847 "Expected operation node to be 'call_module' or 'call_function" 848 + f"Instead got node {node.name} as '{node.op}'." 849 ) 850 851 852def _convert_equalization_ref(model: GraphModule): 853 """Reference function which applies changes needed for equalization, but 854 does not quantize the nodes 855 """ 856 modules = dict(model.named_modules(remove_duplicate=False)) 857 858 # Calculate the equalization scale, update the observers with the scaled 859 # inputs, and scale the weight 860 weight_eq_obs_dict = update_obs_for_equalization(model, modules) 861 convert_eq_obs(model, modules, weight_eq_obs_dict) 862 863 return GraphModule(model, model.graph) 864 865 866############################################################################### 867# Functions for running the equalized model on the Numeric Suite # 868############################################################################### 869 870 871def get_layer_sqnr_dict( 872 model_a: nn.Module, model_b: nn.Module, x: torch.Tensor 873) -> Dict[str, float]: 874 """Runs the Numeric Suite on model_a and model_b and returns a dictionary 875 containing the SQNR between layers in model_a and model_b. 876 877 Note: In order to support equalized models, this function has a hacky fix in 878 which we do not match any torch.mul operators. This is because equalized 879 models contain extra mul operators to scale the input by the equalization 880 scale, but this edge case has not been resolved yet within the numeric suite code. 881 882 Args: 883 model_a: A float model 884 model_b: A quantized model 885 x: Inputs to use during calibration 886 """ 887 import torch.ao.ns._numeric_suite_fx as ns 888 from torch.ao.ns.fx.mappings import get_unmatchable_types_map 889 890 unmatchable_types_map = get_unmatchable_types_map() 891 unmatchable_types_map["funs_unmatchable"].add(torch.mul) 892 893 model_a_ns, model_b_ns = ns.add_loggers( 894 "fp32", 895 model_a, 896 "int8", 897 model_b, 898 ns.OutputLogger, 899 unmatchable_types_map=unmatchable_types_map, 900 ) 901 902 model_a_ns(x) 903 model_b_ns(x) 904 905 activation_comparison_dict = ns.extract_logger_info( 906 model_a_ns, model_b_ns, ns.OutputLogger, "int8" 907 ) 908 ns.extend_logger_results_with_comparison( 909 activation_comparison_dict, 910 "fp32", 911 "int8", 912 torch.ao.ns.fx.utils.compute_sqnr, 913 "sqnr", 914 ) 915 916 # Construct a dictionary mapping layer names to the SQNR values 917 layer_sqnr_dict = {} 918 for key in activation_comparison_dict: 919 layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"] 920 sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0] 921 layer_sqnr_dict[layer] = sqnr 922 923 return layer_sqnr_dict 924 925 926def get_equalization_qconfig_dict( 927 layer_sqnr_dict: Dict[str, float], num_layers_to_equalize: int 928) -> Any: 929 """Given the layer to SQNR dictionary, find the layers with the highest 930 quantization errors, and return an equalization_qconfig_dict 931 specifying to only equalize those top layers. 932 933 Args: 934 layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found 935 when comparing an equalized model against a float model) 936 num_layers_to_equalize: Number of layers with the highest quantization 937 errors to equalize 938 """ 939 940 # Sort the layer_sqnr_dictionary values and get the layers with the lowest 941 # SQNR values (aka highest quantization errors) 942 layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1)) 943 layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] 944 945 # Constructs an equalization_qconfig_dict that specifies to only equalize 946 # the layers with the highest quantization errors 947 module_to_qconfig_list = [ 948 (item[0], default_equalization_qconfig) for item in layers_to_equalize 949 ] 950 equalization_qconfig_dict = {"module_name": module_to_qconfig_list} 951 return equalization_qconfig_dict 952