1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3""" 4This module implements observers which are used to collect statistics about 5the values observed during calibration (PTQ) or training (QAT). 6""" 7 8import re 9import warnings 10from abc import ABCMeta, abstractmethod 11from collections import OrderedDict 12from functools import partial 13from typing import Any, Dict, List, Optional, Tuple 14 15import torch 16import torch.nn as nn 17from torch.ao.quantization.utils import ( 18 calculate_qmin_qmax, 19 check_min_max_valid, 20 is_per_channel, 21 is_per_tensor, 22 validate_qmin_qmax, 23) 24 25 26__all__ = [ 27 "default_affine_fixed_qparams_observer", 28 "default_debug_observer", 29 "default_dynamic_quant_observer", 30 "default_fixed_qparams_range_0to1_observer", 31 "default_fixed_qparams_range_neg1to1_observer", 32 "default_float_qparams_observer", 33 "default_float_qparams_observer_4bit", 34 "default_histogram_observer", 35 "default_observer", 36 "default_per_channel_weight_observer", 37 "default_placeholder_observer", 38 "default_reuse_input_observer", 39 "default_symmetric_fixed_qparams_observer", 40 "default_weight_observer", 41 "get_observer_state_dict", 42 "load_observer_state_dict", 43 "per_channel_weight_observer_range_neg_127_to_127", 44 "weight_observer_range_neg_127_to_127", 45 "FixedQParamsObserver", 46 "HistogramObserver", 47 "MinMaxObserver", 48 "MovingAverageMinMaxObserver", 49 "MovingAveragePerChannelMinMaxObserver", 50 "NoopObserver", 51 "ObserverBase", 52 "PerChannelMinMaxObserver", 53 "PlaceholderObserver", 54 "RecordingObserver", 55 "ReuseInputObserver", 56 "UniformQuantizationObserverBase", 57] 58 59 60class _PartialWrapper: 61 def __init__(self, p): 62 self.p = p 63 self.callable_args = {} 64 65 def __call__(self, *args, **keywords): 66 # call each arg in callable_args and add them partial, then run with keywords 67 # skip if arg_name in keywords so its possible to overwrite 68 for arg_name in self.callable_args: 69 if arg_name not in keywords: 70 keywords = {**keywords, arg_name: self.callable_args[arg_name]()} 71 return self.p(*args, **keywords) 72 73 def __repr__(self): 74 return self.p.__repr__() + self.callable_args.__repr__() 75 76 def with_args(self, **kwargs): 77 return _with_args(self, **kwargs) 78 79 def with_callable_args(self, **kwargs): 80 result = _PartialWrapper(p=self.p) 81 result.callable_args = {**self.callable_args, **kwargs} 82 return result 83 84 85def _with_args(cls_or_self, **kwargs): 86 r"""Wrapper that allows creation of class factories. 87 88 This can be useful when there is a need to create classes with the same 89 constructor arguments, but different instances. Can be used in conjunction with 90 _callable_args 91 92 Example:: 93 94 >>> # xdoctest: +SKIP("Undefined vars") 95 >>> Foo.with_args = classmethod(_with_args) 96 >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) 97 >>> foo_instance1 = foo_builder() 98 >>> foo_instance2 = foo_builder() 99 >>> id(foo_instance1) == id(foo_instance2) 100 False 101 """ 102 r = _PartialWrapper(partial(cls_or_self, **kwargs)) 103 return r 104 105 106def _with_callable_args(cls_or_self, **kwargs): 107 r"""Wrapper that allows creation of class factories args that need to be 108 called at construction time. 109 110 This can be useful when there is a need to create classes with the same 111 constructor arguments, but different instances and those arguments should only 112 be calculated at construction time. Can be used in conjunction with _with_args 113 114 Example:: 115 116 >>> # xdoctest: +SKIP("Undefined vars") 117 >>> Foo.with_callable_args = classmethod(_with_callable_args) 118 >>> Foo.with_args = classmethod(_with_args) 119 >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") 120 >>> foo_instance1 = foo_builder() 121 >>> # wait 50 122 >>> foo_instance2 = foo_builder() 123 >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) 124 False 125 """ 126 r = _PartialWrapper(partial(cls_or_self)) 127 return r.with_callable_args(**kwargs) 128 129 130ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: 131 132 133class ObserverBase(ABC, nn.Module): 134 r"""Base observer Module. 135 Any observer implementation should derive from this class. 136 137 Concrete observers should follow the same API. In forward, they will update 138 the statistics of the observed Tensor. And they should provide a 139 `calculate_qparams` function that computes the quantization parameters given 140 the collected statistics. 141 142 Args: 143 dtype: dtype argument to the `quantize` node needed to implement the 144 reference model spec. 145 is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization 146 or static quantization 147 """ 148 149 def __init__(self, dtype, is_dynamic=False): 150 super().__init__() 151 self.dtype = dtype 152 self.is_dynamic = is_dynamic 153 154 @abstractmethod 155 def forward(self, x): 156 pass 157 158 @abstractmethod 159 def calculate_qparams(self, **kwargs): 160 pass 161 162 with_args = classmethod(_with_args) 163 with_callable_args = classmethod(_with_callable_args) 164 165 166class UniformQuantizationObserverBase(ObserverBase): 167 r"""Common base for all observers using uniform quantization to calculate 168 scale and zero_point. 169 170 Args: 171 dtype: dtype argument to the `quantize` node needed to implement the 172 reference model spec. 173 qscheme: Quantization scheme to be used. 174 reduce_range: Reduces the range of the quantized data type by 1 bit. 175 This is sometimes required to avoid instruction overflow. 176 quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. 177 quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. 178 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 179 180 .. warning:: 181 182 :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. 183 or `torch.int8` or `torch.uint8` 184 185 .. warning:: 186 187 :attr:`qscheme` can only take one of the following options: 188 189 - ``torch.per_tensor_affine`` 190 - ``torch.per_tensor_symmetric`` 191 - ``torch.per_channel_affine`` 192 - ``torch.per_channel_symmetric`` 193 """ 194 195 # Note: the version is shared by all observer types 196 # 197 # Version 1/None 198 # self 199 # 200 # Version 2 (base class only, does not include child class buffers) 201 # self 202 # |--- eps : Tensor 203 # 204 # Version 3 205 # for HistogramObserver only, changed the shape of uninitialized 206 # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) 207 # for PerChannelObservers, changed the name of the buffers from min_vals 208 # to min_val and from max_vals to max_val. 209 _version = 3 210 211 eps: torch.Tensor 212 213 def __init__( 214 self, 215 dtype=torch.quint8, 216 qscheme=torch.per_tensor_affine, 217 reduce_range=False, 218 quant_min=None, 219 quant_max=None, 220 factory_kwargs=None, 221 eps=torch.finfo(torch.float32).eps, 222 is_dynamic=False, 223 **kwargs, 224 ) -> None: 225 factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) 226 super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) 227 self.qscheme = qscheme 228 if reduce_range: 229 warnings.warn( 230 "Please use quant_min and quant_max to specify the range for observers. \ 231 reduce_range will be deprecated in a future release of PyTorch." 232 ) 233 self.reduce_range = reduce_range 234 self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) 235 assert self.qscheme in ( 236 torch.per_tensor_affine, 237 torch.per_tensor_symmetric, 238 torch.per_channel_affine, 239 torch.per_channel_symmetric, 240 torch.per_channel_affine_float_qparams, 241 ), "Default Observer only works for per_tensor_affine, \ 242 per_tensor_symmetric, per_channel_affine, \ 243 per_channel_symmetric and per_channel_float_qparams quantization scheme" 244 245 _ALLOWED_DTYPES = ( 246 torch.qint8, 247 torch.quint8, 248 torch.quint4x2, 249 torch.qint32, 250 torch.int8, 251 torch.uint8, 252 torch.int16, 253 torch.int32, 254 torch.float8_e5m2, 255 torch.float8_e4m3fn, 256 ) 257 258 assert ( 259 self.dtype in _ALLOWED_DTYPES 260 ), f"Default Observer only works for {_ALLOWED_DTYPES} data type" 261 self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) 262 if self.has_customized_qrange: 263 validate_qmin_qmax(quant_min, quant_max) 264 self.quant_min, self.quant_max = calculate_qmin_qmax( 265 quant_min, 266 quant_max, 267 self.has_customized_qrange, 268 self.dtype, 269 self.reduce_range, 270 ) 271 272 def _load_from_state_dict( 273 self, 274 state_dict, 275 prefix, 276 local_metadata, 277 strict, 278 missing_keys, 279 unexpected_keys, 280 error_msgs, 281 ): 282 version = local_metadata.get("version", None) 283 284 if version is None or version == 1: 285 # eps was moved to a buffer in version 2 286 eps = torch.tensor([torch.finfo(torch.float32).eps]) 287 state_dict[prefix + "eps"] = eps 288 289 super()._load_from_state_dict( 290 state_dict, 291 prefix, 292 local_metadata, 293 strict, 294 missing_keys, 295 unexpected_keys, 296 error_msgs, 297 ) 298 299 @torch.jit.export 300 def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: 301 r"""Validates that the user-specified quantization range is properly initialized 302 and within the given bound supported by the observer dtype. 303 304 To accommodate lower-bit quantization with respect to the existing torch.qint8 and 305 torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing 306 in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax 307 values are used to calculate static estimates of the scale and zero point for aggressive lower-bit 308 fake quantization. These estimates are compared against parameters learned through backpropagation. 309 The related literatures for scale and zero point via backpropagation are as follows: 310 311 Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS 312 Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf 313 """ 314 # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted 315 # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. 316 assert ( 317 quant_min <= 0 <= quant_max 318 ), "Used-specified quantization range must include 0." 319 assert ( 320 quant_min < quant_max 321 ), "qmin must be strictly less than qmax for user-specified quantization range." 322 323 @torch.jit.export 324 def _calculate_qparams( 325 self, min_val: torch.Tensor, max_val: torch.Tensor 326 ) -> Tuple[torch.Tensor, torch.Tensor]: 327 r"""Calculates the quantization parameters, given min and max 328 value tensors. Works for both per tensor and per channel cases 329 330 Args: 331 min_val: Minimum values per channel 332 max_val: Maximum values per channel 333 334 Returns: 335 scales: Scales tensor of shape (#channels,) 336 zero_points: Zero points tensor of shape (#channels,) 337 """ 338 # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme 339 # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer 340 # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code 341 # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. 342 # TODO(jakeszwe, jerryzh168) 343 if not check_min_max_valid(min_val, max_val): 344 return torch.tensor([1.0], device=min_val.device.type), torch.tensor( 345 [0], device=min_val.device.type 346 ) 347 348 quant_min, quant_max = self.quant_min, self.quant_max 349 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 350 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 351 352 device = min_val_neg.device 353 scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) 354 zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 355 356 if ( 357 self.qscheme == torch.per_tensor_symmetric 358 or self.qscheme == torch.per_channel_symmetric 359 ): 360 max_val_pos = torch.max(-min_val_neg, max_val_pos) 361 scale = max_val_pos / (float(quant_max - quant_min) / 2) 362 scale = torch.max(scale, self.eps) 363 if self.dtype in [torch.quint8, torch.uint8]: 364 if self.has_customized_qrange: 365 # When customized quantization range is used, down-rounded midpoint of the range is chosen. 366 zero_point = zero_point.new_full( 367 zero_point.size(), (quant_min + quant_max) // 2 368 ) 369 else: 370 zero_point = zero_point.new_full(zero_point.size(), 128) 371 elif self.qscheme == torch.per_channel_affine_float_qparams: 372 scale = (max_val - min_val) / float(quant_max - quant_min) 373 scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) 374 # We use the quantize function 375 # xq = Round(Xf * inv_scale + zero_point), 376 # setting zero_point to (-1 * min *inv_scale) we get 377 # Xq = Round((Xf - min) * inv_scale) 378 zero_point = -1 * min_val / scale 379 else: 380 scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) 381 scale = torch.max(scale, self.eps) 382 zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) 383 zero_point = torch.clamp(zero_point, quant_min, quant_max) 384 385 # For scalar values, cast them to Tensors of size 1 to keep the shape 386 # consistent with default values in FakeQuantize. 387 if len(scale.shape) == 0: 388 # TODO: switch to scale.item() after adding JIT support 389 scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) 390 if len(zero_point.shape) == 0: 391 # TODO: switch to zero_point.item() after adding JIT support 392 zero_point = torch.tensor( 393 [int(zero_point)], dtype=zero_point.dtype, device=device 394 ) 395 if self.qscheme == torch.per_channel_affine_float_qparams: 396 zero_point = torch.tensor( 397 [float(zero_point)], dtype=zero_point.dtype, device=device 398 ) 399 400 return scale, zero_point 401 402 @torch.jit.export 403 def reset_min_max_vals(self): 404 raise NotImplementedError("Cannot reset min/max values in the given observer.") 405 406 407# Originally, this class was called `_ObserverBase`. Keeping the old name around 408# for backwards compatibility. 409# TODO(after v1.13): delete this 410_ObserverBase = UniformQuantizationObserverBase 411 412 413class MinMaxObserver(UniformQuantizationObserverBase): 414 r"""Observer module for computing the quantization parameters based on the 415 running min and max values. 416 417 This observer uses the tensor min/max statistics to compute the quantization 418 parameters. The module records the running minimum and maximum of incoming 419 tensors, and uses this statistic to compute the quantization parameters. 420 421 Args: 422 dtype: dtype argument to the `quantize` node needed to implement the 423 reference model spec. 424 qscheme: Quantization scheme to be used 425 reduce_range: Reduces the range of the quantized data type by 1 bit 426 quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. 427 quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. 428 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 429 430 Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, 431 scale :math:`s` and zero point :math:`z` are computed as: 432 433 The running minimum/maximum :math:`x_\text{min/max}` is computed as: 434 435 .. math:: 436 437 \begin{array}{ll} 438 x_\text{min} &= \begin{cases} 439 \min(X) & \text{if~}x_\text{min} = \text{None} \\ 440 \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} 441 \end{cases}\\ 442 x_\text{max} &= \begin{cases} 443 \max(X) & \text{if~}x_\text{max} = \text{None} \\ 444 \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} 445 \end{cases}\\ 446 \end{array} 447 448 where :math:`X` is the observed tensor. 449 450 The scale :math:`s` and zero point :math:`z` are then computed as: 451 452 .. math:: 453 454 \begin{aligned} 455 \text{if Symmetric:}&\\ 456 &s = 2 \max(|x_\text{min}|, x_\text{max}) / 457 \left( Q_\text{max} - Q_\text{min} \right) \\ 458 &z = \begin{cases} 459 0 & \text{if dtype is qint8} \\ 460 128 & \text{otherwise} 461 \end{cases}\\ 462 \text{Otherwise:}&\\ 463 &s = \left( x_\text{max} - x_\text{min} \right ) / 464 \left( Q_\text{max} - Q_\text{min} \right ) \\ 465 &z = Q_\text{min} - \text{round}(x_\text{min} / s) 466 \end{aligned} 467 468 where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and 469 maximum of the quantized data type. 470 471 .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. 472 473 .. note:: If the running minimum equals to the running maximum, the scale 474 and zero_point are set to 1.0 and 0. 475 """ 476 min_val: torch.Tensor 477 max_val: torch.Tensor 478 479 def __init__( 480 self, 481 dtype=torch.quint8, 482 qscheme=torch.per_tensor_affine, 483 reduce_range=False, 484 quant_min=None, 485 quant_max=None, 486 factory_kwargs=None, 487 eps=torch.finfo(torch.float32).eps, 488 is_dynamic=False, 489 **kwargs, 490 ) -> None: 491 if not is_per_tensor(qscheme): 492 raise NotImplementedError( 493 "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ 494 and torch.per_tensor_affine." 495 ) 496 # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but 497 # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it 498 # supports dynamic quantization, we may need to better error checking here 499 500 # For x86 quantized kernels, we need to ensure that the vpmaddubsw 501 # instruction does not overflow. We allow for a reduce_range argument to 502 # observers that reduces the quantized range to (0,127) or (-64, 63). 503 # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp 504 # This is not an optimal choice for non x86 backends as it loses a bit 505 # of precision for activations. 506 super().__init__( 507 dtype=dtype, 508 qscheme=qscheme, 509 reduce_range=reduce_range, 510 quant_min=quant_min, 511 quant_max=quant_max, 512 factory_kwargs=factory_kwargs, 513 eps=eps, 514 is_dynamic=is_dynamic, 515 **kwargs, 516 ) 517 factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) 518 self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) 519 self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) 520 if ( 521 self.qscheme == torch.per_tensor_symmetric 522 and self.reduce_range 523 and self.dtype == torch.quint8 524 ): 525 raise NotImplementedError( 526 "Cannot reduce range for symmetric \ 527 quantization for quint8" 528 ) 529 530 def forward(self, x_orig): 531 r"""Records the running minimum and maximum of ``x``.""" 532 if x_orig.numel() == 0: 533 return x_orig 534 x = x_orig.detach() # avoid keeping autograd tape 535 x = x.to(self.min_val.dtype) 536 min_val_cur, max_val_cur = torch.aminmax(x) 537 min_val = torch.min(min_val_cur, self.min_val) 538 max_val = torch.max(max_val_cur, self.max_val) 539 self.min_val.copy_(min_val) 540 self.max_val.copy_(max_val) 541 return x_orig 542 543 @torch.jit.export 544 def calculate_qparams(self): 545 r"""Calculates the quantization parameters.""" 546 return self._calculate_qparams(self.min_val, self.max_val) 547 548 @torch.jit.export 549 def extra_repr(self): 550 return f"min_val={self.min_val}, max_val={self.max_val}" 551 552 @torch.jit.export 553 def reset_min_max_vals(self): 554 """Resets the min/max values.""" 555 self.min_val.copy_(torch.tensor(float("inf"))) 556 self.max_val.copy_(torch.tensor(float("-inf"))) 557 558 559class MovingAverageMinMaxObserver(MinMaxObserver): 560 r"""Observer module for computing the quantization parameters based on the 561 moving average of the min and max values. 562 563 This observer computes the quantization parameters based on the moving 564 averages of minimums and maximums of the incoming tensors. The module 565 records the average minimum and maximum of incoming tensors, and uses this 566 statistic to compute the quantization parameters. 567 568 Args: 569 averaging_constant: Averaging constant for min/max. 570 dtype: dtype argument to the `quantize` node needed to implement the 571 reference model spec. 572 qscheme: Quantization scheme to be used 573 reduce_range: Reduces the range of the quantized data type by 1 bit 574 quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. 575 quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. 576 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 577 578 The moving average min/max is computed as follows 579 580 .. math:: 581 582 \begin{array}{ll} 583 x_\text{min} = \begin{cases} 584 \min(X) & \text{if~}x_\text{min} = \text{None} \\ 585 (1 - c) x_\text{min} + c \min(X) & \text{otherwise} 586 \end{cases}\\ 587 x_\text{max} = \begin{cases} 588 \max(X) & \text{if~}x_\text{max} = \text{None} \\ 589 (1 - c) x_\text{max} + c \max(X) & \text{otherwise} 590 \end{cases}\\ 591 \end{array} 592 593 where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is 594 is the incoming tensor, and :math:`c` is the ``averaging_constant``. 595 596 The scale and zero point are then computed as in 597 :class:`~torch.ao.quantization.observer.MinMaxObserver`. 598 599 .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. 600 601 .. note:: If the running minimum equals to the running maximum, the scale 602 and zero_point are set to 1.0 and 0. 603 """ 604 605 def __init__( 606 self, 607 averaging_constant=0.01, 608 dtype=torch.quint8, 609 qscheme=torch.per_tensor_affine, 610 reduce_range=False, 611 quant_min=None, 612 quant_max=None, 613 eps=torch.finfo(torch.float32).eps, 614 is_dynamic=False, 615 **kwargs, 616 ) -> None: 617 if not is_per_tensor(qscheme): 618 raise NotImplementedError( 619 f"MovingAverageMinMaxObserver's qscheme only support \ 620 torch.per_tensor_symmetric and torch.per_tensor_affine. \ 621 but got: {qscheme}" 622 ) 623 self.averaging_constant = averaging_constant 624 if is_dynamic and self.averaging_constant != 1: 625 raise NotImplementedError( 626 "MovingAverageMinMaxObserver doesn't support dynamic quantization for " 627 f"averaging constant of {self.averaging_constant}" 628 ) 629 super().__init__( 630 dtype=dtype, 631 qscheme=qscheme, 632 reduce_range=reduce_range, 633 quant_min=quant_min, 634 quant_max=quant_max, 635 eps=eps, 636 is_dynamic=is_dynamic, 637 **kwargs, 638 ) 639 640 def forward(self, x_orig): 641 if x_orig.numel() == 0: 642 return x_orig 643 x = x_orig.detach() # avoid keeping autograd tape 644 x = x.to(self.min_val.dtype) 645 min_val = self.min_val 646 max_val = self.max_val 647 if min_val == float("inf") and max_val == float("-inf"): 648 min_val, max_val = torch.aminmax(x) 649 else: 650 min_val_cur, max_val_cur = torch.aminmax(x) 651 min_val = min_val + self.averaging_constant * (min_val_cur - min_val) 652 max_val = max_val + self.averaging_constant * (max_val_cur - max_val) 653 self.min_val.copy_(min_val) 654 self.max_val.copy_(max_val) 655 return x_orig 656 657 658class PerChannelMinMaxObserver(UniformQuantizationObserverBase): 659 r"""Observer module for computing the quantization parameters based on the 660 running per channel min and max values. 661 662 This observer uses the tensor min/max statistics to compute the per channel 663 quantization parameters. The module records the running minimum and maximum 664 of incoming tensors, and uses this statistic to compute the quantization 665 parameters. 666 667 Args: 668 ch_axis: Channel axis 669 dtype: dtype argument to the `quantize` node needed to implement the 670 reference model spec. 671 qscheme: Quantization scheme to be used 672 reduce_range: Reduces the range of the quantized data type by 1 bit 673 quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. 674 quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. 675 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 676 677 The quantization parameters are computed the same way as in 678 :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference 679 that the running min/max values are stored per channel. 680 Scales and zero points are thus computed per channel as well. 681 682 .. note:: If the running minimum equals to the running maximum, the scales 683 and zero_points are set to 1.0 and 0. 684 """ 685 min_val: torch.Tensor 686 max_val: torch.Tensor 687 688 def __init__( 689 self, 690 ch_axis=0, 691 dtype=torch.quint8, 692 qscheme=torch.per_channel_affine, 693 reduce_range=False, 694 quant_min=None, 695 quant_max=None, 696 factory_kwargs=None, 697 eps=torch.finfo(torch.float32).eps, 698 is_dynamic=False, 699 **kwargs, 700 ) -> None: 701 if not is_per_channel(qscheme): 702 raise NotImplementedError( 703 "PerChannelMinMaxObserver's qscheme only support \ 704 torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." 705 ) 706 if is_dynamic: 707 raise NotImplementedError( 708 "PerChannelMinMaxObserver doesn't support dynamic quantization" 709 ) 710 super().__init__( 711 dtype=dtype, 712 qscheme=qscheme, 713 reduce_range=reduce_range, 714 quant_min=quant_min, 715 quant_max=quant_max, 716 factory_kwargs=factory_kwargs, 717 eps=eps, 718 is_dynamic=is_dynamic, 719 **kwargs, 720 ) 721 factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) 722 self.ch_axis = ch_axis 723 self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) 724 self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) 725 if ( 726 self.qscheme == torch.per_channel_symmetric 727 and self.reduce_range 728 and self.dtype == torch.quint8 729 ): 730 raise NotImplementedError( 731 "Cannot reduce range for symmetric quantization for quint8" 732 ) 733 734 def forward(self, x_orig): 735 return self._forward(x_orig) 736 737 def _forward(self, x_orig): 738 if x_orig.numel() == 0: 739 return x_orig 740 x = x_orig.detach() # avoid keeping autograd tape 741 min_val = self.min_val 742 max_val = self.max_val 743 x_dim = x.size() 744 745 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 746 new_axis_list[self.ch_axis] = 0 747 new_axis_list[0] = self.ch_axis 748 y = x.permute(new_axis_list) 749 # Need to match dtype of min/max because the updates to buffers 750 # are done in place and types need to match for comparisons 751 y = y.to(self.min_val.dtype) 752 y = torch.flatten(y, start_dim=1) 753 if min_val.numel() == 0 or max_val.numel() == 0: 754 min_val, max_val = torch.aminmax(y, dim=1) 755 else: 756 min_val_cur, max_val_cur = torch.aminmax(y, dim=1) 757 min_val = torch.min(min_val_cur, min_val) 758 max_val = torch.max(max_val_cur, max_val) 759 self.min_val.resize_(min_val.shape) 760 self.max_val.resize_(max_val.shape) 761 self.min_val.copy_(min_val) 762 self.max_val.copy_(max_val) 763 return x_orig 764 765 @torch.jit.export 766 def calculate_qparams(self): 767 return self._calculate_qparams(self.min_val, self.max_val) 768 769 def extra_repr(self): 770 return f"min_val={self.min_val}, max_val={self.max_val}" 771 772 def _load_from_state_dict( 773 self, 774 state_dict: Dict[str, Any], 775 prefix: str, 776 local_metadata: Dict[str, torch.Tensor], 777 strict: bool, 778 missing_keys: List[str], 779 unexpected_keys: List[str], 780 error_msgs: List[str], 781 ): 782 version = local_metadata.get("version", None) 783 if version is not None and version < 3: 784 local_state = ["min_vals", "max_vals"] 785 expected_min_name = "min_vals" 786 expected_max_name = "max_vals" 787 else: 788 local_state = ["min_val", "max_val"] 789 expected_min_name = "min_val" 790 expected_max_name = "max_val" 791 for name in local_state: 792 key = prefix + name 793 if key in state_dict: 794 val = state_dict[key] 795 # Custom handling to allow loading min_val or max_val 796 # of size N into uninitialized buffers of size 0. The 797 # buffers are resized here, and the values are copied in 798 # the default state_dict loading code of the parent. 799 if name == expected_min_name: 800 self.min_val.resize_(val.shape) 801 elif name == expected_max_name: 802 self.max_val.resize_(val.shape) 803 else: 804 warnings.warn( 805 f"Observer load_from_state_dict got unexpected name {name}" 806 ) 807 # For torchscript module we need to update the attributes here since we do not 808 # call the `_load_from_state_dict` function defined module.py 809 if torch.jit.is_scripting(): 810 if name == expected_min_name: 811 self.min_val.copy_(val) 812 elif name == expected_max_name: 813 self.max_val.copy_(val) 814 else: 815 warnings.warn( 816 f"Observer load_from_state_dict got unexpected name {name}" 817 ) 818 elif strict: 819 missing_keys.append(key) 820 821 if not torch.jit.is_scripting(): 822 super()._load_from_state_dict( 823 state_dict, 824 prefix, 825 local_metadata, 826 False, 827 missing_keys, 828 unexpected_keys, 829 error_msgs, 830 ) 831 832 def _load_from_state_dict_script( 833 self, 834 state_dict: Dict[str, Any], 835 prefix: str, 836 local_metadata: Dict[str, torch.Tensor], 837 strict: bool, 838 missing_keys: List[str], 839 unexpected_keys: List[str], 840 error_msgs: List[str], 841 ): 842 self._load_from_state_dict( 843 state_dict, 844 prefix, 845 local_metadata, 846 strict, 847 missing_keys, 848 unexpected_keys, 849 error_msgs, 850 ) 851 852 @torch.jit.export 853 def reset_min_max_vals(self): 854 """Resets the min/max values.""" 855 # This used to be torch.ones but that does not work because 856 # JIT compiler can optimize it via common subexpression elimination 857 # in which case both min_val and max_val point to the same tensor. 858 self.min_val = torch.rand( 859 0, 860 ) 861 self.max_val = torch.rand( 862 0, 863 ) 864 865 866class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): 867 r"""Observer module for computing the quantization parameters based on the 868 running per channel min and max values. 869 870 This observer uses the tensor min/max statistics to compute the per channel 871 quantization parameters. The module records the running minimum and maximum 872 of incoming tensors, and uses this statistic to compute the quantization 873 parameters. 874 875 Args: 876 averaging_constant: Averaging constant for min/max. 877 ch_axis: Channel axis 878 dtype: Quantized data type 879 qscheme: Quantization scheme to be used 880 reduce_range: Reduces the range of the quantized data type by 1 bit 881 quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. 882 quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. 883 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 884 885 The quantization parameters are computed the same way as in 886 :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the 887 difference that the running min/max values are stored per channel. 888 Scales and zero points are thus computed per channel as well. 889 890 .. note:: If the running minimum equals to the running maximum, the scales 891 and zero_points are set to 1.0 and 0. 892 """ 893 894 def __init__( 895 self, 896 averaging_constant=0.01, 897 ch_axis=0, 898 dtype=torch.quint8, 899 qscheme=torch.per_channel_affine, 900 reduce_range=False, 901 quant_min=None, 902 quant_max=None, 903 eps=torch.finfo(torch.float32).eps, 904 is_dynamic=False, 905 **kwargs, 906 ) -> None: 907 if not is_per_channel(qscheme): 908 raise NotImplementedError( 909 "MovingAveragePerChannelMinMaxObserver's qscheme only support \ 910 torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." 911 ) 912 if is_dynamic: 913 raise NotImplementedError( 914 "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" 915 ) 916 super().__init__( 917 ch_axis=ch_axis, 918 dtype=dtype, 919 qscheme=qscheme, 920 reduce_range=reduce_range, 921 quant_min=quant_min, 922 quant_max=quant_max, 923 eps=eps, 924 is_dynamic=is_dynamic, 925 **kwargs, 926 ) 927 self.averaging_constant = averaging_constant 928 929 def forward(self, x_orig): 930 if x_orig.numel() == 0: 931 return x_orig 932 x = x_orig.detach() # avoid keeping autograd tape 933 x = x.to(self.min_val.dtype) 934 min_val = self.min_val 935 max_val = self.max_val 936 x_dim = x.size() 937 938 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 939 new_axis_list[self.ch_axis] = 0 940 new_axis_list[0] = self.ch_axis 941 y = x.permute(new_axis_list) 942 y = torch.flatten(y, start_dim=1) 943 if min_val.numel() == 0 or max_val.numel() == 0: 944 min_val, max_val = torch.aminmax(y, dim=1) 945 else: 946 min_val_cur, max_val_cur = torch.aminmax(y, dim=1) 947 min_val = min_val + self.averaging_constant * (min_val_cur - min_val) 948 max_val = max_val + self.averaging_constant * (max_val_cur - max_val) 949 self.min_val.resize_(min_val.shape) 950 self.max_val.resize_(max_val.shape) 951 self.min_val.copy_(min_val) 952 self.max_val.copy_(max_val) 953 return x_orig 954 955 956class HistogramObserver(UniformQuantizationObserverBase): 957 r""" 958 The module records the running histogram of tensor values along with 959 min/max values. ``calculate_qparams`` will calculate scale and zero_point. 960 961 Args: 962 bins: Number of bins to use for the histogram 963 dtype: dtype argument to the `quantize` node needed to implement the 964 reference model spec 965 qscheme: Quantization scheme to be used 966 reduce_range: Reduces the range of the quantized data type by 1 bit 967 eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. 968 969 The scale and zero point are computed as follows: 970 971 1. Create the histogram of the incoming inputs. 972 The histogram is computed continuously, and the ranges per bin change 973 with every new tensor observed. 974 2. Search the distribution in the histogram for optimal min/max values. 975 The search for the min/max values ensures the minimization of the 976 quantization error with respect to the floating point model. 977 3. Compute the scale and zero point the same way as in the 978 :class:`~torch.ao.quantization.MinMaxObserver` 979 """ 980 histogram: torch.Tensor 981 min_val: torch.Tensor 982 max_val: torch.Tensor 983 984 def __init__( 985 self, 986 bins: int = 2048, 987 dtype: torch.dtype = torch.quint8, 988 qscheme=torch.per_tensor_affine, 989 reduce_range=False, 990 quant_min=None, 991 quant_max=None, 992 factory_kwargs=None, 993 eps=torch.finfo(torch.float32).eps, 994 is_dynamic=False, 995 **kwargs, 996 ) -> None: 997 if not is_per_tensor(qscheme): 998 raise NotImplementedError( 999 "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ 1000 and torch.per_tensor_affine." 1001 ) 1002 if is_dynamic: 1003 raise NotImplementedError( 1004 "HistogramObserver doesn't support dynamic quantization" 1005 ) 1006 # bins: The number of bins used for histogram calculation. 1007 super().__init__( 1008 dtype=dtype, 1009 qscheme=qscheme, 1010 reduce_range=reduce_range, 1011 quant_min=quant_min, 1012 quant_max=quant_max, 1013 factory_kwargs=factory_kwargs, 1014 eps=eps, 1015 is_dynamic=is_dynamic, 1016 **kwargs, 1017 ) 1018 factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) 1019 self.bins = bins 1020 self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) 1021 self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) 1022 self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) 1023 self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits 1024 self.upsample_rate = ( 1025 16 # used to reduce quantization errors when upscaling histogram 1026 ) 1027 1028 def _get_norm( 1029 self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor 1030 ) -> torch.Tensor: 1031 r""" 1032 Compute the norm of the values uniformaly distributed between 1033 delta_begin and delta_end. 1034 Currently only L2 norm is supported. 1035 1036 norm = density * (integral_{begin, end} x^2) 1037 = density * (end^3 - begin^3) / 3 1038 """ 1039 norm = ( 1040 delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin 1041 ) / 3 1042 return density * norm 1043 1044 def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): 1045 r""" 1046 Compute the quantization error if we use start_bin to end_bin as the 1047 min and max to do the quantization. 1048 """ 1049 bin_width = (self.max_val.item() - self.min_val.item()) / self.bins 1050 1051 dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins 1052 if dst_bin_width == 0.0: 1053 return 0.0 1054 1055 src_bin = torch.arange(self.bins, device=self.histogram.device) 1056 # distances from the beginning of first dst_bin to the beginning and 1057 # end of src_bin 1058 src_bin_begin = (src_bin - next_start_bin) * bin_width 1059 src_bin_end = src_bin_begin + bin_width 1060 1061 # which dst_bins the beginning and end of src_bin belong to? 1062 dst_bin_of_begin = torch.clamp( 1063 torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"), 1064 0, 1065 self.dst_nbins - 1, 1066 ) 1067 dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width 1068 1069 dst_bin_of_end = torch.clamp( 1070 torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"), 1071 0, 1072 self.dst_nbins - 1, 1073 ) 1074 density = self.histogram / bin_width 1075 1076 norm = torch.zeros(self.bins, device=self.histogram.device) 1077 1078 delta_begin = src_bin_begin - dst_bin_of_begin_center 1079 delta_end = dst_bin_width / 2 1080 norm += self._get_norm( 1081 delta_begin, 1082 torch.ones(self.bins, device=self.histogram.device) * delta_end, 1083 density, 1084 ) 1085 1086 norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( 1087 torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density 1088 ) 1089 1090 dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 1091 1092 delta_begin = -dst_bin_width / 2 1093 delta_end = src_bin_end - dst_bin_of_end_center 1094 norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) 1095 1096 return norm.sum().item() 1097 1098 def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: 1099 r"""Non-linear parameter search. 1100 1101 An approximation for L2 error minimization for selecting min/max. 1102 By selecting new min/max, we filter out outliers in input distribution. 1103 This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in 1104 caffe2/quantization/server/norm_minimization.cc 1105 """ 1106 assert self.histogram.size()[0] == self.bins, "bins mismatch" 1107 bin_width = (self.max_val - self.min_val) / self.bins 1108 1109 # cumulative sum 1110 total = torch.sum(self.histogram).item() 1111 cSum = torch.cumsum(self.histogram, dim=0) 1112 1113 stepsize = 1e-5 # granularity 1114 alpha = 0.0 # lower bound 1115 beta = 1.0 # upper bound 1116 start_bin = 0 1117 end_bin = self.bins - 1 1118 norm_min = float("inf") 1119 1120 while alpha < beta: 1121 # Find the next step 1122 next_alpha = alpha + stepsize 1123 next_beta = beta - stepsize 1124 1125 # find the left and right bins between the quantile bounds 1126 l = start_bin 1127 r = end_bin 1128 while l < end_bin and cSum[l] < next_alpha * total: 1129 l = l + 1 1130 while r > start_bin and cSum[r] > next_beta * total: 1131 r = r - 1 1132 1133 # decide the next move 1134 next_start_bin = start_bin 1135 next_end_bin = end_bin 1136 if (l - start_bin) > (end_bin - r): 1137 # move the start bin 1138 next_start_bin = l 1139 alpha = next_alpha 1140 else: 1141 # move the end bin 1142 next_end_bin = r 1143 beta = next_beta 1144 1145 if next_start_bin == start_bin and next_end_bin == end_bin: 1146 continue 1147 1148 # calculate the quantization error using next_start_bin and next_end_bin 1149 norm = self._compute_quantization_error(next_start_bin, next_end_bin) 1150 1151 if norm > norm_min: 1152 break 1153 norm_min = norm 1154 start_bin = next_start_bin 1155 end_bin = next_end_bin 1156 1157 new_min = self.min_val + bin_width * start_bin 1158 new_max = self.min_val + bin_width * (end_bin + 1) 1159 return new_min, new_max 1160 1161 def _upscale_histogram( 1162 self, 1163 histogram: torch.Tensor, 1164 orig_min: torch.Tensor, 1165 orig_max: torch.Tensor, 1166 update_min: torch.Tensor, 1167 update_max: torch.Tensor, 1168 ): 1169 # this turns the histogram into a more fine-coarsed histogram to reduce 1170 # bin quantization errors 1171 histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate 1172 bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) 1173 mid_points_histogram = ( 1174 torch.linspace( 1175 orig_min, 1176 orig_max, 1177 self.bins * self.upsample_rate + 1, 1178 device=orig_min.device, 1179 )[:-1].to(histogram.device) 1180 + 0.5 * bin_size 1181 ) 1182 boundaries_new_histogram = torch.linspace( 1183 update_min, update_max, self.bins + 1, device=update_min.device 1184 ).to(histogram.device) 1185 # this maps the mid-poits of the histogram to the new histogram's space 1186 bucket_assignments = ( 1187 torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) 1188 - 1 1189 ) 1190 # this then maps the histogram mid-points in the new space, weighted by the original histogram's values 1191 # this is just the old histogram in the new histogram's space 1192 1193 # In case due to numerical issues the values land higher/lower than the maximum/minimum 1194 bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 1195 bucket_assignments[bucket_assignments < 0] = 0 1196 1197 update_histogram = torch.bincount( 1198 bucket_assignments, weights=histogram, minlength=self.bins 1199 ) 1200 return update_histogram 1201 1202 def _combine_histograms( 1203 self, 1204 orig_hist: torch.Tensor, 1205 orig_min: torch.Tensor, 1206 orig_max: torch.Tensor, 1207 update_hist: torch.Tensor, 1208 update_min: torch.Tensor, 1209 update_max: torch.Tensor, 1210 ) -> torch.Tensor: 1211 # If the new min and max are the same as the current min and max, 1212 # we can just add the new histogram to the original histogram 1213 if update_min == orig_min and update_max == orig_max: 1214 return orig_hist + update_hist 1215 1216 # If the orig hist only has one value (i.e., the min and max are the same) 1217 # we can just add it into new histogram 1218 if orig_min == orig_max: 1219 bin_value = torch.sum(update_hist) 1220 transformed_orig_hist = ( 1221 torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] 1222 * bin_value 1223 ) 1224 return transformed_orig_hist + update_hist 1225 1226 # We assume the update_hist is already in the target range, we will map the orig_max to it 1227 assert update_min <= orig_min 1228 assert update_max >= orig_max 1229 1230 # Now we need to turn the old_histogram, into the range of the new histogram 1231 transformed_orig_hist = self._upscale_histogram( 1232 orig_hist, 1233 orig_min, 1234 orig_max, 1235 update_min, 1236 update_max, 1237 ) 1238 1239 return update_hist + transformed_orig_hist 1240 1241 def reset_histogram( 1242 self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor 1243 ) -> None: 1244 self.min_val.resize_(min_val.shape) 1245 self.min_val.copy_(min_val) 1246 self.max_val.resize_(max_val.shape) 1247 self.max_val.copy_(max_val) 1248 assert ( 1249 min_val.numel() == 1 and max_val.numel() == 1 1250 ), "histogram min/max values must be scalar." 1251 new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] 1252 self.histogram.detach_().resize_(new_histogram.shape) 1253 self.histogram.copy_(new_histogram) 1254 1255 def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] 1256 if x_orig.numel() == 0: 1257 return x_orig 1258 x = x_orig.detach() 1259 x_min, x_max = torch.aminmax(x) 1260 # want to ignore torch.inf since we don't actually 1261 # want to make our quantization range infinite 1262 # and in practice those values will be clamped 1263 if x_min == -torch.inf or x_max == torch.inf: 1264 warnings.warn("torch.inf detected in input tensor, ignoring input") 1265 x = x[x.abs() != torch.inf] 1266 if x.numel() == 0: 1267 return x_orig 1268 x_min, x_max = torch.aminmax(x) 1269 1270 current_min = self.min_val 1271 current_max = self.max_val 1272 1273 is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf") 1274 if is_uninitialized: 1275 self.reset_histogram(x, x_min, x_max) 1276 else: 1277 update_min, update_max = x_min, x_max 1278 new_min = torch.min(current_min, update_min) 1279 new_max = torch.max(current_max, update_max) 1280 1281 # TODO: For some reason, this is required for it to pass torchscript test 1282 # new_min and new_max should already have requires_grad set to False 1283 new_min, new_max = new_min.detach(), new_max.detach() 1284 update_histogram = torch.histc( 1285 x, self.bins, min=new_min, max=new_max # type: ignore[arg-type] 1286 ).to(self.histogram.device) 1287 if new_min == current_min and new_max == current_max: 1288 combined_histogram = self.histogram + update_histogram 1289 self.histogram.detach_().resize_(combined_histogram.shape) 1290 self.histogram.copy_(combined_histogram) 1291 else: 1292 combined_histogram = self._combine_histograms( 1293 self.histogram, 1294 current_min, 1295 current_max, 1296 update_histogram, 1297 new_min, 1298 new_max, 1299 ) 1300 self.histogram.detach_().resize_(combined_histogram.shape) 1301 self.histogram.copy_(combined_histogram) 1302 self.min_val.detach_().resize_(new_min.shape) 1303 self.min_val.copy_(new_min) 1304 self.max_val.detach_().resize_(new_max.shape) 1305 self.max_val.copy_(new_max) 1306 1307 return x_orig 1308 1309 @torch.jit.export 1310 def calculate_qparams(self): 1311 is_uninitialized = self.min_val == float("inf") and self.max_val == float( 1312 "-inf" 1313 ) 1314 if is_uninitialized: 1315 warnings.warn( 1316 "must run observer before calling calculate_qparams.\ 1317 Returning default scale and zero point " 1318 ) 1319 return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( 1320 [0], device=self.min_val.device.type 1321 ) 1322 assert self.bins == len(self.histogram), ( 1323 "The number of bins in histogram should be equal to the number of bins " 1324 "supplied while making this observer" 1325 ) 1326 1327 new_min, new_max = self._non_linear_param_search() 1328 1329 return self._calculate_qparams(new_min, new_max) 1330 1331 def _save_to_state_dict(self, destination, prefix, keep_vars): 1332 super()._save_to_state_dict(destination, prefix, keep_vars) 1333 destination[prefix + "min_val"] = self.min_val 1334 destination[prefix + "max_val"] = self.max_val 1335 1336 def _load_from_state_dict( 1337 self, 1338 state_dict, 1339 prefix, 1340 local_metadata, 1341 strict, 1342 missing_keys, 1343 unexpected_keys, 1344 error_msgs, 1345 ): 1346 version = local_metadata.get("version", None) 1347 1348 if version is None or version < 3: 1349 # if min_val and max_val are not initialized, update their shape 1350 # to account for the differences between v2 and v3 1351 min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" 1352 if min_val_name in state_dict: 1353 if state_dict[min_val_name].shape == torch.Size([0]): 1354 state_dict[min_val_name] = torch.tensor(float("inf")) 1355 if max_val_name in state_dict: 1356 if state_dict[max_val_name].shape == torch.Size([0]): 1357 state_dict[max_val_name] = torch.tensor(float("-inf")) 1358 1359 local_state = ["min_val", "max_val"] 1360 for name in local_state: 1361 key = prefix + name 1362 if key in state_dict: 1363 val = state_dict[key] 1364 setattr(self, name, val) 1365 elif strict: 1366 missing_keys.append(key) 1367 super()._load_from_state_dict( 1368 state_dict, 1369 prefix, 1370 local_metadata, 1371 strict, 1372 missing_keys, 1373 unexpected_keys, 1374 error_msgs, 1375 ) 1376 1377 def extra_repr(self): 1378 return f"min_val={self.min_val}, max_val={self.max_val}" 1379 1380 1381class FixedQParamsObserver(ObserverBase): 1382 r""" 1383 Observer that simulates quantize and dequantize with fixed 1384 quantization parameters in training time. Only per tensor 1385 quantization is supported. 1386 1387 Args: 1388 `scale` (float): fixed scale for the observer 1389 `zero_point` (int): fixed zero point for the observer 1390 `dtype`, `qscheme`, `quant_min`, `quant_max` 1391 """ 1392 1393 scale: torch.Tensor 1394 zero_point: torch.Tensor 1395 1396 def __init__( 1397 self, 1398 scale, 1399 zero_point, 1400 dtype=torch.quint8, 1401 qscheme=torch.per_tensor_affine, 1402 quant_min=0, 1403 quant_max=255, 1404 is_dynamic=False, 1405 **kwargs, 1406 ): 1407 if is_dynamic: 1408 raise NotImplementedError( 1409 "FixedQParamsObserver doesn't support dynamic quantization" 1410 ) 1411 super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) 1412 self.quant_min = quant_min 1413 self.quant_max = quant_max 1414 self.register_buffer("scale", torch.tensor([scale], dtype=torch.float)) 1415 self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int)) 1416 self.dtype = dtype 1417 self.qscheme = qscheme 1418 1419 def forward(self, X): 1420 return X 1421 1422 @torch.jit.export 1423 def calculate_qparams(self): 1424 return self.scale, self.zero_point 1425 1426 1427class PlaceholderObserver(ObserverBase): 1428 r""" 1429 Observer that doesn't do anything and just passes its configuration to the 1430 quantized module's ``.from_float()``. 1431 1432 Can be used for quantization to float16 which doesn't require determining 1433 ranges. 1434 1435 Args: 1436 dtype: dtype argument to the `quantize` node needed to implement the 1437 reference model spec. 1438 quant_min: minimum value in quantized domain (TODO: align behavior with other observers) 1439 quant_max: maximum value in quantized domain 1440 custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation 1441 (Can be used in Graph Mode Passes for special case ops). 1442 compute_dtype (deprecated): if set, marks the future quantize function to use 1443 dynamic quantization instead of static quantization. 1444 This field is deprecated, use `is_dynamic=True` instead. 1445 is_dynamic: if True, the `quantize` function in the reference model 1446 representation taking stats from this observer instance will 1447 use dynamic quantization. 1448 """ 1449 1450 def __init__( 1451 self, 1452 dtype=torch.float32, 1453 custom_op_name="", 1454 compute_dtype=None, 1455 quant_min=None, 1456 quant_max=None, 1457 qscheme=None, 1458 eps=None, 1459 is_dynamic=False, 1460 ) -> None: 1461 super().__init__(dtype=dtype, is_dynamic=is_dynamic) 1462 if qscheme is None: 1463 qscheme = torch.per_tensor_affine 1464 if eps is None: 1465 eps = torch.finfo(torch.float32).eps 1466 1467 # dtype of input of the target operator, e.g. for dynamic quantization 1468 # ops, the dtype will be float32 1469 self.dtype = dtype 1470 self.qscheme = qscheme 1471 self.quant_min = quant_min 1472 self.quant_max = quant_max 1473 self.eps = eps 1474 self.custom_op = custom_op_name 1475 # used for configuration of computation type for dynamic quantization 1476 if compute_dtype: 1477 is_dynamic = True 1478 warnings.warn( 1479 "Please use `is_dynamic` instead of `compute_dtype`. \ 1480 `compute_dtype` will be deprecated in a future release \ 1481 of PyTorch." 1482 ) 1483 1484 def forward(self, x): 1485 return x 1486 1487 @torch.jit.export 1488 def extra_repr(self): 1489 return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" 1490 1491 @torch.jit.export 1492 def calculate_qparams(self): 1493 raise Exception( # noqa: TRY002 1494 "calculate_qparams should not be called for PlaceholderObserver" 1495 ) 1496 1497 1498class RecordingObserver(ObserverBase): 1499 r""" 1500 The module is mainly for debug and records the tensor values during runtime. 1501 1502 Args: 1503 dtype: Quantized data type 1504 qscheme: Quantization scheme to be used 1505 reduce_range: Reduces the range of the quantized data type by 1 bit 1506 """ 1507 __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]} 1508 1509 def __init__(self, dtype=torch.quint8): 1510 super().__init__(dtype=dtype, is_dynamic=False) # type: ignore[call-arg] 1511 self.tensor_val = [] 1512 1513 def forward(self, x): 1514 self.tensor_val.append(x.clone()) 1515 return x 1516 1517 @torch.jit.export 1518 def calculate_qparams(self): 1519 raise Exception( # noqa: TRY002 1520 "calculate_qparams should not be called for RecordingObserver" 1521 ) 1522 1523 @torch.jit.export 1524 def get_tensor_value(self): 1525 return self.tensor_val 1526 1527 1528class NoopObserver(ObserverBase): 1529 r""" 1530 Observer that doesn't do anything and just passes its configuration to the 1531 quantized module's ``.from_float()``. 1532 1533 Primarily used for quantization to float16 which doesn't require determining 1534 ranges. 1535 1536 Args: 1537 dtype: Quantized data type 1538 custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation 1539 (Can be used in Graph Mode Passes for special case ops). 1540 """ 1541 1542 def __init__(self, dtype=torch.float16, custom_op_name="") -> None: 1543 super().__init__(dtype=dtype, is_dynamic=False) 1544 self.dtype = dtype 1545 self.custom_op = custom_op_name 1546 1547 def forward(self, x): 1548 return x 1549 1550 @torch.jit.export 1551 def calculate_qparams(self): 1552 raise Exception( # noqa: TRY002 1553 "calculate_qparams should not be called for NoopObserver" 1554 ) 1555 1556 1557class ReuseInputObserver(ObserverBase): 1558 r"""This observer is used when we want to reuse the observer from the operator 1559 that produces the input Tensor, typically used for operators like reshape, e.g. 1560 ``` 1561 x0 = ... 1562 x1 = x0.reshape() 1563 ``` 1564 if we configure x0 to be observed by some observer, let's say MinMaxObserver, 1565 and reshape is configured with ReuseInputObserver, we'll reuse the observer instance 1566 for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. 1567 1568 Note: this is only enabled in FX Graph Mode Quantization 1569 """ 1570 1571 def __init__(self) -> None: 1572 super().__init__(torch.quint8, is_dynamic=False) 1573 1574 def forward(self, x): 1575 return x 1576 1577 @torch.jit.export 1578 def calculate_qparams(self): 1579 raise Exception( # noqa: TRY002 1580 "calculate_qparams should not be called for ReuseInputObserver" 1581 ) 1582 1583 1584def _is_observer_script_module(mod, obs_type_name): 1585 """Returns true if given mod is an instance of Observer script module.""" 1586 if isinstance(mod, torch.jit.RecursiveScriptModule): 1587 # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' 1588 suffix = mod._c.qualified_name.split(".", 1)[1] 1589 name = re.sub(r"\.___torch_mangle_\d+", "", suffix) 1590 return obs_type_name in name 1591 return False 1592 1593 1594def _is_activation_post_process(module): 1595 return isinstance( 1596 module, 1597 (torch.ao.quantization.ObserverBase, torch.ao.quantization.FakeQuantizeBase), 1598 ) or _is_observer_script_module(module, "quantization.observer") 1599 1600 1601def _is_per_channel_script_obs_instance(module): 1602 if isinstance(module, torch.jit.RecursiveScriptModule): 1603 return _is_observer_script_module( 1604 module, "quantization.observer.PerChannelMinMaxObserver" 1605 ) or _is_observer_script_module( 1606 module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" 1607 ) 1608 return False 1609 1610 1611def get_observer_state_dict(mod): 1612 r""" 1613 Returns the state dict corresponding to the observer stats. 1614 Traverse the model state_dict and extract out the stats. 1615 """ 1616 od = OrderedDict() 1617 if isinstance(mod, torch.jit.RecursiveScriptModule): 1618 for k, v in mod.state_dict().items(): 1619 if "observer" in k: 1620 od[k] = v 1621 else: 1622 # path for GraphModule and nn.Module (eager mode) 1623 for k, v in mod.state_dict().items(): 1624 if "activation_post_process" in k: 1625 od[k] = v 1626 od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] 1627 return od 1628 1629 1630def load_observer_state_dict(mod, obs_dict): 1631 r""" 1632 Given input model and a state_dict containing model observer stats, 1633 load the stats back into the model. The observer state_dict can be saved 1634 using torch.ao.quantization.get_observer_state_dict 1635 """ 1636 missing_keys: List[str] = [] 1637 unexpected_keys: List[str] = [] 1638 for name, module in mod.named_modules(): 1639 prefix = name + "." 1640 if _is_activation_post_process(module): 1641 if _is_per_channel_script_obs_instance(module): 1642 # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. 1643 # However this is not called when the module is scripted and we end up calling the default one in module.py 1644 module._load_from_state_dict_script( 1645 obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] 1646 ) 1647 else: 1648 module._load_from_state_dict( 1649 obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] 1650 ) 1651 for k in missing_keys: 1652 if "observer" in k or "activation_post_process" in k: 1653 raise Exception( # noqa: TRY002 1654 f"Missing keys for observer {k} in state_dict" 1655 ) 1656 for k in unexpected_keys: 1657 if "observer" in k or "activation_post_process" in k: 1658 raise Exception( # noqa: TRY002 1659 f"Unexpected keys for observer {k} in state_dict" 1660 ) 1661 1662 1663# Restrict activations to be in the range (0,127) 1664default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) 1665""" 1666Default observer for static quantization, usually used for debugging. 1667""" 1668 1669default_placeholder_observer = PlaceholderObserver 1670""" 1671Default placeholder observer, usually used for quantization to torch.float16. 1672""" 1673 1674default_debug_observer = RecordingObserver 1675""" 1676Default debug-only observer. 1677""" 1678 1679default_weight_observer = MinMaxObserver.with_args( 1680 dtype=torch.qint8, qscheme=torch.per_tensor_symmetric 1681) 1682""" 1683Default weight observer. 1684""" 1685 1686weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( 1687 dtype=torch.qint8, 1688 qscheme=torch.per_tensor_symmetric, 1689 quant_min=-127, 1690 quant_max=127, 1691 eps=2**-12, 1692) 1693""" 1694Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. 1695""" 1696 1697default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) 1698""" 1699Default histogram observer, usually used for PTQ. 1700""" 1701 1702default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( 1703 dtype=torch.qint8, qscheme=torch.per_channel_symmetric 1704) 1705""" 1706Default per-channel weight observer, usually used on backends where per-channel 1707weight quantization is supported, such as `fbgemm`. 1708""" 1709 1710per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( 1711 dtype=torch.qint8, 1712 qscheme=torch.per_channel_symmetric, 1713 quant_min=-127, 1714 quant_max=127, 1715 eps=2**-12, 1716) 1717""" 1718Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. 1719""" 1720 1721default_dynamic_quant_observer = PlaceholderObserver.with_args( 1722 dtype=torch.quint8, 1723 quant_min=0, 1724 quant_max=255, 1725 is_dynamic=True, 1726) 1727""" 1728Default observer for dynamic quantization. 1729""" 1730 1731default_float_qparams_observer = PerChannelMinMaxObserver.with_args( 1732 dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 1733) 1734""" 1735Default observer for a floating point zero-point. 1736""" 1737 1738default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( 1739 dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 1740) 1741""" 1742Default observer for a floating point zero-point and 4 bit activations. 1743""" 1744 1745# TODO(future PR): remove these defaults and enforce activation functions 1746# to explicitly specify their output range 1747default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( 1748 scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255 1749) 1750default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( 1751 scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255 1752) 1753# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases 1754default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer 1755default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer 1756 1757""" 1758Default observers for fixed qparams operations. 1759""" 1760 1761default_reuse_input_observer = ReuseInputObserver 1762""" 1763Default observer for operators like reshape that reuses the observer of input to 1764the operator 1765""" 1766