1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <torch/library.h>
7 #include <ATen/native/Pool.h>
8 #include <ATen/native/MaxPooling.h>
9 #include <ATen/quantized/Quantizer.h>
10 #include <ATen/native/quantized/cpu/QuantizedOps.h>
11 #include <ATen/native/quantized/cpu/init_qnnpack.h>
12 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
13 #include <c10/util/irange.h>
14 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/_empty_affine_quantized.h>
22 #include <ATen/ops/quantized_max_pool1d.h>
23 #include <ATen/ops/quantized_max_pool1d_native.h>
24 #include <ATen/ops/quantized_max_pool2d.h>
25 #include <ATen/ops/quantized_max_pool2d_native.h>
26 #include <ATen/ops/quantized_max_pool3d_native.h>
27 #endif
28
29 #include <algorithm>
30 #include <vector>
31
32 namespace at {
33 namespace native {
34
35 DEFINE_DISPATCH(qmaxpool_2d_nhwc_stub);
36 DEFINE_DISPATCH(qmaxpool_3d_nthwc_stub);
37
38 namespace {
39
40 /* Computes the spatial 2D max pooling with dilation.
41
42 Argument description in the argument list.
43 */
44 template <typename T>
spatial_dilated_max_pooling(const T * iData,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,T * oData)45 void spatial_dilated_max_pooling(
46 const T* iData,
47 int64_t iC, // input/output channels
48 int64_t iH,
49 int64_t iW, // input sizes
50 int64_t oH,
51 int64_t oW, // output sizes
52 int64_t kH,
53 int64_t kW, // kernel size
54 int64_t sH,
55 int64_t sW, // strides
56 int64_t pH,
57 int64_t pW, // padding
58 int64_t dH,
59 int64_t dW, // dilation
60 T* oData) { // output arrays (data and max-index)
61 at::parallel_for(0, iC, 0, [&](int64_t start, int64_t end) {
62 for (const auto p : c10::irange(start, end)) {
63 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
64 int64_t row, col;
65 const T* i_p = iData + p * iW * iH;
66 for (row = 0; row < oH; ++row) {
67 for (col = 0; col < oW; ++col) {
68 int64_t h_start = row * sH - pH;
69 int64_t w_start = col * sW - pW;
70 int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
71 int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
72 while (h_start < 0)
73 h_start += dH;
74 while (w_start < 0)
75 w_start += dW;
76
77 // local pointers
78 T* o_p = oData + p * oW * oH + row * oW + col;
79
80 // local max
81 auto max_val = std::numeric_limits<typename T::underlying>::lowest();
82 int64_t tcntr = 0; // center point
83 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
84 int64_t x, y;
85 for (y = h_start; y < h_end; y += dH) {
86 for (x = w_start; x < w_end; x += dW) {
87 tcntr = y * iW + x;
88 auto val = (i_p + tcntr)->val_;
89 if (val > max_val) {
90 max_val = val;
91 }
92 }
93 }
94 *o_p = T(max_val); // Output.
95 }
96 }
97 }
98 });
99 }
100
101 template <typename T>
spatial_dilated_max_pooling3d(const T * qxd,int64_t nbatch,int64_t iC,int64_t iT,int64_t iH,int64_t iW,int64_t oT,int64_t oH,int64_t oW,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,T * qyd)102 void spatial_dilated_max_pooling3d(
103 const T* qxd,
104 int64_t nbatch,
105 int64_t iC, // input/output channels
106 int64_t iT,
107 int64_t iH,
108 int64_t iW, // input sizes
109 int64_t oT,
110 int64_t oH,
111 int64_t oW, // output sizes
112 int64_t kT,
113 int64_t kH,
114 int64_t kW, // kernel size
115 int64_t sT,
116 int64_t sH,
117 int64_t sW, // strides
118 int64_t pT,
119 int64_t pH,
120 int64_t pW, // padding
121 int64_t dT,
122 int64_t dH,
123 int64_t dW, // dilation
124 T* qyd) { // output arrays (data and max-index)
125 // TODO: Further optimize the performance suggested by @mingfeima. Parallel on NCTH and cache the output indices from W.
126 // Handle each bs
127 int64_t oC = iC;
128 int64_t parallel_dim = nbatch * iC;
129 at::parallel_for(0, parallel_dim, 0, [&](int64_t start, int64_t end) {
130 for (const auto p : c10::irange(start, end)) {
131
132 int64_t batch_idx = p / iC;
133 int64_t channel_idx = p - batch_idx * iC;
134
135 auto* iData = qxd + batch_idx * iC * iT * iH * iW;
136 auto* oData = qyd + batch_idx * oC * oT * oH * oW;
137
138 // Handle each Channel
139 int64_t time, row, col;
140 const T* i_p = iData + channel_idx * iT * iW * iH;
141 for (time = 0; time < oT; ++time) {
142 for (row = 0; row < oH; ++row) {
143 for (col = 0; col < oW; ++col) {
144 // Handle each output element
145 int64_t t_start = time * sT - pT;
146 int64_t h_start = row * sH - pH;
147 int64_t w_start = col * sW - pW;
148 int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
149 int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
150 int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
151
152 while (t_start < 0)
153 t_start += dT;
154 while (h_start < 0)
155 h_start += dH;
156 while (w_start < 0)
157 w_start += dW;
158
159 // local pointers
160 T* o_p = oData + channel_idx * oT * oH * oW + time * oH * oW + row * oW + col;
161
162 // local max
163 auto max_val = std::numeric_limits<typename T::underlying>::lowest();
164 int64_t tcntr = 0; // center point
165 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
166 int64_t t, x, y;
167 for (t = t_start; t < t_end; t += dT) {
168 for (y = h_start; y < h_end; y += dH) {
169 for (x = w_start; x < w_end; x += dW) {
170 tcntr = t * iH * iW + y * iW + x;
171 auto val = (i_p + tcntr)->val_;
172 if (val > max_val) {
173 max_val = val;
174 }
175 }
176 }
177 }
178 *o_p = T(max_val); // Output.
179 }
180 }
181 }
182 }
183 });
184 }
185
186 template <typename Q>
q_maxpool_2d(Tensor qx,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,bool ceil_mode)187 Tensor q_maxpool_2d(
188 Tensor qx, // Input Tensor (Quantized)
189 int64_t kH,
190 int64_t kW, // kernel size
191 int64_t sH,
192 int64_t sW, // strides
193 int64_t pH,
194 int64_t pW, // padding
195 int64_t dH,
196 int64_t dW,
197 bool ceil_mode) { // dilation
198 // Check input dimensions.
199 TORCH_CHECK(kH > 0 && kW > 0, "kernel_size should be greater than zero.");
200 TORCH_CHECK(sH > 0 && sW > 0, "strides should be greater than zero.");
201 TORCH_CHECK(
202 dH > 0 && dW > 0,
203 "dilation should be greater than zero. "
204 "Got (",
205 dH,
206 ", ",
207 dW,
208 ")");
209
210 int ndim = qx.dim();
211 TORCH_CHECK(
212 ndim == 3 || ndim == 4, "Expecting the input tensor of rank 3 or 4.");
213 int dimc = 0;
214 int dimh = 1;
215 int dimw = 2;
216 int nbatch = 1;
217 if (ndim == 4) { // Includes batches
218 ++dimc;
219 ++dimh;
220 ++dimw;
221 nbatch = qx.size(0);
222 }
223
224 // Check if inputs are valid.
225 int64_t iC = qx.size(dimc);
226 int64_t iH = qx.size(dimh);
227 int64_t iW = qx.size(dimw);
228 TORCH_CHECK(iC > 0 && iH > 0 && iW > 0, "input dimensions must be non-zero.");
229 TORCH_CHECK(
230 (ndim == 3 || ndim == 4),
231 "non-empty 3D or 4D input tensor is expected.");
232 TORCH_CHECK(
233 kH / 2 >= pH && kW / 2 >= pW,
234 "padding should be smaller than half of kernel_size.");
235
236 // Check output dimensions.
237 int64_t oC = iC;
238 int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
239 int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
240 TORCH_CHECK(oH > 0 && oW > 0,
241 "Given input size: (",
242 iC, "x", iH, "x", iW,
243 "). Calculated output size: (",
244 oC, "x", oH, "x", oW,
245 "). Output size is too small.");
246
247 std::vector<int64_t> oSizes;
248 if (ndim == 3) {
249 oSizes = {oC, oH, oW};
250 } else {
251 oSizes = {nbatch, oC, oH, oW};
252 }
253
254 if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
255 // Fast path case for channels-last case.
256 // In this case, we can preserve the data layout in memory
257 // as well as use a loop nest that is more amenable to
258 // vectorization.
259 Tensor qy;
260 if constexpr(std::is_same_v<Q, uint8_t>) {
261 qy = at::empty(
262 oSizes,
263 qx.options()
264 .device(c10::kCPU)
265 .dtype(qx.scalar_type())
266 .memory_format(c10::MemoryFormat::ChannelsLast));
267 } else {
268 qy = at::_empty_affine_quantized(
269 oSizes,
270 qx.options()
271 .dtype(toQIntType(qx.scalar_type()))
272 .memory_format(qx.suggest_memory_format()),
273 qx.q_scale(),
274 qx.q_zero_point(),
275 std::nullopt);
276 }
277 qmaxpool_2d_nhwc_stub(qx.device().type(), qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
278 return qy;
279 } else {
280 Tensor qy;
281 if constexpr(!std::is_same_v<Q, uint8_t>) {
282 qy = at::_empty_affine_quantized(
283 oSizes,
284 qx.options().dtype(toQIntType(qx.scalar_type())),
285 qx.q_scale(),
286 qx.q_zero_point());
287 auto qx_contig = qx.contiguous();
288 auto qxd = qx_contig.data_ptr<Q>();
289 auto qyd = qy.data_ptr<Q>();
290 if (ndim == 3 || nbatch == 1) {
291 auto* iData = qxd;
292 auto* oData = qyd;
293 spatial_dilated_max_pooling<Q>(
294 iData,
295 iC,
296 iH,
297 iW,
298 oH,
299 oW,
300 kH,
301 kW,
302 sH,
303 sW,
304 pH,
305 pW,
306 dH,
307 dW,
308 oData);
309 } else {
310 at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
311 for (const auto p : c10::irange(start, end)) {
312 auto* iData = qxd + p * iC * iW * iH;
313 auto* oData = qyd + p * oC * oW * oH;
314 spatial_dilated_max_pooling<Q>(
315 iData,
316 iC,
317 iH,
318 iW,
319 oH,
320 oW,
321 kH,
322 kW,
323 sH,
324 sW,
325 pH,
326 pW,
327 dH,
328 dW,
329 oData);
330 }
331 });
332 }
333 } else {
334 // If qx is uint8 and contiguous memory format,
335 // Use the channels_last implementation and convert qy back to contiguous.
336 qy = at::empty(
337 oSizes,
338 qx.options()
339 .device(c10::kCPU)
340 .dtype(qx.scalar_type())
341 .memory_format(c10::MemoryFormat::ChannelsLast));
342 auto qx_nhwc = qx.contiguous(c10::MemoryFormat::ChannelsLast);
343 qmaxpool_2d_nhwc_stub(qx_nhwc.device().type(), qx_nhwc, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
344 qy = qy.contiguous();
345 }
346 return qy;
347 }
348 }
349
350 template <typename Q>
q_maxpool_3d(Tensor qx,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,bool ceil_mode)351 Tensor q_maxpool_3d(
352 Tensor qx, // Input Tensor (Quantized)
353 int64_t kT,
354 int64_t kH,
355 int64_t kW, // kernel size
356 int64_t sT,
357 int64_t sH,
358 int64_t sW, // strides
359 int64_t pT,
360 int64_t pH,
361 int64_t pW, // padding
362 int64_t dT,
363 int64_t dH,
364 int64_t dW,
365 bool ceil_mode) { // dilation
366 // Check input dimensions.
367 TORCH_CHECK(kT > 0 && kH > 0 && kW > 0, "kernel_size should be greater than zero.");
368 TORCH_CHECK(sT > 0 && sH > 0 && sW > 0, "strides should be greater than zero.");
369 TORCH_CHECK(
370 dT && dH > 0 && dW > 0,
371 "dilation should be greater than zero. "
372 "Got (",
373 dT,
374 ", ",
375 dH,
376 ", ",
377 dW,
378 ")");
379 int ndim = qx.dim();
380 // TODO leslie: Support non batch mode input when input is THWC which is 4-d tensor.
381 TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
382
383 // act: n, c, t, h, w
384 int dimc = 1;
385 int dimt = 2;
386 int dimh = 3;
387 int dimw = 4;
388 int nbatch = qx.size(0);
389 // Check if inputs are valid.
390 int64_t iC = qx.size(dimc);
391 int64_t iT = qx.size(dimt);
392 int64_t iH = qx.size(dimh);
393 int64_t iW = qx.size(dimw);
394 TORCH_CHECK(iC > 0 && iT > 0 && iH > 0 && iW > 0, "input dimensions must be non-zero.");
395 TORCH_CHECK(
396 kT / 2 >= pT && kH / 2 >= pH && kW / 2 >= pW,
397 "padding should be smaller than half of kernel_size.");
398
399 // Check output dimensions.
400 int64_t oC = iC;
401 int64_t oT = pooling_output_shape(iT, kT, pT, sT, dT, ceil_mode);
402 int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
403 int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
404 TORCH_CHECK(oT > 0 && oH > 0 && oW > 0,
405 "Given input size: (",
406 iC, "t", iT , "x", iH, "x", iW,
407 "). Calculated output size: (",
408 oC, "t", oT , "x", oH, "x", oW,
409 "). Output size is too small.");
410
411 std::vector<int64_t> oSizes = {nbatch, oC, oT, oH, oW};
412
413 if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
414 // Fast path case for channels-last case.
415 // In this case, we can preserve the data layout in memory
416 // as well as use a loop nest that is more amenable to
417 // vectorization.
418 Tensor qy = at::_empty_affine_quantized(
419 oSizes,
420 qx.options()
421 .dtype(toQIntType(qx.scalar_type()))
422 .memory_format(qx.suggest_memory_format()),
423 qx.q_scale(),
424 qx.q_zero_point(),
425 std::nullopt);
426 qmaxpool_3d_nthwc_stub(qx.device().type(), qx, iC, iT, iH, iW, oT, oH, oW, kT, kH, kW, sT, sH, sW, pT, pH, pW, dT, dH, dW, qy);
427 return qy;
428 } else {
429 Tensor qy = at::_empty_affine_quantized(
430 oSizes,
431 qx.options().dtype(toQIntType(qx.scalar_type())),
432 qx.q_scale(),
433 qx.q_zero_point());
434 auto qx_contig = qx.contiguous();
435 auto qxd = qx_contig.data_ptr<Q>();
436 auto qyd = qy.data_ptr<Q>();
437
438 spatial_dilated_max_pooling3d<Q>(
439 qxd,
440 nbatch,
441 iC,
442 iT,
443 iH,
444 iW,
445 oT,
446 oH,
447 oW,
448 kT,
449 kH,
450 kW,
451 sT,
452 sH,
453 sW,
454 pT,
455 pH,
456 pW,
457 dT,
458 dH,
459 dW,
460 qyd);
461
462 return qy;
463 }
464 }
465 } // namespace
466
467 namespace {
check_maxpool2d_params(IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)468 void check_maxpool2d_params(
469 IntArrayRef kernel_size,
470 IntArrayRef stride,
471 IntArrayRef padding,
472 IntArrayRef dilation) {
473 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
474 "Expected 1d or 2d kernel size, got ", kernel_size.size());
475 TORCH_CHECK(stride.empty() || stride.size() == 2,
476 "Expected no strides or 2d strides, got", stride.size());
477 TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
478 "Expected 1d or 2d padding, got ", padding.size());
479 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
480 "Expected 1d or 2d dilation, got ", dilation.size());
481 }
482
check_maxpool3d_params(IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)483 void check_maxpool3d_params(
484 IntArrayRef kernel_size,
485 IntArrayRef stride,
486 IntArrayRef padding,
487 IntArrayRef dilation) {
488 TORCH_CHECK(kernel_size.size() == 3, "Expected 3d kernel size, got ", kernel_size.size());
489 TORCH_CHECK(stride.empty() || stride.size() == 3,
490 "Expected no strides or 3d strides, got", stride.size());
491 TORCH_CHECK(padding.size() == 3, "Expected 3d padding, got ", padding.size());
492 TORCH_CHECK(dilation.size() == 3, "Expected 1d or 3d dilation, got ", dilation.size());
493 }
494
495 #ifdef USE_PYTORCH_QNNPACK
qnnpack_maxpool2d(Tensor input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)496 static Tensor qnnpack_maxpool2d(
497 Tensor input,
498 IntArrayRef kernel_size,
499 IntArrayRef stride,
500 IntArrayRef padding,
501 IntArrayRef dilation,
502 bool ceil_mode) {
503 Tensor qy;
504
505 TORCH_CHECK(
506 input.ndimension() == 4,
507 "qnnpack_maxpool2d(): Expected input to be 4-dimensional: got ",
508 input.ndimension());
509 TORCH_CHECK(
510 kernel_size.size() == 2,
511 "qnnpack_maxpool2d(): Expected kernel_size to be 2-dimensional: got ",
512 kernel_size.size());
513 TORCH_CHECK(
514 stride.size() == 2,
515 "qnnpack_maxpool2d(): Expected stride to be 2-dimensional: got ",
516 stride.size());
517 TORCH_CHECK(
518 dilation.size() == 2,
519 "qnnpack_maxpool2d(): Expected dilation to be 2-dimensional: got ",
520 dilation.size());
521 TORCH_CHECK(
522 padding.size() == 2,
523 "qnnpack_maxpool2d(): Expected padding to be 2-dimensional: got ",
524 padding.size());
525
526 int64_t batch_size = input.size(0);
527 int64_t inC = input.size(1);
528 int64_t inH = input.size(2);
529 int64_t inW = input.size(3);
530 Tensor input_contig = input.contiguous(MemoryFormat::ChannelsLast);
531
532 initQNNPACK();
533 const auto scale = input_contig.q_scale();
534 const auto zero_point = input_contig.q_zero_point();
535 pytorch_qnnp_operator_t qnnpack_operator{nullptr};
536
537 int64_t padH = padding[0];
538 int64_t padW = padding[1];
539 int64_t kH = kernel_size[0];
540 int64_t kW = kernel_size[1];
541 int64_t strideH = stride[0];
542 int64_t strideW = stride[1];
543 int64_t dilationH = dilation[0];
544 int64_t dilationW = dilation[1];
545
546 TORCH_CHECK(
547 kH > 0 && kW > 0,
548 "qnnpack_maxpool2d(): kernel_size should be greater than zero.");
549 TORCH_CHECK(
550 strideH > 0 && strideW > 0,
551 "qnnpack_maxpool2d(): strides should be greater than zero.");
552
553 const pytorch_qnnp_status createStatus =
554 pytorch_qnnp_create_max_pooling2d_nhwc_u8(
555 padH /* input_padding_height */,
556 padW /* input_padding_width */,
557 kH /* pooling height */,
558 kW /* pooling width */,
559 strideH /* stride height */,
560 strideW /* stride width */,
561 dilationH /* dilation height */,
562 dilationW /* dilation width */,
563 inC /* input channels */,
564 std::numeric_limits<uint8_t>::min() /* output min */,
565 std::numeric_limits<uint8_t>::max() /* output max */,
566 0 /* flags */,
567 &qnnpack_operator);
568 TORCH_INTERNAL_ASSERT(
569 createStatus == pytorch_qnnp_status_success,
570 "failed to create QNNPACK MaxPool operator");
571
572 int64_t outC = inC;
573 int64_t outH =
574 pooling_output_shape(inH, kH, padH, strideH, dilationH, ceil_mode);
575 int64_t outW =
576 pooling_output_shape(inW, kW, padW, strideW, dilationW, ceil_mode);
577
578 TORCH_CHECK(
579 outH > 0 && outW > 0,
580 "qnnpack_maxpool2d(): the resulting output Tensor size should be >= 0");
581
582 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
583 qnnpack_uniq_ptr(qnnpack_operator);
584
585 // NHWC output
586 qy = at::_empty_affine_quantized(
587 {batch_size, outC, outH, outW},
588 at::device(kCPU).dtype(kQUInt8),
589 scale,
590 zero_point,
591 MemoryFormat::ChannelsLast);
592
593 const pytorch_qnnp_status setupStatus =
594 pytorch_qnnp_setup_max_pooling2d_nhwc_u8(
595 qnnpack_operator /* max pooling */,
596 batch_size /* batch size */,
597 inH /* input height */,
598 inW /* input width */,
599 (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input */,
600 inC /* input_pixel_stride */,
601 (uint8_t*)qy.data_ptr<c10::quint8>() /* output data */,
602 outC /* output_pixel_stride */,
603 nullptr /* thread pool */);
604 TORCH_INTERNAL_ASSERT(
605 setupStatus == pytorch_qnnp_status_success,
606 "failed to setup QNNPACK MaxPool operator");
607
608 pthreadpool_t threadpool = caffe2::pthreadpool_();
609 const pytorch_qnnp_status runStatus =
610 pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
611 TORCH_INTERNAL_ASSERT(
612 runStatus == pytorch_qnnp_status_success,
613 "failed to run QNNPACK MaxPool operator");
614 return qy.contiguous(input.suggest_memory_format());
615 }
616 #endif
617 } // namespace
618
619 // at::native functions for the native_functions.yaml
quantized_max_pool2d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)620 Tensor quantized_max_pool2d(
621 const Tensor& qx,
622 IntArrayRef kernel_size,
623 IntArrayRef stride,
624 IntArrayRef padding,
625 IntArrayRef dilation,
626 bool ceil_mode) {
627 check_maxpool2d_params(
628 kernel_size,
629 stride,
630 padding,
631 dilation);
632 if (stride.empty()) {
633 stride = kernel_size;
634 }
635 #ifdef USE_PYTORCH_QNNPACK
636 if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8 && !ceil_mode) {
637 return qnnpack_maxpool2d(qx, kernel_size, stride, padding, dilation, ceil_mode);
638 }
639 #endif
640 Tensor qy;
641 AT_DISPATCH_QINT_TYPES_AND(ScalarType::Byte, qx.scalar_type(), "max_pool2d", [&]() {
642 qy = q_maxpool_2d<scalar_t>(
643 qx,
644 kernel_size[0],
645 kernel_size[1],
646 stride[0],
647 stride[1],
648 padding[0],
649 padding[1],
650 dilation[0],
651 dilation[1],
652 ceil_mode);
653 });
654 return qy;
655 }
656
quantized_max_pool3d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)657 Tensor quantized_max_pool3d(
658 const Tensor& qx,
659 IntArrayRef kernel_size,
660 IntArrayRef stride,
661 IntArrayRef padding,
662 IntArrayRef dilation,
663 bool ceil_mode) {
664 check_maxpool3d_params(
665 kernel_size,
666 stride,
667 padding,
668 dilation);
669 if (stride.empty()) {
670 stride = kernel_size;
671 }
672 #ifdef USE_PYTORCH_QNNPACK
673 TORCH_CHECK(at::globalContext().qEngine() != at::QEngine::QNNPACK,
674 "QNNPACK backend doesn't support of quantized_max_pool3d");
675 #endif
676 Tensor qy;
677 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool3d", [&]() {
678 qy = q_maxpool_3d<scalar_t>(
679 qx,
680 kernel_size[0],
681 kernel_size[1],
682 kernel_size[2],
683 stride[0],
684 stride[1],
685 stride[2],
686 padding[0],
687 padding[1],
688 padding[2],
689 dilation[0],
690 dilation[1],
691 dilation[2],
692 ceil_mode);
693 });
694 return qy;
695 }
696
697 // Quantized max_pool1d is a special case of the max_pool2d, with one of the
698 // dimensions and kernels removed.
quantized_max_pool1d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)699 Tensor quantized_max_pool1d(
700 const Tensor& qx,
701 IntArrayRef kernel_size,
702 IntArrayRef stride,
703 IntArrayRef padding,
704 IntArrayRef dilation,
705 bool ceil_mode) {
706 check_max_pool1d(qx, kernel_size, stride, padding, dilation, ceil_mode);
707 // (C, L) -> (C, 1, L) => kSqueezeDim = 1
708 // (N, C, L) -> (N, C, 1, L) => kSqueezeDim = 2
709 const int32_t kSqueezeDim = qx.dim() - 1;
710 const auto qx_unsqueeze = qx.unsqueeze(kSqueezeDim);
711 if (stride.empty()) {
712 stride = kernel_size;
713 }
714 auto qy = at::quantized_max_pool2d(
715 qx.unsqueeze(kSqueezeDim),
716 {1, kernel_size[0]},
717 {1, stride[0]},
718 {0, padding[0]},
719 {1, dilation[0]},
720 ceil_mode);
721 qy = qy.squeeze(kSqueezeDim);
722 return qy;
723 }
724
725 // Keep the registry in the anonymous namespace.
726 namespace {
727 template <uint32_t kSpatialDim>
728 class QMaxPool_arr_args final {
729 public:
run(const Tensor & qx,std::vector<int64_t> kernel_size,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,bool ceil_mode)730 static Tensor run(
731 const Tensor& qx,
732 std::vector<int64_t> kernel_size,
733 std::vector<int64_t> stride,
734 std::vector<int64_t> padding,
735 std::vector<int64_t> dilation,
736 bool ceil_mode) {
737 if (!qx.is_quantized() && kSpatialDim == 2 && qx.scalar_type() == c10::ScalarType::Byte){
738 return at::native::quantized_max_pool2d(qx, kernel_size, stride, padding,
739 dilation, ceil_mode);
740 }
741 if (kSpatialDim == 1) {
742 return at::quantized_max_pool1d(qx, kernel_size, stride, padding,
743 dilation, ceil_mode);
744 } else if (kSpatialDim == 2) {
745 return at::quantized_max_pool2d(qx, kernel_size, stride, padding,
746 dilation, ceil_mode);
747 }
748 TORCH_CHECK(false, "MaxPool", kSpatialDim, "D is not supported.");
749 }
750 };
751
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)752 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
753 m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool1d"), TORCH_FN(QMaxPool_arr_args<1>::run));
754 m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
755 }
756
TORCH_LIBRARY_IMPL(quantized,CPU,m)757 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
758 m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
759 }
760
761 } // namespace
762 } // namespace native
763 } // namespace at
764