xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_decomposed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4from typing import Optional, Tuple
5
6import torch
7from torch._refs import _unsqueeze_multiple
8from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
9from torch.library import impl, Library
10
11
12# Note: decomposed means decomposed quantized tensor, using decomposed so that the
13# name is not too long
14quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
15
16_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32]
17_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
18
19_DTYPE_TO_QVALUE_BOUNDS = {
20    k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES
21}
22_DTYPE_TO_QVALUE_BOUNDS.update(
23    {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES}
24)
25
26
27# Helper to check the passed in quant min and max are valid for the dtype
28def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
29    if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
30        raise ValueError(f"Unsupported dtype: {dtype}")
31    quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
32
33    assert quant_min >= quant_min_lower_bound, (
34        "quant_min out of bound for dtype, "
35        f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
36    )
37
38    assert quant_max <= quant_max_upper_bound, (
39        "quant_max out of bound for dtype, "
40        f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
41    )
42
43
44quantized_decomposed_lib.define(
45    "quantize_per_tensor(Tensor input, float scale, int zero_point, "
46    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
47)
48
49
50@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
51def quantize_per_tensor(
52    input: torch.Tensor,
53    scale: float,
54    zero_point: int,
55    quant_min: int,
56    quant_max: int,
57    dtype: torch.dtype,
58) -> torch.Tensor:
59    """Affine quantization for the Tensor using the same quantization parameters to map
60    from floating point to quantized values
61
62    Args:
63       input (torch.Tensor): original float32 or bfloat16 Tensor
64       scale (float): quantization parameter for affine quantization
65       zero_point (int): quantization parameter for affine quantization
66       quant_min (int): minimum quantized value for output Tensor
67       quant_max (int): maximum quantized value for output Tensor
68       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
69
70    Returns:
71       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
72       are not stored in the Tensor, we are storing them in function arguments instead
73    """
74    if input.dtype in [torch.float16, torch.bfloat16]:
75        input = input.to(torch.float32)
76    assert (
77        input.dtype == torch.float32
78    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
79    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
80
81    inv_scale = 1.0 / scale
82    return torch.clamp(
83        torch.round(input * inv_scale) + zero_point, quant_min, quant_max
84    ).to(dtype)
85
86
87@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
88def quantize_per_tensor_meta(
89    input: torch.Tensor,
90    scale: float,
91    zero_point: int,
92    quant_min: int,
93    quant_max: int,
94    dtype: torch.dtype,
95) -> torch.Tensor:
96    if input.dtype in [torch.float16, torch.bfloat16]:
97        input = input.to(torch.float32)
98    assert (
99        input.dtype == torch.float32
100    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
101    return torch.empty_like(input, dtype=dtype)
102
103
104quantized_decomposed_lib.define(
105    "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
106    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
107)
108
109
110@impl(
111    quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd"
112)
113def quantize_per_tensor_tensor(
114    input: torch.Tensor,
115    scale: torch.Tensor,
116    zero_point: torch.Tensor,
117    quant_min: int,
118    quant_max: int,
119    dtype: torch.dtype,
120) -> torch.Tensor:
121    """Affine quantization for the Tensor using the same quantization parameters to map
122    from floating point to quantized values
123    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
124    scalar values
125    """
126    assert (
127        zero_point.numel() == 1
128    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
129    assert (
130        scale.numel() == 1
131    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
132    return quantize_per_tensor(
133        input, scale.item(), zero_point.item(), quant_min, quant_max, dtype
134    )
135
136
137@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
138def quantize_per_tensor_tensor_meta(
139    input: torch.Tensor,
140    scale: torch.Tensor,
141    zero_point: torch.Tensor,
142    quant_min: int,
143    quant_max: int,
144    dtype: torch.dtype,
145) -> torch.Tensor:
146    if input.dtype in [torch.float16, torch.bfloat16]:
147        input = input.to(torch.float32)
148    assert (
149        zero_point.numel() == 1
150    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
151    assert (
152        scale.numel() == 1
153    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
154    assert (
155        input.dtype == torch.float32
156    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
157    return torch.empty_like(input, dtype=dtype)
158
159
160# TODO: remove other variants and keep this one
161quantized_decomposed_lib.define(
162    "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
163    "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor"
164)
165
166
167@impl(
168    quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd"
169)
170def quantize_per_tensor_tensor2(
171    input: torch.Tensor,
172    scale: torch.Tensor,
173    zero_point: torch.Tensor,
174    quant_min: torch.Tensor,
175    quant_max: torch.Tensor,
176    dtype: torch.dtype,
177) -> torch.Tensor:
178    """Affine quantization for the Tensor using the same quantization parameters to map
179    from floating point to quantized values
180    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
181    scalar values
182    """
183    assert (
184        zero_point.numel() == 1
185    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
186    assert (
187        scale.numel() == 1
188    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
189    return quantize_per_tensor(
190        input,
191        scale.item(),
192        zero_point.item(),
193        quant_min.item(),
194        quant_max.item(),
195        dtype,
196    )
197
198
199@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
200def quantize_per_tensor_tensor2_meta(
201    input: torch.Tensor,
202    scale: torch.Tensor,
203    zero_point: torch.Tensor,
204    quant_min: torch.Tensor,
205    quant_max: torch.Tensor,
206    dtype: torch.dtype,
207) -> torch.Tensor:
208    return quantize_per_tensor_tensor_meta(
209        input, scale, zero_point, quant_min, quant_max, dtype
210    )
211
212
213# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
214# the signature as metadata for the input Tensor, this might be useful for pattern
215# matching in the future
216# We will revisit this later if we found there are no use cases for it
217quantized_decomposed_lib.define(
218    "dequantize_per_tensor(Tensor input, float scale, int zero_point, "
219    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
220)
221
222
223@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
224def dequantize_per_tensor(
225    input: torch.Tensor,
226    scale: float,
227    zero_point: int,
228    quant_min: int,
229    quant_max: int,
230    dtype: torch.dtype,
231    *,
232    out_dtype: Optional[torch.dtype] = None,
233) -> torch.Tensor:
234    """Affine dequantization for the Tensor using the same quantization parameters to map
235    from quantized values to floating point values
236
237    Args:
238       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
239       e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
240       quantization parameters in the argument of this function (scale/zero_point)
241
242       scale (float): quantization parameter for affine quantization
243
244       zero_point (int): quantization parameter for affine quantization
245
246       quant_min (int): minimum quantized value for input Tensor (not used in computation,
247       reserved for pattern matching)
248
249       quant_max (int): maximum quantized value for input Tensor (not used in computation,
250       reserved for pattern matching)
251
252       dtype (torch.dtype): dtype for input Tensor (not used in computation,
253       reserved for pattern matching)
254
255       out_dtype (torch.dtype?): optional dtype for output Tensor
256
257    Returns:
258       dequantized float32 Tensor
259    """
260    assert (
261        input.dtype == dtype
262    ), f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
263    if out_dtype is None:
264        out_dtype = torch.float32
265    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
266        # TODO: investigate why
267        # (input - zero_point).to(torch.float32) * scale
268        # failed the test
269        return (input.to(out_dtype) - zero_point) * scale
270    else:
271        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
272
273
274@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
275def dequantize_per_tensor_meta(
276    input: torch.Tensor,
277    scale: torch.Tensor,
278    zero_point: torch.Tensor,
279    quant_min: int,
280    quant_max: int,
281    dtype: torch.dtype,
282    *,
283    out_dtype: Optional[torch.dtype] = None,
284) -> torch.Tensor:
285    if out_dtype is None:
286        out_dtype = torch.float32
287    return torch.empty_like(input, dtype=out_dtype)
288
289
290quantized_decomposed_lib.define(
291    "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
292    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
293)
294
295
296@impl(
297    quantized_decomposed_lib,
298    "dequantize_per_tensor.tensor",
299    "CompositeExplicitAutograd",
300)
301def dequantize_per_tensor_tensor(
302    input: torch.Tensor,
303    scale: torch.Tensor,
304    zero_point: torch.Tensor,
305    quant_min: int,
306    quant_max: int,
307    dtype: torch.dtype,
308    *,
309    out_dtype: Optional[torch.dtype] = None,
310) -> torch.Tensor:
311    """Affine dequantization for the Tensor using the same quantization parameters to map
312    from quantized values to floating point values
313    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
314    scalar values
315    """
316    assert (
317        zero_point.numel() == 1
318    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
319    assert (
320        scale.numel() == 1
321    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
322    return dequantize_per_tensor(
323        input,
324        scale.item(),
325        zero_point.item(),
326        quant_min,
327        quant_max,
328        dtype,
329        out_dtype=out_dtype,
330    )
331
332
333@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
334def dequantize_per_tensor_tensor_meta(
335    input: torch.Tensor,
336    scale: torch.Tensor,
337    zero_point: torch.Tensor,
338    quant_min: int,
339    quant_max: int,
340    dtype: torch.dtype,
341    *,
342    out_dtype: Optional[torch.dtype] = None,
343) -> torch.Tensor:
344    if out_dtype is None:
345        out_dtype = torch.float32
346    assert (
347        zero_point.numel() == 1
348    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
349    assert (
350        scale.numel() == 1
351    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
352    assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
353    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
354        return torch.empty_like(input, dtype=out_dtype)
355    else:
356        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
357
358
359# TODO: remove other variants and keep this one
360quantized_decomposed_lib.define(
361    "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
362    "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
363)
364
365
366@impl(
367    quantized_decomposed_lib,
368    "dequantize_per_tensor.tensor2",
369    "CompositeExplicitAutograd",
370)
371def dequantize_per_tensor_tensor2(
372    input: torch.Tensor,
373    scale: torch.Tensor,
374    zero_point: torch.Tensor,
375    quant_min: torch.Tensor,
376    quant_max: torch.Tensor,
377    dtype: torch.dtype,
378    *,
379    out_dtype: Optional[torch.dtype] = None,
380) -> torch.Tensor:
381    """Affine dequantization for the Tensor using the same quantization parameters to map
382    from quantized values to floating point values
383    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
384    scalar values
385    """
386    assert (
387        zero_point.numel() == 1
388    ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
389    assert (
390        scale.numel() == 1
391    ), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
392    return dequantize_per_tensor(
393        input,
394        scale.item(),
395        zero_point.item(),
396        quant_min.item(),
397        quant_max.item(),
398        dtype,
399        out_dtype=out_dtype,
400    )
401
402
403@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
404def dequantize_per_tensor_tensor2_meta(
405    input,
406    scale,
407    zero_point,
408    quant_min,
409    quant_max,
410    dtype,
411    *,
412    out_dtype: Optional[torch.dtype] = None,
413) -> torch.Tensor:
414    return dequantize_per_tensor_tensor_meta(
415        input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype
416    )
417
418
419quantized_decomposed_lib.define(
420    "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
421    "float eps, ScalarType dtype) -> (Tensor, Tensor)"
422)
423
424
425@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
426def choose_qparams_tensor(
427    input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
428) -> Tuple[torch.Tensor, torch.Tensor]:
429    """Given an input Tensor, derive the per tensor affine quantization parameter
430    (scale and zero_point) for target quantized Tensor from the Tensor
431
432    Args:
433       input (torch.Tensor): floating point input Tensor
434       quant_min (int): minimum quantized value for target quantized Tensor
435       quant_max (int): maximum quantized value for target quantized Tensor
436       dtype (torch.dtype): dtype for target quantized Tensor
437
438    Returns:
439       scale (float): quantization parameter for the target quantized Tensor
440       zero_point (int): quantization parameter for the target quantized Tensor
441    """
442    assert input.dtype in [
443        torch.float32,
444        torch.float16,
445        torch.bfloat16,
446    ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
447    assert (
448        dtype in _DTYPE_TO_QVALUE_BOUNDS
449    ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
450    validate_qmin_qmax(qmin, qmax)
451
452    min_val, max_val = torch.aminmax(input)
453
454    return determine_qparams(
455        min_val,
456        max_val,
457        qmin,
458        qmax,
459        dtype,
460        torch.Tensor([eps]),
461        has_customized_qrange=False,
462    )
463
464
465quantized_decomposed_lib.define(
466    "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
467    "float eps, ScalarType dtype) -> (Tensor, Tensor)"
468)
469
470
471@impl(
472    quantized_decomposed_lib,
473    "choose_qparams_symmetric.tensor",
474    "CompositeExplicitAutograd",
475)
476def choose_qparams_symmetric_tensor(
477    input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
478) -> Tuple[torch.Tensor, torch.Tensor]:
479    """Given an input Tensor, derive the per tensor affine quantization parameter
480    (scale and zero_point) for target quantized Tensor from the Tensor
481
482    Args:
483       input (torch.Tensor): floating point input Tensor
484       quant_min (int): minimum quantized value for target quantized Tensor
485       quant_max (int): maximum quantized value for target quantized Tensor
486       dtype (torch.dtype): dtype for target quantized Tensor
487
488    Returns:
489       scale (float): quantization parameter for the target quantized Tensor
490       zero_point (int): quantization parameter for the target quantized Tensor
491    """
492    assert input.dtype in [
493        torch.float32,
494        torch.float16,
495        torch.bfloat16,
496    ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
497    assert (
498        dtype in _DTYPE_TO_QVALUE_BOUNDS
499    ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
500    validate_qmin_qmax(qmin, qmax)
501
502    min_val, max_val = torch.aminmax(input)
503    return determine_qparams(
504        min_val,
505        max_val,
506        qmin,
507        qmax,
508        dtype,
509        torch.Tensor([eps]),
510        has_customized_qrange=False,
511        qscheme=torch.per_tensor_symmetric,
512    )
513
514
515@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
516def choose_qparams_tensor_meta(
517    input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
518) -> Tuple[torch.Tensor, torch.Tensor]:
519    assert input.dtype in [
520        torch.float32,
521        torch.float16,
522        torch.bfloat16,
523    ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
524    assert (
525        quant_min < quant_max
526    ), f"Expecting quant_min to be smaller than quant_max but received min: \
527        {quant_min} max: {quant_max}"
528    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
529        1, dtype=torch.int64, device=input.device
530    )
531
532
533@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
534def choose_qparams_symmetric_tensor_meta(
535    input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
536) -> Tuple[torch.Tensor, torch.Tensor]:
537    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
538        1, dtype=torch.int64, device=input.device
539    )
540
541
542# Helper function used to implement per-channel quantization against any axis
543def _permute_to_axis_zero(x, axis):
544    new_axis_list = list(range(x.dim()))
545    new_axis_list[axis] = 0
546    new_axis_list[0] = axis
547    y = x.permute(tuple(new_axis_list))
548    return y, new_axis_list
549
550
551quantized_decomposed_lib.define(
552    "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
553    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
554)
555
556
557@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
558def quantize_per_channel(
559    input: torch.Tensor,
560    scales: torch.Tensor,
561    zero_points: torch.Tensor,
562    axis: int,
563    quant_min: int,
564    quant_max: int,
565    dtype: torch.dtype,
566) -> torch.Tensor:
567    """Affine per channel quantization for the Tensor using the same quantization
568    parameters for each channel/axis to map from floating point to quantized values
569
570    Args:
571       input (torch.Tensor): original float32 or bfloat16 Tensor
572       scales (torch.Tensor): a list of scale quantization parameter for
573       affine quantization, one per channel
574       zero_point (torch.Tensor): a list of zero_point quantization parameter for
575       affine quantization, one per channel
576       quant_min (int): minimum quantized value for output Tensor
577       quant_max (int): maximum quantized value for output Tensor
578       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
579
580    Returns:
581       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
582       are not stored in the Tensor, we are storing them in function arguments instead
583    """
584    if input.dtype in [torch.float16, torch.bfloat16]:
585        input = input.to(torch.float32)
586    assert (
587        input.dtype == torch.float32
588    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
589    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
590    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
591    input, permute_axis_list = _permute_to_axis_zero(input, axis)
592
593    new_shape = [1] * input.dim()
594    new_shape[0] = scales.shape[0]
595    scales = scales.view(new_shape)
596    zero_points = zero_points.view(new_shape)
597
598    res = torch.clamp(
599        torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max
600    )
601    out = res.permute(tuple(permute_axis_list))
602    return out.to(dtype)
603
604
605@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
606def quantize_per_channel_meta(
607    input: torch.Tensor,
608    scales: torch.Tensor,
609    zero_points: torch.Tensor,
610    axis: int,
611    quant_min: int,
612    quant_max: int,
613    dtype: torch.dtype,
614) -> torch.Tensor:
615    if input.dtype in [torch.float16, torch.bfloat16]:
616        input = input.to(torch.float32)
617    assert (
618        input.dtype == torch.float32
619    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
620    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
621    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
622    return torch.empty_like(input, dtype=dtype)
623
624
625# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
626# the signature as metadata for the input Tensor, this might be useful for pattern
627# matching in the future
628# We will revisit this later if we found there are no use cases for it
629quantized_decomposed_lib.define(
630    "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, "
631    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
632)
633
634
635@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
636def dequantize_per_channel(
637    input: torch.Tensor,
638    scales: torch.Tensor,
639    zero_points: Optional[torch.Tensor],
640    axis: int,
641    quant_min: int,
642    quant_max: int,
643    dtype: torch.dtype,
644    *,
645    out_dtype: Optional[torch.dtype] = None,
646) -> torch.Tensor:
647    """Affine per channel dequantization for the Tensor using the same quantization
648    parameters for each channel/axis to map from quantized values to floating point values
649
650    Args:
651       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
652       e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
653       quantization parameter in the argument of this function (scales/zero_points/axis)
654
655       scales (torch.Tensor): a list of scale quantization parameter for
656       affine quantization, one per channel
657
658       zero_points (torch.Tensor): a list of zero_point quantization parameter for
659       affine quantization, one per channel
660
661       quant_min (int): minimum quantized value for output Tensor (not used in computation,
662       reserved for pattern matching)
663
664       quant_max (int): maximum quantized value for output Tensor (not used in computation,
665       reserved for pattern matching)
666
667       dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
668       reserved for pattern matching)
669
670       out_dtype (torch.dtype?): optional dtype for output Tensor
671
672    Returns:
673       dequantized float32 Tensor
674    """
675    assert (
676        input.dtype == dtype
677    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
678    if out_dtype is None:
679        out_dtype = torch.float32
680    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
681    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
682    input, permute_axis_list = _permute_to_axis_zero(input, axis)
683
684    new_shape = [1] * input.dim()
685    new_shape[0] = scales.shape[0]
686    scales = scales.view(new_shape)
687    if zero_points is not None:
688        res = (input - zero_points.view(new_shape)) * scales
689    else:
690        res = input * scales
691
692    res = res.to(out_dtype)
693
694    out = res.permute(tuple(permute_axis_list))
695    return out
696
697
698@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
699def dequantize_per_channel_meta(
700    input: torch.Tensor,
701    scales: torch.Tensor,
702    zero_points: Optional[torch.Tensor],
703    axis: int,
704    quant_min: int,
705    quant_max: int,
706    dtype: torch.dtype,
707    *,
708    out_dtype: Optional[torch.dtype] = None,
709) -> torch.Tensor:
710    assert (
711        input.dtype == dtype
712    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
713    if out_dtype is None:
714        out_dtype = torch.float32
715    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
716    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
717    return torch.empty_like(input, dtype=out_dtype)
718
719
720quantized_decomposed_lib.define(
721    "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
722)
723
724
725@impl(
726    quantized_decomposed_lib,
727    "choose_qparams_per_token",
728    "CompositeExplicitAutograd",
729)
730def choose_qparams_per_token(
731    input: torch.Tensor,
732    dtype: torch.dtype,
733) -> Tuple[torch.Tensor, torch.Tensor]:
734    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
735    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
736    every N elements with the same quantization parameter. The dimension for scales/zero_points
737    will be (M1 * M2 ... * Mn)
738
739    Args:
740       input (torch.Tensor): original float32/float16 Tensor
741       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
742
743    Returns:
744        scales and zero_points, both float32 Tensors
745    """
746
747    scales = input.abs().amax(dim=-1, keepdim=True)
748    if scales.dtype == torch.float16:
749        scales = (
750            scales.float()
751        )  # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
752    if dtype == torch.int8:
753        n_bits = 8
754        quant_max = 2 ** (n_bits - 1) - 1
755    else:
756        raise Exception(  # noqa: TRY002
757            f"unsupported dtype in choose_qparams_per_token: {dtype}"
758        )
759
760    scales = scales.clamp(min=1e-5).div(quant_max)
761    zero_points = torch.zeros_like(scales)
762    return scales, zero_points
763
764
765@impl(
766    quantized_decomposed_lib,
767    "choose_qparams_per_token",
768    "Meta",
769)
770def choose_qparams_per_token_meta(
771    input: torch.Tensor,
772    dtype: torch.dtype,
773) -> Tuple[torch.Tensor, torch.Tensor]:
774    size = (1, input.size(-1))
775    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
776        size, dtype=torch.int64, device=input.device
777    )
778
779
780quantized_decomposed_lib.define(
781    "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
782)
783
784
785@impl(
786    quantized_decomposed_lib,
787    "_choose_qparams_per_token_asymmetric_impl",
788    "CompositeImplicitAutograd",
789)
790def _choose_qparams_per_token_asymmetric_impl(
791    input: torch.Tensor,
792    dtype: torch.dtype,
793) -> Tuple[torch.Tensor, torch.Tensor]:
794    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
795    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
796    every N elements with the same quantization parameter. The dimension for scales/zero_points
797    will be (M1 * M2 ... * Mn)
798
799    Args:
800       input (torch.Tensor): original float32/float16 Tensor
801       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
802
803    Returns:
804        scales and zero_points, both float32 Tensors
805    """
806    # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
807    qmin, qmax = -128, 127
808    min_val = torch.amin(input, dim=-1, keepdim=True)
809    max_val = torch.amax(input, dim=-1, keepdim=True)
810    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
811    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
812    eps = torch.finfo(torch.float32).eps  # use xnnpack eps?
813
814    # scale
815    scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
816    scale = scale.clamp(min=eps)
817
818    # zero point
819    descaled_min = min_val_neg / scale
820    descaled_max = max_val_pos / scale
821    zero_point_from_min_error = qmin + descaled_min
822    zero_point_from_max_error = qmax + descaled_max
823    zero_point = torch.where(
824        zero_point_from_min_error + zero_point_from_max_error > 0,
825        qmin - descaled_min,
826        qmax - descaled_max,
827    )
828    zero_point = torch.clamp(zero_point, qmin, qmax).round()
829
830    return scale.to(torch.float32), zero_point.to(torch.float32)
831
832
833quantized_decomposed_lib.define(
834    "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
835)
836
837
838@impl(
839    quantized_decomposed_lib,
840    "choose_qparams_per_token_asymmetric",
841    "CompositeExplicitAutograd",
842)
843def choose_qparams_per_token_asymmetric(
844    input: torch.Tensor,
845    dtype: torch.dtype,
846) -> Tuple[torch.Tensor, torch.Tensor]:
847    return _choose_qparams_per_token_asymmetric_impl(input, dtype)
848
849
850@impl(
851    quantized_decomposed_lib,
852    "choose_qparams_per_token_asymmetric",
853    "Meta",
854)
855def choose_qparams_per_token_asymmetric_meta(
856    input: torch.Tensor,
857    dtype: torch.dtype,
858) -> Tuple[torch.Tensor, torch.Tensor]:
859    size = (1, input.size(-1))
860    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
861        size, dtype=torch.int64, device=input.device
862    )
863
864
865def _per_token_quant_qparam_dim_check(input, scales, zero_points):
866    num_tokens = math.prod(list(input.size())[:-1])
867    assert (
868        num_tokens == scales.numel()
869    ), f"num_tokens: {num_tokens} scales: {scales.size()}"
870    assert (
871        num_tokens == zero_points.numel()
872    ), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
873
874
875quantized_decomposed_lib.define(
876    "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
877    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
878)
879
880
881@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd")
882def quantize_per_token(
883    input: torch.Tensor,
884    scales: torch.Tensor,
885    zero_points: torch.Tensor,
886    quant_min: int,
887    quant_max: int,
888    dtype: torch.dtype,
889):
890    """Per token quantization for the Tensor using the quantization parameters to map
891    from floating point to quantized values. This means for a N dimension Tensor
892    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
893    every N elements with the same quantization parameter. The dimension for scales/zero_points
894    will be (M1 * M2 ... * Mn)
895
896    Args:
897       input (torch.Tensor): original float32 or bfloat16 Tensor
898       scales (float32 torch.Tensor): quantization parameter for per token affine quantization
899       zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
900       quant_min (int): minimum quantized value for output Tensor
901       quant_max (int): maximum quantized value for output Tensor
902       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
903
904    Returns:
905       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
906       are not stored in the Tensor, we are storing them in function arguments instead
907    """
908    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
909    _per_token_quant_qparam_dim_check(input, scales, zero_points)
910    input = (
911        input.mul(1.0 / scales)
912        .add(zero_points)
913        .round()
914        .clamp(quant_min, quant_max)
915        .to(dtype)
916    )
917    return input
918
919
920@impl(quantized_decomposed_lib, "quantize_per_token", "Meta")
921def quantize_per_token_meta(
922    input: torch.Tensor,
923    scales: torch.Tensor,
924    zero_points: torch.Tensor,
925    quant_min: int,
926    quant_max: int,
927    dtype: torch.dtype,
928):
929    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
930    return torch.empty_like(input, dtype=dtype)
931
932
933quantized_decomposed_lib.define(
934    "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
935    "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor"
936)
937
938
939@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd")
940def dequantize_per_token(
941    input: torch.Tensor,
942    scales: torch.Tensor,
943    zero_points: torch.Tensor,
944    quant_min: int,
945    quant_max: int,
946    dtype: torch.dtype,
947    output_dtype: torch.dtype = torch.float32,
948):
949    """Per token dequantization for the Tensor using the quantization parameters to map
950    from floating point to quantized values. This means for a N dimension Tensor
951    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
952    every N elements with the same quantization parameter. The dimension for scales/zero_points
953    will be (M1 * M2 ... * Mn)
954
955    Args:
956       input (torch.Tensor): quantized Tensor (uint8, int8 etc.)
957       scales (float32 torch.Tensor): quantization parameter for per token affine quantization
958       zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
959       quant_min (int): minimum quantized value for input Tensor
960       quant_max (int): maximum quantized value for input Tensor
961       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
962       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
963
964    Returns:
965       dequantized Tensor with dtype `output_dtype`
966    """
967    input = input - zero_points
968    input = input.to(output_dtype) * scales
969    return input
970
971
972@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta")
973def dequantize_per_token_meta(
974    input: torch.Tensor,
975    scales: torch.Tensor,
976    zero_points: torch.Tensor,
977    quant_min: int,
978    quant_max: int,
979    dtype: torch.dtype,
980    output_dtype: torch.dtype = torch.float32,
981):
982    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
983    # TODO: support fp16
984    return torch.empty_like(input, dtype=output_dtype)
985
986
987quantized_decomposed_lib.define(
988    "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
989    "int quant_max, ScalarType dtype, int group_size) -> Tensor"
990)
991
992
993# TODO: dtype is ignored for now
994@impl(
995    quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd"
996)
997def quantize_per_channel_group(
998    input: torch.Tensor,
999    scales: torch.Tensor,
1000    zero_points: torch.Tensor,
1001    quant_min: int,
1002    quant_max: int,
1003    dtype: torch.dtype,
1004    group_size=128,
1005):
1006    assert group_size > 1
1007    # needed for GPTQ single column quantize
1008    if group_size > input.shape[-1] and scales.shape[-1] == 1:
1009        group_size = input.shape[-1]
1010
1011    assert input.shape[-1] % group_size == 0
1012    assert input.dim() == 2
1013
1014    # TODO: check for dtype, currently we can't express torch.int4 so it's omitted
1015    to_quant = input.reshape(-1, group_size)
1016    assert torch.isnan(to_quant).sum() == 0
1017
1018    scales = scales.reshape(-1, 1)
1019    zero_points = zero_points.reshape(-1, 1)
1020
1021    input_int8 = (
1022        to_quant.mul(1.0 / scales)
1023        .add(zero_points)
1024        .round()
1025        .clamp_(quant_min, quant_max)
1026        .to(dtype)
1027        .reshape_as(input)
1028    )
1029
1030    return input_int8
1031
1032
1033@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta")
1034def quantize_per_channel_group_meta(
1035    input: torch.Tensor,
1036    scales: torch.Tensor,
1037    zero_points: torch.Tensor,
1038    quant_min: int,
1039    quant_max: int,
1040    dtype: torch.dtype,
1041    group_size=128,
1042):
1043    """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters
1044    to map from floating point to quantized values. This means for each row of a 2-d Tensor
1045    (M, N), we calculate scales/zero_points for each `group_size` elements
1046    and quantize every `group_size` elements with the same quantization parameter.
1047    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
1048
1049    Args:
1050       input (torch.Tensor): original float32 or bfloat16 Tensor
1051       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
1052       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
1053       quant_min (int): minimum quantized value for output Tensor
1054       quant_max (int): maximum quantized value for output Tensor
1055       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
1056
1057    Returns:
1058       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
1059       are not stored in the Tensor, we are storing them in function arguments instead
1060    """
1061    assert group_size > 1
1062    # needed for GPTQ single column quantize
1063    if group_size > input.shape[-1] and scales.shape[-1] == 1:
1064        group_size = input.shape[-1]
1065
1066    assert input.shape[-1] % group_size == 0
1067    assert input.dim() == 2
1068    return torch.empty_like(input, dtype=dtype)
1069
1070
1071quantized_decomposed_lib.define(
1072    "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
1073    "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
1074)
1075
1076
1077@impl(
1078    quantized_decomposed_lib,
1079    "dequantize_per_channel_group",
1080    "CompositeExplicitAutograd",
1081)
1082def dequantize_per_channel_group(
1083    w_int8: torch.Tensor,
1084    scales: torch.Tensor,
1085    zero_points: Optional[torch.Tensor],
1086    quant_min: int,
1087    quant_max: int,
1088    dtype: torch.dtype,
1089    group_size: int = 128,
1090    output_dtype: torch.dtype = torch.float32,
1091):
1092    """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters
1093    to map from floating point to quantized values. This means for each row of a 2-d Tensor
1094    (M, N), we calculate scales/zero_points for each `group_size` elements
1095    and quantize every `group_size` elements with the same quantization parameter.
1096    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
1097
1098    Args:
1099       input (torch.Tensor): quantized Tensor (uint8/int8 etc.)
1100       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
1101       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
1102       quant_min (int): minimum quantized value for input Tensor
1103       quant_max (int): maximum quantized value for input Tensor
1104       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
1105       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
1106
1107    Returns:
1108       dequantized Tensor with dtype `output_dtype`
1109    """
1110
1111    assert group_size > 1
1112    # needed for GPTQ single column dequantize
1113    if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
1114        group_size = w_int8.shape[-1]
1115    assert w_int8.shape[-1] % group_size == 0
1116    assert w_int8.dim() == 2
1117
1118    w_int8_grouped = w_int8.reshape(-1, group_size)
1119    scales = scales.reshape(-1, 1)
1120    if zero_points is not None:
1121        zp = zero_points.reshape(-1, 1)
1122    else:
1123        zp = torch.zeros([], dtype=torch.int32, device=scales.device)
1124    w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype)
1125    return w_dq
1126
1127
1128quantized_decomposed_lib.define(
1129    "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
1130    "int quant_min, int quant_max) -> Tensor"
1131)
1132
1133
1134class FakeQuantPerChannel(torch.autograd.Function):
1135    @staticmethod
1136    def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
1137        if scales.dtype != torch.float32:
1138            scales = scales.to(torch.float32)
1139        if zero_points.dtype != torch.int32:
1140            zero_points = zero_points.to(torch.int32)
1141        assert (
1142            input.dtype == torch.float32
1143        ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
1144        assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
1145        broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
1146        unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
1147        unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
1148        temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
1149        out = (
1150            torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points
1151        ) * unsqueeze_scales
1152        mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
1153
1154        ctx.save_for_backward(mask)
1155        return out
1156
1157    @staticmethod
1158    def backward(ctx, gy):
1159        (mask,) = ctx.saved_tensors
1160        return gy * mask, None, None, None, None, None
1161
1162
1163@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd")
1164def fake_quant_per_channel(
1165    input: torch.Tensor,
1166    scales: torch.Tensor,
1167    zero_points: torch.Tensor,
1168    axis: int,
1169    quant_min: int,
1170    quant_max: int,
1171) -> torch.Tensor:
1172    return FakeQuantPerChannel.apply(
1173        input, scales, zero_points, axis, quant_min, quant_max
1174    )
1175
1176
1177@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta")
1178def fake_quant_per_channel_meta(
1179    input: torch.Tensor,
1180    scales: torch.Tensor,
1181    zero_points: torch.Tensor,
1182    axis: int,
1183    quant_min: int,
1184    quant_max: int,
1185) -> torch.Tensor:
1186    return torch.empty_like(input)
1187