1# mypy: allow-untyped-defs 2""" 3Utils shared by different modes of quantization (eager/graph) 4""" 5import functools 6import warnings 7from collections import OrderedDict 8from inspect import getfullargspec, signature 9from typing import Any, Callable, Dict, Optional, Tuple, Union 10 11import torch 12from torch.ao.quantization.quant_type import QuantType 13from torch.fx import Node 14from torch.nn.utils.parametrize import is_parametrized 15 16 17NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any] 18NodePattern.__module__ = "torch.ao.quantization.utils" 19 20# This is the Quantizer class instance from torch/quantization/fx/quantize.py. 21# Define separately to prevent circular imports. 22# TODO(future PR): improve this. 23# make this public once fixed (can't be public as is because setting the module directly 24# doesn't work) 25QuantizerCls = Any 26 27# Type for fusion patterns, it can be more complicated than the following actually, 28# see pattern.md for docs 29# TODO: not sure if typing supports recursive data types 30Pattern = Union[ 31 Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any 32] 33Pattern.__module__ = "torch.ao.quantization.utils" 34 35 36# TODO: maybe rename this to MatchInputNode 37class MatchAllNode: 38 """A node pattern that matches all nodes, used in defining 39 fusion patterns in FX Graph Mode Quantization 40 """ 41 42 43module_type_list = { 44 torch.nn.ReLU, 45 torch.nn.ReLU6, 46 torch.nn.AdaptiveAvgPool1d, 47 torch.nn.AdaptiveAvgPool2d, 48 torch.nn.AdaptiveAvgPool3d, 49 torch.nn.AvgPool1d, 50 torch.nn.AvgPool2d, 51 torch.nn.AvgPool3d, 52 torch.nn.MaxPool1d, 53 torch.nn.MaxPool2d, 54 torch.nn.MaxPool3d, 55 torch.nn.Identity, 56 torch.nn.Hardsigmoid, 57 torch.nn.Sigmoid, 58 torch.nn.Tanh, 59} 60func_list = { 61 torch.nn.functional.adaptive_avg_pool1d, 62 torch.nn.functional.adaptive_avg_pool2d, 63 torch.nn.functional.adaptive_avg_pool3d, 64 torch.nn.functional.elu, 65 torch.nn.functional.hardswish, 66 torch.nn.functional.instance_norm, 67 torch.nn.functional.layer_norm, 68 torch.nn.functional.leaky_relu, 69 torch.nn.functional.silu, 70 torch.nn.functional.mish, 71 torch.nn.functional.dropout, 72 torch.nn.functional.max_pool1d, 73 torch.nn.functional.max_pool2d, 74 torch.nn.functional.max_pool3d, 75 torch.nn.functional.relu, 76 torch.nn.functional.hardtanh, 77 torch.nn.functional.hardtanh_, 78 torch.nn.functional.hardsigmoid, 79 torch.nn.functional.sigmoid, 80 torch.transpose, 81 torch.repeat_interleave, 82 torch.sigmoid, 83 torch.squeeze, 84 torch.stack, 85 torch.sum, 86 torch.tanh, 87 torch.unsqueeze, 88 torch.cat, 89} 90method_list = { 91 torch.mean, 92 "relu", 93 "relu_", 94 "contiguous", 95 "detach", 96 "detach_", 97 "hardsigmoid", 98 "hardsigmoid_", 99 "permute", 100 "repeat", 101 "repeat_interleave", 102 "reshape", 103 "resize_", 104 "shape", 105 "sigmoid", 106 "sigmoid_", 107 "size", 108 "squeeze", 109 "squeeze_", 110 "tanh", 111 "tanh_", 112 "transpose", 113 "unsqueeze", 114 "unsqueeze_", 115 "view", 116} 117 118 119# TODO: not used now, remove 120def check_node(node, modules): 121 # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py 122 is_call_function = node.op == "call_function" and node.target in func_list 123 is_call_method = node.op == "call_method" and node.target in method_list 124 is_call_module = ( 125 node.op == "call_module" and type(modules[str(node.target)]) in module_type_list 126 ) 127 return is_call_function, is_call_method, is_call_module 128 129 130def get_combined_dict(default_dict, additional_dict): 131 """ 132 Combines two dictionaries. 133 134 This function takes two dictionaries as input and returns a new dictionary 135 that contains all the key-value pairs from both input dictionaries. 136 If there are any duplicate keys in the `additional_dict`, the values 137 from the `additional_dict` will overwrite those in the `default_dict`. 138 Args: 139 default_dict (dict): The main dictionary that will be used as the base 140 additional_dict (dict): The dictionary used to update `default_dict` 141 142 Returns: 143 dict: The resulting dictionary 144 Example: 145 >>> x = dict(a=1, b=1) 146 >>> y = dict(b=2, c=3) 147 >>> get_combined_dict(x, y) 148 {'a': 1, 'b': 2, 'c': 3} 149 """ 150 d = default_dict.copy() 151 d.update(additional_dict) 152 return d 153 154 155def is_per_tensor(qscheme): 156 return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric 157 158 159def is_per_channel(qscheme): 160 return qscheme in [ 161 torch.per_channel_affine, 162 torch.per_channel_affine_float_qparams, 163 torch.per_channel_symmetric, 164 ] 165 166 167def getattr_from_fqn(obj: Any, fqn: str) -> Any: 168 """ 169 Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. 170 """ 171 return functools.reduce(getattr, fqn.split("."), obj) 172 173 174def to_underlying_dtype(qdtype): 175 DTYPE_MAPPING = { 176 torch.quint8: torch.uint8, 177 torch.qint8: torch.int8, 178 torch.qint32: torch.int32, 179 torch.quint4x2: torch.uint8, 180 torch.quint2x4: torch.uint8, 181 torch.uint8: torch.uint8, 182 torch.int8: torch.int8, 183 torch.int16: torch.int16, 184 torch.int32: torch.int32, 185 torch.float8_e5m2: torch.float8_e5m2, 186 torch.float8_e4m3fn: torch.float8_e4m3fn, 187 } 188 assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) 189 return DTYPE_MAPPING[qdtype] 190 191 192def get_qparam_dict(observer_or_fake_quant): 193 from torch.ao.quantization.observer import PlaceholderObserver 194 195 qscheme = getattr(observer_or_fake_quant, "qscheme", None) 196 dtype = observer_or_fake_quant.dtype 197 qparams = {"qscheme": qscheme, "dtype": dtype} 198 199 if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): 200 return {"qscheme": None, "dtype": dtype} 201 202 if is_per_tensor(qscheme): 203 qscheme = torch.per_tensor_affine 204 elif is_per_channel(qscheme): 205 # change symmetric to affine since we do not have symmetric 206 # quantized Tensor 207 if qscheme == torch.per_channel_symmetric: 208 qscheme = torch.per_channel_affine 209 qparams["axis"] = observer_or_fake_quant.ch_axis 210 else: 211 raise RuntimeError(f"Unrecognized qscheme: {qscheme}") 212 # update qscheme, since we don't have symmetric quant qscheme 213 # in quantized Tensor 214 qparams["qscheme"] = qscheme 215 216 scale, zero_point = observer_or_fake_quant.calculate_qparams() 217 qparams["scale"] = scale 218 qparams["zero_point"] = zero_point 219 220 if hasattr(observer_or_fake_quant, "quant_min"): 221 qparams["quant_min"] = observer_or_fake_quant.quant_min 222 if hasattr(observer_or_fake_quant, "quant_max"): 223 qparams["quant_max"] = observer_or_fake_quant.quant_max 224 225 return qparams 226 227 228def get_swapped_custom_module_class( 229 custom_module, custom_module_class_mapping, qconfig 230): 231 """Get the observed/quantized custom module class that we need 232 to swap `custom_module` to 233 Input: 234 custom_module: input, can be an instance of either a float or observed custom module 235 custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping 236 qconfig: qconfig configured for the custom module 237 238 Output: 239 corresponding observed/quantized custom module class for input custom module instance 240 """ 241 quant_type = get_quant_type(qconfig) 242 class_mapping = custom_module_class_mapping.get(quant_type, {}) 243 assert type(custom_module) in class_mapping, ( 244 "did not find corresponding observed " 245 f"module class for {type(custom_module)} in mapping: {class_mapping}" 246 ) 247 return class_mapping[type(custom_module)] 248 249 250def activation_dtype(qconfig): 251 assert qconfig is not None 252 activation = qconfig.activation() 253 return activation.dtype 254 255 256def weight_dtype(qconfig): 257 assert qconfig is not None 258 weight = qconfig.weight() 259 return weight.dtype 260 261 262def activation_is_statically_quantized(qconfig): 263 """Given a qconfig, decide if the activation needs to be 264 quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 265 """ 266 return activation_dtype(qconfig) in [ 267 torch.quint8, 268 torch.qint8, 269 torch.qint32, 270 torch.float16, 271 torch.uint8, 272 torch.int8, 273 torch.int16, 274 torch.int32, 275 torch.float8_e5m2, 276 torch.float8_e4m3fn, 277 ] and (not activation_is_dynamically_quantized(qconfig)) 278 279 280def activation_is_dynamically_quantized(qconfig): 281 """Given a qconfig, decide if the activation needs to be 282 dynamically quantized or not, this includes dynamically quantizing to 283 quint8, qint8 and float16 284 """ 285 activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) 286 return activation_is_dynamic 287 288 289def activation_is_int8_quantized(qconfig): 290 """Given a qconfig, decide if the activation needs to be 291 quantized to int8 or not, this includes quantizing to quint8, qint8 292 """ 293 return activation_dtype(qconfig) in [ 294 torch.quint8, 295 torch.qint8, 296 torch.uint8, 297 torch.int8, 298 ] 299 300 301def activation_is_int32_quantized(qconfig): 302 """Given a qconfig, decide if the activation needs to be 303 quantized to int32 or not 304 """ 305 return activation_dtype(qconfig) in [torch.qint32, torch.int32] 306 307 308def weight_is_quantized(qconfig): 309 """Given a qconfig, decide if the weight needs to be 310 quantized or not 311 """ 312 return weight_dtype(qconfig) in [ 313 torch.quint8, 314 torch.qint8, 315 torch.float16, 316 torch.quint4x2, 317 torch.uint8, 318 torch.int8, 319 torch.int16, 320 torch.int32, 321 torch.float8_e5m2, 322 torch.float8_e4m3fn, 323 ] 324 325 326def weight_is_statically_quantized(qconfig): 327 """Given a qconfig, decide if the weight needs to be statically 328 quantized or not 329 """ 330 return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] 331 332 333def op_is_int8_dynamically_quantized(qconfig) -> bool: 334 """Given a qconfig, returns True if this op is using int8 dynamic 335 quantization 336 """ 337 activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) 338 return ( 339 activation_dtype in [torch.quint8, torch.uint8] 340 and 341 # for now, the lines below assume fbgemm or qnnpack 342 weight_dtype in [torch.qint8, torch.int8] 343 and activation_is_dynamic 344 ) 345 346 347def get_qconfig_dtypes(qconfig): 348 r"""returns the qconfig tuple for qconfig: 349 (activation_dtype, weight_dtype, activation_is_dynamic) 350 """ 351 assert qconfig is not None 352 activation = qconfig.activation() 353 weight = qconfig.weight() 354 act_is_dynamic = getattr(activation, "is_dynamic", False) 355 return (activation.dtype, weight.dtype, act_is_dynamic) 356 357 358def get_quant_type(qconfig): 359 assert qconfig is not None 360 activation = qconfig.activation() 361 weight = qconfig.weight() 362 static_dtypes = [ 363 torch.quint8, 364 torch.qint8, 365 torch.quint4x2, 366 torch.qint32, 367 torch.uint8, 368 torch.int8, 369 torch.int16, 370 torch.int32, 371 torch.float8_e5m2, 372 torch.float8_e4m3fn, 373 ] 374 if weight.dtype in static_dtypes: 375 if hasattr(activation, "is_dynamic") and activation.is_dynamic: 376 return QuantType.DYNAMIC 377 elif activation.dtype in static_dtypes: 378 return QuantType.STATIC 379 else: 380 return QuantType.WEIGHT_ONLY 381 382 if weight.dtype == torch.float16: 383 if hasattr(activation, "is_dynamic") and activation.is_dynamic: 384 return QuantType.DYNAMIC 385 elif activation.dtype == torch.float16: 386 return QuantType.STATIC 387 388 raise Exception( # noqa: TRY002 389 f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," 390 f"weight({weight.dtype})" 391 ) 392 393 394def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: 395 """Checks if the given minimum and maximum values are valid, meaning that 396 they exist and the min value is less than the max value. 397 """ 398 if min_val.numel() == 0 or max_val.numel() == 0: 399 warnings.warn( 400 "must run observer before calling calculate_qparams. " 401 + "Returning default values." 402 ) 403 return False 404 405 if min_val.dim() == 0 or max_val.dim() == 0: 406 if min_val == float("inf") and max_val == float("-inf"): 407 warnings.warn( 408 "must run observer before calling calculate_qparams. " 409 + "Returning default values." 410 ) 411 412 return False 413 414 assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" 415 else: 416 assert torch.all( 417 min_val <= max_val 418 ), f"min {min_val} should be less than max {max_val}" 419 420 return True 421 422 423def calculate_qmin_qmax( 424 quant_min: int, 425 quant_max: int, 426 has_customized_qrange: bool, 427 dtype: torch.dtype, 428 reduce_range: bool, 429) -> Tuple[int, int]: 430 r"""Calculates actual qmin and qmax based on the quantization range, 431 observer datatype and if range is reduced. 432 """ 433 # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. 434 if has_customized_qrange: 435 # This initialization here is to be resolve TorchScript compilation issues and allow 436 # using of refinement to decouple initial_qmin and initial_qmax from quantization range. 437 # The actual values of initial_qmin and initial_qmax will be reset below. 438 if dtype in [torch.qint32, torch.int32]: 439 initial_quant_min, initial_quant_max = 0, 2**32 - 1 440 else: 441 initial_quant_min, initial_quant_max = 0, 255 442 # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the 443 # attribute from Optional valid integers for use, based on TorchScript's requirements. 444 custom_quant_min, custom_quant_max = quant_min, quant_max 445 if custom_quant_min is not None and custom_quant_max is not None: 446 initial_quant_min, initial_quant_max = ( 447 custom_quant_min, 448 custom_quant_max, 449 ) 450 451 qrange_len = initial_quant_max - initial_quant_min + 1 452 if dtype in [torch.qint8, torch.int8]: 453 assert ( 454 0 < qrange_len <= 256 455 ), "quantization range should be positive and not exceed the maximum bit range (=256)." 456 elif dtype in [torch.qint32, torch.int32]: 457 assert ( 458 0 < qrange_len <= 2**32 459 ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." 460 if reduce_range: 461 quant_min, quant_max = quant_min // 2, quant_max // 2 462 else: 463 # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. 464 if dtype in [torch.qint8, torch.int8]: 465 if reduce_range: 466 quant_min, quant_max = -64, 63 467 else: 468 quant_min, quant_max = -128, 127 469 elif dtype in [torch.quint8, torch.uint8]: 470 if reduce_range: 471 quant_min, quant_max = 0, 127 472 else: 473 quant_min, quant_max = 0, 255 474 elif dtype in [torch.qint32, torch.int32]: 475 quant_min, quant_max = -1 * (2**31), (2**31) - 1 476 else: 477 quant_min, quant_max = 0, 15 478 return quant_min, quant_max 479 480 481def _parent_name(target): 482 """ 483 Turn 'foo.bar' into ['foo', 'bar'] 484 """ 485 r = target.rsplit(".", 1) 486 if len(r) == 1: 487 return "", r[0] 488 else: 489 return r[0], r[1] 490 491 492def has_no_children_ignoring_parametrizations(module): 493 """ 494 Checks if module._modules is empty or 495 if module is a parametrization, checks that module._modules only has 496 the 'parametrizations' module 497 """ 498 if len(module._modules) == 0: 499 return True 500 elif is_parametrized(module): 501 return len(module._modules) == 1 and "parametrizations" in module._modules 502 else: 503 return False 504 505 506def _get_path_of_module( 507 root: torch.nn.Module, submodule: torch.nn.Module 508) -> Optional[str]: 509 """Get the path (fully qualified name) of a submodule 510 511 Example:: 512 513 >> class M(torch.nn.Module): 514 def __init__(self) -> None: 515 self.linear = torch.nn.Linear(5, 5) 516 def forward(self, x): 517 return self.linear(x) 518 519 >> m = M() 520 >> l = m.linear 521 >> _get_path_of_module(m, l) 522 "linear" 523 """ 524 for n, p in root.named_modules(): 525 if submodule is p: 526 return n 527 return None 528 529 530def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]: 531 """Get local keyword arguments 532 533 Example:: 534 535 >> def f(self, a, b=9): 536 pass 537 >> loc = {"a": 6, "c": 7} 538 >> _get_signature_locals(f, loc) 539 {"a": 6} 540 """ 541 return {k: v for k, v in loc.items() if k in signature(f).parameters} 542 543 544def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": 545 """Get all default keyword arguments from function signature 546 547 Example:: 548 549 >> def f(self, a, b=9): 550 pass 551 >> _get_default_kwargs(f) 552 {"b": 9} 553 """ 554 kwargs = {} 555 for name, param in signature(f).parameters.items(): 556 if param.default is not param.empty: 557 kwargs[name] = param.default 558 elif param.kind is param.VAR_POSITIONAL: 559 kwargs[name] = () 560 elif param.kind is param.VAR_KEYWORD: 561 kwargs[name] = {} 562 return OrderedDict(kwargs) 563 564 565def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]": 566 """Given a function and local function arguments, normalize the keyword 567 arguments by filling in default arguments from function signature 568 569 Example:: 570 571 >> def f(self, key1=3, key2=3): 572 pass 573 >> loc = {"key2": 6} 574 >> _normalize_kwargs(f, loc) 575 {"key1": 3, "key2": 6} 576 """ 577 default_kwargs = _get_default_kwargs(func) 578 local_kwargs = _get_signature_locals(func, loc) 579 normalized_kwargs = default_kwargs.copy() 580 for attr, val in local_kwargs.items(): 581 if attr in normalized_kwargs: 582 # override the default keyword arguments 583 normalized_kwargs[attr] = val 584 return normalized_kwargs 585 586 587def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: 588 r"""Validates that the user-specified quantization range is properly initialized 589 and within the given bound supported by the observer dtype. 590 591 To accommodate lower-bit quantization with respect to the existing torch.qint8 and 592 torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing 593 in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax 594 values are used to calculate static estimates of the scale and zero point for aggressive lower-bit 595 fake quantization. These estimates are compared against parameters learned through backpropagation. 596 The related literatures for scale and zero point via backpropagation are as follows: 597 598 Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS 599 Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf 600 """ 601 # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted 602 # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. 603 assert ( 604 quant_min <= 0 <= quant_max 605 ), "Used-specified quantization range must include 0." 606 assert ( 607 quant_min < quant_max 608 ), "qmin must be strictly less than qmax for user-specified quantization range." 609 610 611# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme 612# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer 613# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change 614# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) 615def determine_qparams( 616 min_val: torch.Tensor, 617 max_val: torch.Tensor, 618 quant_min: int, 619 quant_max: int, 620 dtype: torch.dtype, 621 eps: torch.Tensor, 622 has_customized_qrange: bool, 623 qscheme: torch.qscheme = torch.per_tensor_affine, 624) -> Tuple[torch.Tensor, torch.Tensor]: 625 r"""Calculates the quantization parameters, given min and max 626 value tensors. Works for both per tensor and per channel cases 627 628 Args: 629 min_val: Minimum values per channel 630 max_val: Maximum values per channel 631 632 Returns: 633 scales: Scales tensor of shape (#channels,) 634 zero_points: Zero points tensor of shape (#channels,) 635 """ 636 if not check_min_max_valid(min_val, max_val): 637 return torch.tensor([1.0], device=min_val.device.type), torch.tensor( 638 [0], device=min_val.device.type 639 ) 640 641 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 642 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 643 644 device = min_val_neg.device 645 scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) 646 zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 647 648 if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: 649 max_val_pos = torch.max(-min_val_neg, max_val_pos) 650 scale = max_val_pos / (float(quant_max - quant_min) / 2) 651 scale = torch.max(scale, eps) 652 if dtype in [torch.uint8, torch.quint8]: 653 if has_customized_qrange: 654 # When customized quantization range is used, down-rounded midpoint of the range is chosen. 655 zero_point = zero_point.new_full( 656 zero_point.size(), (quant_min + quant_max) // 2 657 ) 658 else: 659 zero_point = zero_point.new_full(zero_point.size(), 128) 660 elif qscheme == torch.per_channel_affine_float_qparams: 661 scale = (max_val - min_val) / float(quant_max - quant_min) 662 scale = torch.where(scale > eps, scale, torch.ones_like(scale)) 663 # We use the quantize function 664 # xq = Round(Xf * inv_scale + zero_point), 665 # setting zero_point to (-1 * min *inv_scale) we get 666 # Xq = Round((Xf - min) * inv_scale) 667 zero_point = -1 * min_val / scale 668 else: 669 scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) 670 scale = torch.max(scale, eps) 671 zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) 672 zero_point = torch.clamp(zero_point, quant_min, quant_max) 673 674 # For scalar values, cast them to Tensors of size 1 to keep the shape 675 # consistent with default values in FakeQuantize. 676 if len(scale.shape) == 0: 677 # TODO: switch to scale.item() after adding JIT support 678 scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) 679 if len(zero_point.shape) == 0: 680 # TODO: switch to zero_point.item() after adding JIT support 681 zero_point = torch.tensor( 682 [int(zero_point)], dtype=zero_point.dtype, device=device 683 ) 684 if qscheme == torch.per_channel_affine_float_qparams: 685 zero_point = torch.tensor( 686 [float(zero_point)], dtype=zero_point.dtype, device=device 687 ) 688 689 return scale.to(torch.double), zero_point.to(torch.int64) 690 691 692def _get_num_pos_args(f: Callable) -> int: 693 """Get number of positional args for a function 694 695 Example:: 696 697 >> def f(self, key1=3, key2=3): 698 pass 699 >> _get_num_pos_args(f) 700 3 701 """ 702 return len(getfullargspec(f).args) 703 704 705def get_fqn_to_example_inputs( 706 model: torch.nn.Module, example_inputs: Tuple[Any, ...] 707) -> Dict[str, Tuple[Any, ...]]: 708 """Given a model and its example inputs, return a dictionary from 709 fully qualified name of submodules to example_inputs for that submodule, 710 e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), 711 "sub.linear1": (tensor4,), ...} 712 713 Used to make quantizing submodules easier now that FX Graph Mode Quantization requires 714 example inputs. 715 716 Also works for keyword arguments with default values, we would flatten keyword 717 arguments as positional arguments and fill in the missing keyword args with default 718 values, e.g. if we have a forward function: 719 def forward(self, x, key1=3, key2=3): 720 ... 721 722 and we call it with self.submodule(x, key2=6) 723 we'll get example_inputs: (x, 3, 6) 724 725 user can also override `key1` with positional arguments as well: 726 for self.submodule(x, 5, key2=6) 727 we'll get: (x, 5, 6) 728 729 variable positional arguments and variable positional keyword arguments in forward 730 function are not supported currently, so please make sure no submodules is using 731 them. 732 """ 733 root = model 734 fqn_to_example_inputs = {} 735 736 def _patched_module_call(self, *args, **kwargs): 737 submodule_example_inputs = list(args).copy() 738 normalized_kwargs = _normalize_kwargs(self.forward, kwargs) 739 # minus 1 to skipping counting `self` 740 num_args = _get_num_pos_args(self.forward) - 1 741 num_to_pop = num_args - len(submodule_example_inputs) 742 while num_to_pop and normalized_kwargs: 743 normalized_kwargs.popitem(last=False) 744 num_to_pop -= 1 745 submodule_example_inputs.extend(normalized_kwargs.values()) 746 submodule_example_inputs_tuple = tuple(submodule_example_inputs) 747 fqn = _get_path_of_module(root, self) 748 if fqn is not None: 749 fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple 750 return orig_module_call(self, *args, **kwargs) 751 752 orig_module_call = torch.nn.Module.__call__ 753 torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] 754 try: 755 model(*example_inputs) 756 finally: 757 # restore the module call even if there is an exception 758 torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] 759 return fqn_to_example_inputs 760 761 762def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: 763 """ 764 Returns the unique device for a module, or None if no device is found. 765 Throws an error if multiple devices are detected. 766 """ 767 devices = {p.device for p in module.parameters()} | { 768 p.device for p in module.buffers() 769 } 770 """ 771 As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 772 """ 773 if {torch.device("cpu"), torch.device("meta")} == devices: 774 warnings.warn( 775 "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." 776 ) 777 devices = {torch.device("cpu")} 778 "" 779 assert len(devices) <= 1, ( 780 "prepare only works with cpu or single-device CUDA modules, " 781 f"but got devices {devices}" 782 ) 783 device = next(iter(devices)) if len(devices) > 0 else None 784 return device 785 786 787__all__ = [ 788 "NodePattern", 789 "Pattern", 790 "MatchAllNode", 791 "check_node", 792 "get_combined_dict", 793 "is_per_tensor", 794 "is_per_channel", 795 "getattr_from_fqn", 796 "get_qparam_dict", 797 "get_swapped_custom_module_class", 798 "activation_dtype", 799 "weight_dtype", 800 "activation_is_statically_quantized", 801 "activation_is_dynamically_quantized", 802 "activation_is_int8_quantized", 803 "activation_is_int32_quantized", 804 "weight_is_quantized", 805 "weight_is_statically_quantized", 806 "op_is_int8_dynamically_quantized", 807 "get_qconfig_dtypes", 808 "get_quant_type", 809 "check_min_max_valid", 810 "calculate_qmin_qmax", 811 "has_no_children_ignoring_parametrizations", 812 "get_fqn_to_example_inputs", 813 "to_underlying_dtype", 814 "determine_qparams", 815 "validate_qmin_qmax", 816] 817