xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/AffineQuantizer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/quantized/AffineQuantizerBase.h>
7 
8 namespace at {
9 namespace native {
10 
11 Tensor& quantize_tensor_per_tensor_affine(
12     const Tensor& rtensor,
13     Tensor& qtensor,
14     double scale,
15     int64_t zero_point);
16 Tensor& quantize_tensor_per_channel_affine(
17     const Tensor& rtensor,
18     Tensor& qtensor,
19     const Tensor& scales,
20     Tensor zero_points,
21     int64_t axis);
22 
23 Tensor& quantize_tensor_per_channel_float_qparams(
24     const Tensor& rtensor,
25     Tensor& qtensor,
26     const Tensor& scales,
27     const Tensor& zero_points,
28     int64_t axis);
29 
30 Tensor& dequantize_tensor_per_tensor_affine(
31     const Tensor& qtensor,
32     Tensor& rtensor,
33     double scale,
34     int64_t zero_point);
35 Tensor& dequantize_tensor_per_channel_affine(
36     const Tensor& qtensor,
37     Tensor& rtensor,
38     const Tensor& scales,
39     Tensor zero_points,
40     int64_t axis);
41 Tensor& dequantize_tensor_per_channel_float_qparams(
42     const Tensor& qtensor,
43     Tensor& rtensor,
44     const Tensor& scales,
45     const Tensor& zero_points,
46     int64_t axis);
47 
48 using quantize_tensor_per_tensor_affine_fn =
49     void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
50 
51 using quantize_tensor_per_channel_affine_fn = void (*)(
52     const Tensor& rtensor,
53     Tensor& qtensor,
54     const Tensor& scales,
55     const Tensor& zero_points,
56     int64_t axis);
57 
58 using quantize_tensor_per_channel_float_qparams_fn = void (*)(
59     const Tensor& rtensor,
60     Tensor& qtensor,
61     const Tensor& scales,
62     const Tensor& zero_points,
63     int64_t axis);
64 
65 using dequantize_tensor_per_tensor_affine_fn =
66     void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
67 
68 using dequantize_tensor_per_channel_affine_fn = void (*)(
69     const Tensor& qtensor,
70     Tensor& rtensor,
71     const Tensor& scales,
72     const Tensor& zero_points,
73     int64_t axis);
74 
75 using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
76     const Tensor& qtensor,
77     Tensor& rtensor,
78     const Tensor& scales,
79     const Tensor& zero_points,
80     int64_t axis);
81 
82 using quantize_tensor_per_tensor_affine_sub_byte_fn =
83     void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
84 
85 using dequantize_tensor_per_tensor_affine_sub_byte_fn =
86     void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
87 
88 DECLARE_DISPATCH(
89     quantize_tensor_per_tensor_affine_fn,
90     quantize_tensor_per_tensor_affine_stub);
91 DECLARE_DISPATCH(
92     quantize_tensor_per_channel_affine_fn,
93     quantize_tensor_per_channel_affine_stub);
94 DECLARE_DISPATCH(
95     quantize_tensor_per_channel_float_qparams_fn,
96     quantize_tensor_per_channel_float_qparams_stub);
97 
98 DECLARE_DISPATCH(
99     dequantize_tensor_per_tensor_affine_fn,
100     dequantize_tensor_per_tensor_affine_stub);
101 DECLARE_DISPATCH(
102     dequantize_tensor_per_channel_affine_fn,
103     dequantize_tensor_per_channel_affine_stub);
104 DECLARE_DISPATCH(
105     dequantize_tensor_per_channel_float_qparams_fn,
106     dequantize_tensor_per_channel_float_qparams_stub);
107 
108 DECLARE_DISPATCH(
109     quantize_tensor_per_tensor_affine_sub_byte_fn,
110     quantize_tensor_per_tensor_affine_sub_byte_stub);
111 
112 DECLARE_DISPATCH(
113     dequantize_tensor_per_tensor_affine_sub_byte_fn,
114     dequantize_tensor_per_tensor_affine_sub_byte_stub);
115 
116 template <typename T>
117 TORCH_API Tensor quantize_tensor(
118     Tensor rtensor,
119     Tensor qtensor,
120     double scale,
121     int64_t zero_point);
122 template <typename T>
123 TORCH_API Tensor dequantize_tensor(
124     Tensor qtensor,
125     Tensor rtensor,
126     double scale,
127     int64_t zero_point);
128 
129 } // namespace native
130 } // namespace at
131