1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4from typing import Optional, Tuple 5 6import torch 7from torch._refs import _unsqueeze_multiple 8from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax 9from torch.library import impl, Library 10 11 12# Note: decomposed means decomposed quantized tensor, using decomposed so that the 13# name is not too long 14quantized_decomposed_lib = Library("quantized_decomposed", "DEF") 15 16_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32] 17_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn] 18 19_DTYPE_TO_QVALUE_BOUNDS = { 20 k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES 21} 22_DTYPE_TO_QVALUE_BOUNDS.update( 23 {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES} 24) 25 26 27# Helper to check the passed in quant min and max are valid for the dtype 28def _quant_min_max_bounds_check(quant_min, quant_max, dtype): 29 if dtype not in _DTYPE_TO_QVALUE_BOUNDS: 30 raise ValueError(f"Unsupported dtype: {dtype}") 31 quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] 32 33 assert quant_min >= quant_min_lower_bound, ( 34 "quant_min out of bound for dtype, " 35 f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" 36 ) 37 38 assert quant_max <= quant_max_upper_bound, ( 39 "quant_max out of bound for dtype, " 40 f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" 41 ) 42 43 44quantized_decomposed_lib.define( 45 "quantize_per_tensor(Tensor input, float scale, int zero_point, " 46 "int quant_min, int quant_max, ScalarType dtype) -> Tensor" 47) 48 49 50@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") 51def quantize_per_tensor( 52 input: torch.Tensor, 53 scale: float, 54 zero_point: int, 55 quant_min: int, 56 quant_max: int, 57 dtype: torch.dtype, 58) -> torch.Tensor: 59 """Affine quantization for the Tensor using the same quantization parameters to map 60 from floating point to quantized values 61 62 Args: 63 input (torch.Tensor): original float32 or bfloat16 Tensor 64 scale (float): quantization parameter for affine quantization 65 zero_point (int): quantization parameter for affine quantization 66 quant_min (int): minimum quantized value for output Tensor 67 quant_max (int): maximum quantized value for output Tensor 68 dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor 69 70 Returns: 71 Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters 72 are not stored in the Tensor, we are storing them in function arguments instead 73 """ 74 if input.dtype in [torch.float16, torch.bfloat16]: 75 input = input.to(torch.float32) 76 assert ( 77 input.dtype == torch.float32 78 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 79 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 80 81 inv_scale = 1.0 / scale 82 return torch.clamp( 83 torch.round(input * inv_scale) + zero_point, quant_min, quant_max 84 ).to(dtype) 85 86 87@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta") 88def quantize_per_tensor_meta( 89 input: torch.Tensor, 90 scale: float, 91 zero_point: int, 92 quant_min: int, 93 quant_max: int, 94 dtype: torch.dtype, 95) -> torch.Tensor: 96 if input.dtype in [torch.float16, torch.bfloat16]: 97 input = input.to(torch.float32) 98 assert ( 99 input.dtype == torch.float32 100 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 101 return torch.empty_like(input, dtype=dtype) 102 103 104quantized_decomposed_lib.define( 105 "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " 106 "int quant_min, int quant_max, ScalarType dtype) -> Tensor" 107) 108 109 110@impl( 111 quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd" 112) 113def quantize_per_tensor_tensor( 114 input: torch.Tensor, 115 scale: torch.Tensor, 116 zero_point: torch.Tensor, 117 quant_min: int, 118 quant_max: int, 119 dtype: torch.dtype, 120) -> torch.Tensor: 121 """Affine quantization for the Tensor using the same quantization parameters to map 122 from floating point to quantized values 123 Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of 124 scalar values 125 """ 126 assert ( 127 zero_point.numel() == 1 128 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 129 assert ( 130 scale.numel() == 1 131 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 132 return quantize_per_tensor( 133 input, scale.item(), zero_point.item(), quant_min, quant_max, dtype 134 ) 135 136 137@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") 138def quantize_per_tensor_tensor_meta( 139 input: torch.Tensor, 140 scale: torch.Tensor, 141 zero_point: torch.Tensor, 142 quant_min: int, 143 quant_max: int, 144 dtype: torch.dtype, 145) -> torch.Tensor: 146 if input.dtype in [torch.float16, torch.bfloat16]: 147 input = input.to(torch.float32) 148 assert ( 149 zero_point.numel() == 1 150 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 151 assert ( 152 scale.numel() == 1 153 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 154 assert ( 155 input.dtype == torch.float32 156 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 157 return torch.empty_like(input, dtype=dtype) 158 159 160# TODO: remove other variants and keep this one 161quantized_decomposed_lib.define( 162 "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " 163 "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor" 164) 165 166 167@impl( 168 quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd" 169) 170def quantize_per_tensor_tensor2( 171 input: torch.Tensor, 172 scale: torch.Tensor, 173 zero_point: torch.Tensor, 174 quant_min: torch.Tensor, 175 quant_max: torch.Tensor, 176 dtype: torch.dtype, 177) -> torch.Tensor: 178 """Affine quantization for the Tensor using the same quantization parameters to map 179 from floating point to quantized values 180 Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of 181 scalar values 182 """ 183 assert ( 184 zero_point.numel() == 1 185 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 186 assert ( 187 scale.numel() == 1 188 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 189 return quantize_per_tensor( 190 input, 191 scale.item(), 192 zero_point.item(), 193 quant_min.item(), 194 quant_max.item(), 195 dtype, 196 ) 197 198 199@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") 200def quantize_per_tensor_tensor2_meta( 201 input: torch.Tensor, 202 scale: torch.Tensor, 203 zero_point: torch.Tensor, 204 quant_min: torch.Tensor, 205 quant_max: torch.Tensor, 206 dtype: torch.dtype, 207) -> torch.Tensor: 208 return quantize_per_tensor_tensor_meta( 209 input, scale, zero_point, quant_min, quant_max, dtype 210 ) 211 212 213# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in 214# the signature as metadata for the input Tensor, this might be useful for pattern 215# matching in the future 216# We will revisit this later if we found there are no use cases for it 217quantized_decomposed_lib.define( 218 "dequantize_per_tensor(Tensor input, float scale, int zero_point, " 219 "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" 220) 221 222 223@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") 224def dequantize_per_tensor( 225 input: torch.Tensor, 226 scale: float, 227 zero_point: int, 228 quant_min: int, 229 quant_max: int, 230 dtype: torch.dtype, 231 *, 232 out_dtype: Optional[torch.dtype] = None, 233) -> torch.Tensor: 234 """Affine dequantization for the Tensor using the same quantization parameters to map 235 from quantized values to floating point values 236 237 Args: 238 input (torch.Tensor): Tensor with dtype matching `dtype` argument, 239 e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with 240 quantization parameters in the argument of this function (scale/zero_point) 241 242 scale (float): quantization parameter for affine quantization 243 244 zero_point (int): quantization parameter for affine quantization 245 246 quant_min (int): minimum quantized value for input Tensor (not used in computation, 247 reserved for pattern matching) 248 249 quant_max (int): maximum quantized value for input Tensor (not used in computation, 250 reserved for pattern matching) 251 252 dtype (torch.dtype): dtype for input Tensor (not used in computation, 253 reserved for pattern matching) 254 255 out_dtype (torch.dtype?): optional dtype for output Tensor 256 257 Returns: 258 dequantized float32 Tensor 259 """ 260 assert ( 261 input.dtype == dtype 262 ), f"Expecting input to have dtype: {dtype}, but got {input.dtype}" 263 if out_dtype is None: 264 out_dtype = torch.float32 265 if dtype in _DTYPE_TO_QVALUE_BOUNDS: 266 # TODO: investigate why 267 # (input - zero_point).to(torch.float32) * scale 268 # failed the test 269 return (input.to(out_dtype) - zero_point) * scale 270 else: 271 raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") 272 273 274@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta") 275def dequantize_per_tensor_meta( 276 input: torch.Tensor, 277 scale: torch.Tensor, 278 zero_point: torch.Tensor, 279 quant_min: int, 280 quant_max: int, 281 dtype: torch.dtype, 282 *, 283 out_dtype: Optional[torch.dtype] = None, 284) -> torch.Tensor: 285 if out_dtype is None: 286 out_dtype = torch.float32 287 return torch.empty_like(input, dtype=out_dtype) 288 289 290quantized_decomposed_lib.define( 291 "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " 292 "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" 293) 294 295 296@impl( 297 quantized_decomposed_lib, 298 "dequantize_per_tensor.tensor", 299 "CompositeExplicitAutograd", 300) 301def dequantize_per_tensor_tensor( 302 input: torch.Tensor, 303 scale: torch.Tensor, 304 zero_point: torch.Tensor, 305 quant_min: int, 306 quant_max: int, 307 dtype: torch.dtype, 308 *, 309 out_dtype: Optional[torch.dtype] = None, 310) -> torch.Tensor: 311 """Affine dequantization for the Tensor using the same quantization parameters to map 312 from quantized values to floating point values 313 Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of 314 scalar values 315 """ 316 assert ( 317 zero_point.numel() == 1 318 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 319 assert ( 320 scale.numel() == 1 321 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 322 return dequantize_per_tensor( 323 input, 324 scale.item(), 325 zero_point.item(), 326 quant_min, 327 quant_max, 328 dtype, 329 out_dtype=out_dtype, 330 ) 331 332 333@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") 334def dequantize_per_tensor_tensor_meta( 335 input: torch.Tensor, 336 scale: torch.Tensor, 337 zero_point: torch.Tensor, 338 quant_min: int, 339 quant_max: int, 340 dtype: torch.dtype, 341 *, 342 out_dtype: Optional[torch.dtype] = None, 343) -> torch.Tensor: 344 if out_dtype is None: 345 out_dtype = torch.float32 346 assert ( 347 zero_point.numel() == 1 348 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 349 assert ( 350 scale.numel() == 1 351 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 352 assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" 353 if dtype in _DTYPE_TO_QVALUE_BOUNDS: 354 return torch.empty_like(input, dtype=out_dtype) 355 else: 356 raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") 357 358 359# TODO: remove other variants and keep this one 360quantized_decomposed_lib.define( 361 "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " 362 "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" 363) 364 365 366@impl( 367 quantized_decomposed_lib, 368 "dequantize_per_tensor.tensor2", 369 "CompositeExplicitAutograd", 370) 371def dequantize_per_tensor_tensor2( 372 input: torch.Tensor, 373 scale: torch.Tensor, 374 zero_point: torch.Tensor, 375 quant_min: torch.Tensor, 376 quant_max: torch.Tensor, 377 dtype: torch.dtype, 378 *, 379 out_dtype: Optional[torch.dtype] = None, 380) -> torch.Tensor: 381 """Affine dequantization for the Tensor using the same quantization parameters to map 382 from quantized values to floating point values 383 Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of 384 scalar values 385 """ 386 assert ( 387 zero_point.numel() == 1 388 ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" 389 assert ( 390 scale.numel() == 1 391 ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" 392 return dequantize_per_tensor( 393 input, 394 scale.item(), 395 zero_point.item(), 396 quant_min.item(), 397 quant_max.item(), 398 dtype, 399 out_dtype=out_dtype, 400 ) 401 402 403@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") 404def dequantize_per_tensor_tensor2_meta( 405 input, 406 scale, 407 zero_point, 408 quant_min, 409 quant_max, 410 dtype, 411 *, 412 out_dtype: Optional[torch.dtype] = None, 413) -> torch.Tensor: 414 return dequantize_per_tensor_tensor_meta( 415 input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype 416 ) 417 418 419quantized_decomposed_lib.define( 420 "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " 421 "float eps, ScalarType dtype) -> (Tensor, Tensor)" 422) 423 424 425@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") 426def choose_qparams_tensor( 427 input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype 428) -> Tuple[torch.Tensor, torch.Tensor]: 429 """Given an input Tensor, derive the per tensor affine quantization parameter 430 (scale and zero_point) for target quantized Tensor from the Tensor 431 432 Args: 433 input (torch.Tensor): floating point input Tensor 434 quant_min (int): minimum quantized value for target quantized Tensor 435 quant_max (int): maximum quantized value for target quantized Tensor 436 dtype (torch.dtype): dtype for target quantized Tensor 437 438 Returns: 439 scale (float): quantization parameter for the target quantized Tensor 440 zero_point (int): quantization parameter for the target quantized Tensor 441 """ 442 assert input.dtype in [ 443 torch.float32, 444 torch.float16, 445 torch.bfloat16, 446 ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" 447 assert ( 448 dtype in _DTYPE_TO_QVALUE_BOUNDS 449 ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" 450 validate_qmin_qmax(qmin, qmax) 451 452 min_val, max_val = torch.aminmax(input) 453 454 return determine_qparams( 455 min_val, 456 max_val, 457 qmin, 458 qmax, 459 dtype, 460 torch.Tensor([eps]), 461 has_customized_qrange=False, 462 ) 463 464 465quantized_decomposed_lib.define( 466 "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, " 467 "float eps, ScalarType dtype) -> (Tensor, Tensor)" 468) 469 470 471@impl( 472 quantized_decomposed_lib, 473 "choose_qparams_symmetric.tensor", 474 "CompositeExplicitAutograd", 475) 476def choose_qparams_symmetric_tensor( 477 input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype 478) -> Tuple[torch.Tensor, torch.Tensor]: 479 """Given an input Tensor, derive the per tensor affine quantization parameter 480 (scale and zero_point) for target quantized Tensor from the Tensor 481 482 Args: 483 input (torch.Tensor): floating point input Tensor 484 quant_min (int): minimum quantized value for target quantized Tensor 485 quant_max (int): maximum quantized value for target quantized Tensor 486 dtype (torch.dtype): dtype for target quantized Tensor 487 488 Returns: 489 scale (float): quantization parameter for the target quantized Tensor 490 zero_point (int): quantization parameter for the target quantized Tensor 491 """ 492 assert input.dtype in [ 493 torch.float32, 494 torch.float16, 495 torch.bfloat16, 496 ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" 497 assert ( 498 dtype in _DTYPE_TO_QVALUE_BOUNDS 499 ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" 500 validate_qmin_qmax(qmin, qmax) 501 502 min_val, max_val = torch.aminmax(input) 503 return determine_qparams( 504 min_val, 505 max_val, 506 qmin, 507 qmax, 508 dtype, 509 torch.Tensor([eps]), 510 has_customized_qrange=False, 511 qscheme=torch.per_tensor_symmetric, 512 ) 513 514 515@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta") 516def choose_qparams_tensor_meta( 517 input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype 518) -> Tuple[torch.Tensor, torch.Tensor]: 519 assert input.dtype in [ 520 torch.float32, 521 torch.float16, 522 torch.bfloat16, 523 ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" 524 assert ( 525 quant_min < quant_max 526 ), f"Expecting quant_min to be smaller than quant_max but received min: \ 527 {quant_min} max: {quant_max}" 528 return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( 529 1, dtype=torch.int64, device=input.device 530 ) 531 532 533@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta") 534def choose_qparams_symmetric_tensor_meta( 535 input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype 536) -> Tuple[torch.Tensor, torch.Tensor]: 537 return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( 538 1, dtype=torch.int64, device=input.device 539 ) 540 541 542# Helper function used to implement per-channel quantization against any axis 543def _permute_to_axis_zero(x, axis): 544 new_axis_list = list(range(x.dim())) 545 new_axis_list[axis] = 0 546 new_axis_list[0] = axis 547 y = x.permute(tuple(new_axis_list)) 548 return y, new_axis_list 549 550 551quantized_decomposed_lib.define( 552 "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " 553 "int quant_min, int quant_max, ScalarType dtype) -> Tensor" 554) 555 556 557@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") 558def quantize_per_channel( 559 input: torch.Tensor, 560 scales: torch.Tensor, 561 zero_points: torch.Tensor, 562 axis: int, 563 quant_min: int, 564 quant_max: int, 565 dtype: torch.dtype, 566) -> torch.Tensor: 567 """Affine per channel quantization for the Tensor using the same quantization 568 parameters for each channel/axis to map from floating point to quantized values 569 570 Args: 571 input (torch.Tensor): original float32 or bfloat16 Tensor 572 scales (torch.Tensor): a list of scale quantization parameter for 573 affine quantization, one per channel 574 zero_point (torch.Tensor): a list of zero_point quantization parameter for 575 affine quantization, one per channel 576 quant_min (int): minimum quantized value for output Tensor 577 quant_max (int): maximum quantized value for output Tensor 578 dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor 579 580 Returns: 581 Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters 582 are not stored in the Tensor, we are storing them in function arguments instead 583 """ 584 if input.dtype in [torch.float16, torch.bfloat16]: 585 input = input.to(torch.float32) 586 assert ( 587 input.dtype == torch.float32 588 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 589 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" 590 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 591 input, permute_axis_list = _permute_to_axis_zero(input, axis) 592 593 new_shape = [1] * input.dim() 594 new_shape[0] = scales.shape[0] 595 scales = scales.view(new_shape) 596 zero_points = zero_points.view(new_shape) 597 598 res = torch.clamp( 599 torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max 600 ) 601 out = res.permute(tuple(permute_axis_list)) 602 return out.to(dtype) 603 604 605@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta") 606def quantize_per_channel_meta( 607 input: torch.Tensor, 608 scales: torch.Tensor, 609 zero_points: torch.Tensor, 610 axis: int, 611 quant_min: int, 612 quant_max: int, 613 dtype: torch.dtype, 614) -> torch.Tensor: 615 if input.dtype in [torch.float16, torch.bfloat16]: 616 input = input.to(torch.float32) 617 assert ( 618 input.dtype == torch.float32 619 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 620 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" 621 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 622 return torch.empty_like(input, dtype=dtype) 623 624 625# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in 626# the signature as metadata for the input Tensor, this might be useful for pattern 627# matching in the future 628# We will revisit this later if we found there are no use cases for it 629quantized_decomposed_lib.define( 630 "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, " 631 "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" 632) 633 634 635@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") 636def dequantize_per_channel( 637 input: torch.Tensor, 638 scales: torch.Tensor, 639 zero_points: Optional[torch.Tensor], 640 axis: int, 641 quant_min: int, 642 quant_max: int, 643 dtype: torch.dtype, 644 *, 645 out_dtype: Optional[torch.dtype] = None, 646) -> torch.Tensor: 647 """Affine per channel dequantization for the Tensor using the same quantization 648 parameters for each channel/axis to map from quantized values to floating point values 649 650 Args: 651 input (torch.Tensor): Tensor with dtype matching `dtype` argument, 652 e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with 653 quantization parameter in the argument of this function (scales/zero_points/axis) 654 655 scales (torch.Tensor): a list of scale quantization parameter for 656 affine quantization, one per channel 657 658 zero_points (torch.Tensor): a list of zero_point quantization parameter for 659 affine quantization, one per channel 660 661 quant_min (int): minimum quantized value for output Tensor (not used in computation, 662 reserved for pattern matching) 663 664 quant_max (int): maximum quantized value for output Tensor (not used in computation, 665 reserved for pattern matching) 666 667 dtype (torch.dtype): requested dtype for output Tensor (not used in computation, 668 reserved for pattern matching) 669 670 out_dtype (torch.dtype?): optional dtype for output Tensor 671 672 Returns: 673 dequantized float32 Tensor 674 """ 675 assert ( 676 input.dtype == dtype 677 ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" 678 if out_dtype is None: 679 out_dtype = torch.float32 680 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" 681 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 682 input, permute_axis_list = _permute_to_axis_zero(input, axis) 683 684 new_shape = [1] * input.dim() 685 new_shape[0] = scales.shape[0] 686 scales = scales.view(new_shape) 687 if zero_points is not None: 688 res = (input - zero_points.view(new_shape)) * scales 689 else: 690 res = input * scales 691 692 res = res.to(out_dtype) 693 694 out = res.permute(tuple(permute_axis_list)) 695 return out 696 697 698@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta") 699def dequantize_per_channel_meta( 700 input: torch.Tensor, 701 scales: torch.Tensor, 702 zero_points: Optional[torch.Tensor], 703 axis: int, 704 quant_min: int, 705 quant_max: int, 706 dtype: torch.dtype, 707 *, 708 out_dtype: Optional[torch.dtype] = None, 709) -> torch.Tensor: 710 assert ( 711 input.dtype == dtype 712 ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" 713 if out_dtype is None: 714 out_dtype = torch.float32 715 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" 716 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 717 return torch.empty_like(input, dtype=out_dtype) 718 719 720quantized_decomposed_lib.define( 721 "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" 722) 723 724 725@impl( 726 quantized_decomposed_lib, 727 "choose_qparams_per_token", 728 "CompositeExplicitAutograd", 729) 730def choose_qparams_per_token( 731 input: torch.Tensor, 732 dtype: torch.dtype, 733) -> Tuple[torch.Tensor, torch.Tensor]: 734 """Choose quantization parameters for per token quantization. This means for a N dimension Tensor 735 (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize 736 every N elements with the same quantization parameter. The dimension for scales/zero_points 737 will be (M1 * M2 ... * Mn) 738 739 Args: 740 input (torch.Tensor): original float32/float16 Tensor 741 dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor 742 743 Returns: 744 scales and zero_points, both float32 Tensors 745 """ 746 747 scales = input.abs().amax(dim=-1, keepdim=True) 748 if scales.dtype == torch.float16: 749 scales = ( 750 scales.float() 751 ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) 752 if dtype == torch.int8: 753 n_bits = 8 754 quant_max = 2 ** (n_bits - 1) - 1 755 else: 756 raise Exception( # noqa: TRY002 757 f"unsupported dtype in choose_qparams_per_token: {dtype}" 758 ) 759 760 scales = scales.clamp(min=1e-5).div(quant_max) 761 zero_points = torch.zeros_like(scales) 762 return scales, zero_points 763 764 765@impl( 766 quantized_decomposed_lib, 767 "choose_qparams_per_token", 768 "Meta", 769) 770def choose_qparams_per_token_meta( 771 input: torch.Tensor, 772 dtype: torch.dtype, 773) -> Tuple[torch.Tensor, torch.Tensor]: 774 size = (1, input.size(-1)) 775 return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( 776 size, dtype=torch.int64, device=input.device 777 ) 778 779 780quantized_decomposed_lib.define( 781 "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" 782) 783 784 785@impl( 786 quantized_decomposed_lib, 787 "_choose_qparams_per_token_asymmetric_impl", 788 "CompositeImplicitAutograd", 789) 790def _choose_qparams_per_token_asymmetric_impl( 791 input: torch.Tensor, 792 dtype: torch.dtype, 793) -> Tuple[torch.Tensor, torch.Tensor]: 794 """Choose quantization parameters for per token quantization. This means for a N dimension Tensor 795 (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize 796 every N elements with the same quantization parameter. The dimension for scales/zero_points 797 will be (M1 * M2 ... * Mn) 798 799 Args: 800 input (torch.Tensor): original float32/float16 Tensor 801 dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor 802 803 Returns: 804 scales and zero_points, both float32 Tensors 805 """ 806 # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 807 qmin, qmax = -128, 127 808 min_val = torch.amin(input, dim=-1, keepdim=True) 809 max_val = torch.amax(input, dim=-1, keepdim=True) 810 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 811 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 812 eps = torch.finfo(torch.float32).eps # use xnnpack eps? 813 814 # scale 815 scale = (max_val_pos - min_val_neg) / float(qmax - qmin) 816 scale = scale.clamp(min=eps) 817 818 # zero point 819 descaled_min = min_val_neg / scale 820 descaled_max = max_val_pos / scale 821 zero_point_from_min_error = qmin + descaled_min 822 zero_point_from_max_error = qmax + descaled_max 823 zero_point = torch.where( 824 zero_point_from_min_error + zero_point_from_max_error > 0, 825 qmin - descaled_min, 826 qmax - descaled_max, 827 ) 828 zero_point = torch.clamp(zero_point, qmin, qmax).round() 829 830 return scale.to(torch.float32), zero_point.to(torch.float32) 831 832 833quantized_decomposed_lib.define( 834 "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" 835) 836 837 838@impl( 839 quantized_decomposed_lib, 840 "choose_qparams_per_token_asymmetric", 841 "CompositeExplicitAutograd", 842) 843def choose_qparams_per_token_asymmetric( 844 input: torch.Tensor, 845 dtype: torch.dtype, 846) -> Tuple[torch.Tensor, torch.Tensor]: 847 return _choose_qparams_per_token_asymmetric_impl(input, dtype) 848 849 850@impl( 851 quantized_decomposed_lib, 852 "choose_qparams_per_token_asymmetric", 853 "Meta", 854) 855def choose_qparams_per_token_asymmetric_meta( 856 input: torch.Tensor, 857 dtype: torch.dtype, 858) -> Tuple[torch.Tensor, torch.Tensor]: 859 size = (1, input.size(-1)) 860 return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( 861 size, dtype=torch.int64, device=input.device 862 ) 863 864 865def _per_token_quant_qparam_dim_check(input, scales, zero_points): 866 num_tokens = math.prod(list(input.size())[:-1]) 867 assert ( 868 num_tokens == scales.numel() 869 ), f"num_tokens: {num_tokens} scales: {scales.size()}" 870 assert ( 871 num_tokens == zero_points.numel() 872 ), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" 873 874 875quantized_decomposed_lib.define( 876 "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " 877 "int quant_min, int quant_max, ScalarType dtype) -> Tensor" 878) 879 880 881@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd") 882def quantize_per_token( 883 input: torch.Tensor, 884 scales: torch.Tensor, 885 zero_points: torch.Tensor, 886 quant_min: int, 887 quant_max: int, 888 dtype: torch.dtype, 889): 890 """Per token quantization for the Tensor using the quantization parameters to map 891 from floating point to quantized values. This means for a N dimension Tensor 892 (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize 893 every N elements with the same quantization parameter. The dimension for scales/zero_points 894 will be (M1 * M2 ... * Mn) 895 896 Args: 897 input (torch.Tensor): original float32 or bfloat16 Tensor 898 scales (float32 torch.Tensor): quantization parameter for per token affine quantization 899 zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization 900 quant_min (int): minimum quantized value for output Tensor 901 quant_max (int): maximum quantized value for output Tensor 902 dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor 903 904 Returns: 905 Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters 906 are not stored in the Tensor, we are storing them in function arguments instead 907 """ 908 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 909 _per_token_quant_qparam_dim_check(input, scales, zero_points) 910 input = ( 911 input.mul(1.0 / scales) 912 .add(zero_points) 913 .round() 914 .clamp(quant_min, quant_max) 915 .to(dtype) 916 ) 917 return input 918 919 920@impl(quantized_decomposed_lib, "quantize_per_token", "Meta") 921def quantize_per_token_meta( 922 input: torch.Tensor, 923 scales: torch.Tensor, 924 zero_points: torch.Tensor, 925 quant_min: int, 926 quant_max: int, 927 dtype: torch.dtype, 928): 929 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 930 return torch.empty_like(input, dtype=dtype) 931 932 933quantized_decomposed_lib.define( 934 "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " 935 "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor" 936) 937 938 939@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd") 940def dequantize_per_token( 941 input: torch.Tensor, 942 scales: torch.Tensor, 943 zero_points: torch.Tensor, 944 quant_min: int, 945 quant_max: int, 946 dtype: torch.dtype, 947 output_dtype: torch.dtype = torch.float32, 948): 949 """Per token dequantization for the Tensor using the quantization parameters to map 950 from floating point to quantized values. This means for a N dimension Tensor 951 (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize 952 every N elements with the same quantization parameter. The dimension for scales/zero_points 953 will be (M1 * M2 ... * Mn) 954 955 Args: 956 input (torch.Tensor): quantized Tensor (uint8, int8 etc.) 957 scales (float32 torch.Tensor): quantization parameter for per token affine quantization 958 zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization 959 quant_min (int): minimum quantized value for input Tensor 960 quant_max (int): maximum quantized value for input Tensor 961 dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor 962 output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor 963 964 Returns: 965 dequantized Tensor with dtype `output_dtype` 966 """ 967 input = input - zero_points 968 input = input.to(output_dtype) * scales 969 return input 970 971 972@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta") 973def dequantize_per_token_meta( 974 input: torch.Tensor, 975 scales: torch.Tensor, 976 zero_points: torch.Tensor, 977 quant_min: int, 978 quant_max: int, 979 dtype: torch.dtype, 980 output_dtype: torch.dtype = torch.float32, 981): 982 _quant_min_max_bounds_check(quant_min, quant_max, dtype) 983 # TODO: support fp16 984 return torch.empty_like(input, dtype=output_dtype) 985 986 987quantized_decomposed_lib.define( 988 "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, " 989 "int quant_max, ScalarType dtype, int group_size) -> Tensor" 990) 991 992 993# TODO: dtype is ignored for now 994@impl( 995 quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd" 996) 997def quantize_per_channel_group( 998 input: torch.Tensor, 999 scales: torch.Tensor, 1000 zero_points: torch.Tensor, 1001 quant_min: int, 1002 quant_max: int, 1003 dtype: torch.dtype, 1004 group_size=128, 1005): 1006 assert group_size > 1 1007 # needed for GPTQ single column quantize 1008 if group_size > input.shape[-1] and scales.shape[-1] == 1: 1009 group_size = input.shape[-1] 1010 1011 assert input.shape[-1] % group_size == 0 1012 assert input.dim() == 2 1013 1014 # TODO: check for dtype, currently we can't express torch.int4 so it's omitted 1015 to_quant = input.reshape(-1, group_size) 1016 assert torch.isnan(to_quant).sum() == 0 1017 1018 scales = scales.reshape(-1, 1) 1019 zero_points = zero_points.reshape(-1, 1) 1020 1021 input_int8 = ( 1022 to_quant.mul(1.0 / scales) 1023 .add(zero_points) 1024 .round() 1025 .clamp_(quant_min, quant_max) 1026 .to(dtype) 1027 .reshape_as(input) 1028 ) 1029 1030 return input_int8 1031 1032 1033@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta") 1034def quantize_per_channel_group_meta( 1035 input: torch.Tensor, 1036 scales: torch.Tensor, 1037 zero_points: torch.Tensor, 1038 quant_min: int, 1039 quant_max: int, 1040 dtype: torch.dtype, 1041 group_size=128, 1042): 1043 """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters 1044 to map from floating point to quantized values. This means for each row of a 2-d Tensor 1045 (M, N), we calculate scales/zero_points for each `group_size` elements 1046 and quantize every `group_size` elements with the same quantization parameter. 1047 The dimension for scales/zero_points will be (M * ceil(N, group_size),) 1048 1049 Args: 1050 input (torch.Tensor): original float32 or bfloat16 Tensor 1051 scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization 1052 zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization 1053 quant_min (int): minimum quantized value for output Tensor 1054 quant_max (int): maximum quantized value for output Tensor 1055 dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor 1056 1057 Returns: 1058 Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters 1059 are not stored in the Tensor, we are storing them in function arguments instead 1060 """ 1061 assert group_size > 1 1062 # needed for GPTQ single column quantize 1063 if group_size > input.shape[-1] and scales.shape[-1] == 1: 1064 group_size = input.shape[-1] 1065 1066 assert input.shape[-1] % group_size == 0 1067 assert input.dim() == 2 1068 return torch.empty_like(input, dtype=dtype) 1069 1070 1071quantized_decomposed_lib.define( 1072 "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, " 1073 "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor" 1074) 1075 1076 1077@impl( 1078 quantized_decomposed_lib, 1079 "dequantize_per_channel_group", 1080 "CompositeExplicitAutograd", 1081) 1082def dequantize_per_channel_group( 1083 w_int8: torch.Tensor, 1084 scales: torch.Tensor, 1085 zero_points: Optional[torch.Tensor], 1086 quant_min: int, 1087 quant_max: int, 1088 dtype: torch.dtype, 1089 group_size: int = 128, 1090 output_dtype: torch.dtype = torch.float32, 1091): 1092 """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters 1093 to map from floating point to quantized values. This means for each row of a 2-d Tensor 1094 (M, N), we calculate scales/zero_points for each `group_size` elements 1095 and quantize every `group_size` elements with the same quantization parameter. 1096 The dimension for scales/zero_points will be (M * ceil(N, group_size),) 1097 1098 Args: 1099 input (torch.Tensor): quantized Tensor (uint8/int8 etc.) 1100 scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization 1101 zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization 1102 quant_min (int): minimum quantized value for input Tensor 1103 quant_max (int): maximum quantized value for input Tensor 1104 dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor 1105 output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor 1106 1107 Returns: 1108 dequantized Tensor with dtype `output_dtype` 1109 """ 1110 1111 assert group_size > 1 1112 # needed for GPTQ single column dequantize 1113 if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: 1114 group_size = w_int8.shape[-1] 1115 assert w_int8.shape[-1] % group_size == 0 1116 assert w_int8.dim() == 2 1117 1118 w_int8_grouped = w_int8.reshape(-1, group_size) 1119 scales = scales.reshape(-1, 1) 1120 if zero_points is not None: 1121 zp = zero_points.reshape(-1, 1) 1122 else: 1123 zp = torch.zeros([], dtype=torch.int32, device=scales.device) 1124 w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype) 1125 return w_dq 1126 1127 1128quantized_decomposed_lib.define( 1129 "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " 1130 "int quant_min, int quant_max) -> Tensor" 1131) 1132 1133 1134class FakeQuantPerChannel(torch.autograd.Function): 1135 @staticmethod 1136 def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): 1137 if scales.dtype != torch.float32: 1138 scales = scales.to(torch.float32) 1139 if zero_points.dtype != torch.int32: 1140 zero_points = zero_points.to(torch.int32) 1141 assert ( 1142 input.dtype == torch.float32 1143 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" 1144 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" 1145 broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim)) 1146 unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) 1147 unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) 1148 temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points 1149 out = ( 1150 torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points 1151 ) * unsqueeze_scales 1152 mask = torch.logical_and((temp >= quant_min), (temp <= quant_max)) 1153 1154 ctx.save_for_backward(mask) 1155 return out 1156 1157 @staticmethod 1158 def backward(ctx, gy): 1159 (mask,) = ctx.saved_tensors 1160 return gy * mask, None, None, None, None, None 1161 1162 1163@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd") 1164def fake_quant_per_channel( 1165 input: torch.Tensor, 1166 scales: torch.Tensor, 1167 zero_points: torch.Tensor, 1168 axis: int, 1169 quant_min: int, 1170 quant_max: int, 1171) -> torch.Tensor: 1172 return FakeQuantPerChannel.apply( 1173 input, scales, zero_points, axis, quant_min, quant_max 1174 ) 1175 1176 1177@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta") 1178def fake_quant_per_channel_meta( 1179 input: torch.Tensor, 1180 scales: torch.Tensor, 1181 zero_points: torch.Tensor, 1182 axis: int, 1183 quant_min: int, 1184 quant_max: int, 1185) -> torch.Tensor: 1186 return torch.empty_like(input) 1187