xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/QuantizedOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/IListRef.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorIterator.h>
6 #include <ATen/native/Activation.h>
7 #include <ATen/native/DispatchStub.h>
8 
9 namespace at {
10 namespace native {
11 
12 using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
13 using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
14                                 const Scalar& /*negval_*/);
15 using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */);
16 using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
17 using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
18 using qclamp_fn = void (*)(
19     const at::Tensor& /*qx*/,
20     const Scalar& min,
21     const Scalar& max,
22     at::Tensor& /*qy*/);
23 using qclamp_minmax_fn = void (*)(
24     const at::Tensor& /*qx*/,
25     const Scalar& /*min or max*/,
26     at::Tensor& /*qy*/);
27 using qthreshold_fn = void (*)(
28     const at::Tensor& /*qx*/,
29     const Scalar& threshold,
30     const Scalar& value,
31     at::Tensor& /*qy*/);
32 using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
33 using qelu_fn = void(*)(
34     const at::Tensor& /*qx*/,
35     const Scalar& /*alpha*/,
36     const Scalar& /*scale*/,
37     const Scalar& /*input_scale*/,
38     at::Tensor& /*qy*/);
39 using qbinary_fn =
40     void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
41 using qadd_scalar_fn =
42     void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/);
43 using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
44 using qdropout_fn = void(*)(
45     const at::Tensor& /*qx*/,
46     const Scalar& /*p*/,
47     bool training /*training*/,
48     at::Tensor& /*qy*/);
49 using qmaxpool_2d_fn = void (*)(
50     const Tensor& qx,
51     int64_t iC, // input/output channels
52     int64_t iH,
53     int64_t iW, // input sizes
54     int64_t oH,
55     int64_t oW, // output sizes
56     int64_t kH,
57     int64_t kW, // kernel size
58     int64_t sH,
59     int64_t sW, // strides
60     int64_t pH,
61     int64_t pW, // padding
62     int64_t dH,
63     int64_t dW, // dilation
64     Tensor& qy);
65 using qmaxpool_3d_fn = void (*)(
66     const Tensor& qx,
67     int64_t iC, // input/output channels
68     int64_t iT,
69     int64_t iH,
70     int64_t iW, // input sizes
71     int64_t oT,
72     int64_t oH,
73     int64_t oW, // output sizes
74     int64_t kT,
75     int64_t kH,
76     int64_t kW, // kernel size
77     int64_t sT,
78     int64_t sH,
79     int64_t sW, // strides
80     int64_t pT,
81     int64_t pH,
82     int64_t pW, // padding
83     int64_t dT,
84     int64_t dH,
85     int64_t dW, // dilation
86     Tensor& qy);
87 using qadaptive_avg_pool2d_fn = void (*)(
88     const Tensor& qx,
89     Tensor& qy,
90     int64_t sizeB,
91     int64_t sizeC,
92     int64_t isizeH,
93     int64_t isizeW,
94     int64_t osizeH,
95     int64_t osizeW,
96     int64_t istrideB,
97     int64_t istrideC,
98     int64_t istrideH,
99     int64_t istrideW);
100 using qadaptive_avg_pool3d_fn = void (*)(
101     const Tensor& qx,
102     Tensor& qy,
103     int64_t sizeB,
104     int64_t sizeC,
105     int64_t isizeD,
106     int64_t isizeH,
107     int64_t isizeW,
108     int64_t osizeD,
109     int64_t osizeH,
110     int64_t osizeW,
111     int64_t istrideB,
112     int64_t istrideC,
113     int64_t istrideD,
114     int64_t istrideH,
115     int64_t istrideW);
116 using qavg_pool2d_fn = void (*)(
117     const Tensor& qx,
118     Tensor& qy,
119     int64_t nBatch,
120     int64_t nInputPlane,
121     int64_t inputWidth,
122     int64_t inputHeight,
123     int64_t outputWidth,
124     int64_t outputHeight,
125     int kW,
126     int kH,
127     int dW,
128     int dH,
129     int padW,
130     int padH,
131     bool count_include_pad,
132     std::optional<int64_t> divisor_override);
133 
134 using qavg_pool3d_fn = void (*)(
135     const Tensor& qx,
136     Tensor& qy,
137     int64_t nBatch,
138     int64_t nInputPlane,
139     int64_t inputWidth,
140     int64_t inputHeight,
141     int64_t inputDepth,
142     int64_t outputWidth,
143     int64_t outputHeight,
144     int64_t outputDepth,
145     int kW,
146     int kH,
147     int kD,
148     int dW,
149     int dH,
150     int dD,
151     int padW,
152     int padH,
153     int padD,
154     bool count_include_pad,
155     std::optional<int64_t> divisor_override);
156 
157 using qupsample_bilinear2d_fn = void (*)(
158     Tensor& output,
159     const Tensor& input,
160     int64_t input_height,
161     int64_t input_width,
162     int64_t output_height,
163     int64_t output_width,
164     int64_t nbatch,
165     int64_t channels,
166     bool align_corners,
167     std::optional<double> scales_h,
168     std::optional<double> scales_w);
169 
170 using qcat_nhwc_fn = Tensor (*)(
171     const MaterializedITensorListRef& qxs,
172     int64_t dim,
173     double scale,
174     int64_t zero_point);
175 using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);
176 
177 using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&);
178 
179 using qnormalize_fn = void (*)(
180     const Tensor& /* X */,
181     const Tensor& /* gamma */,
182     const Tensor& /* beta */,
183     bool /* affine_per_channel */,
184     int /* num_channels */,
185     int /* num_groups */,
186     int64_t /* M */,
187     int64_t /* N */,
188     double /* eps */,
189     Tensor* /* Y */);
190 
191 using qmean_inner_dim_fn = void (*)(
192     const Tensor& /* X */,
193     OptionalIntArrayRef /* opt_dim */,
194     bool /* keepdim */,
195     std::optional<ScalarType> /* opt_dtype */,
196     Tensor& /* Y */);
197 
198 using qstd_inner_dim_fn = void (*)(
199     const Tensor& /* X */,
200     OptionalIntArrayRef /* dim */,
201     const std::optional<Scalar>& /* correction */,
202     bool /* keepdim */,
203     Tensor& /* Y */);
204 
205 using qnormalize_nhwc_fn = void (*)(
206     const Tensor& /* X */,
207     const Tensor& /* gamma */,
208     const Tensor& /* beta */,
209     bool /* affine_per_channel */,
210     int /* num_channels */,
211     int /* num_groups */,
212     int64_t /* M */,
213     int64_t /* N */,
214     double /* eps */,
215     Tensor* /* Y */);
216 
217 using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
218                            const Tensor& /*qw*/);
219 
220 DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
221 DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub);
222 DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
223 DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub);
224 DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub);
225 DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub);
226 DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub);
227 DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);
228 DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub);
229 DECLARE_DISPATCH(qbinary_fn, qadd_stub);
230 DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub);
231 DECLARE_DISPATCH(qbinary_fn, qmul_stub);
232 DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
233 DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
234 DECLARE_DISPATCH(qclamp_fn, qclamp_stub);
235 DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub);
236 DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub);
237 DECLARE_DISPATCH(qelu_fn, qelu_stub);
238 DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub);
239 DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub);
240 DECLARE_DISPATCH(qdropout_fn, qdropout_stub);
241 DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
242 DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub);
243 DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub);
244 DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub);
245 DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
246 DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub);
247 DECLARE_DISPATCH(qgelu_fn, qgelu_stub);
248 DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub);
249 DECLARE_DISPATCH(qtanh_fn, qtanh_stub);
250 DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
251 DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
252 DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
253 DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub);
254 DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub);
255 DECLARE_DISPATCH(qprelu_fn, qprelu_stub);
256 
257 } // namespace native
258 } // namespace at
259