xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/Pooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <torch/library.h>
7 #include <ATen/native/Pool.h>
8 #include <ATen/native/MaxPooling.h>
9 #include <ATen/quantized/Quantizer.h>
10 #include <ATen/native/quantized/cpu/QuantizedOps.h>
11 #include <ATen/native/quantized/cpu/init_qnnpack.h>
12 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
13 #include <c10/util/irange.h>
14 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/_empty_affine_quantized.h>
22 #include <ATen/ops/quantized_max_pool1d.h>
23 #include <ATen/ops/quantized_max_pool1d_native.h>
24 #include <ATen/ops/quantized_max_pool2d.h>
25 #include <ATen/ops/quantized_max_pool2d_native.h>
26 #include <ATen/ops/quantized_max_pool3d_native.h>
27 #endif
28 
29 #include <algorithm>
30 #include <vector>
31 
32 namespace at {
33 namespace native {
34 
35 DEFINE_DISPATCH(qmaxpool_2d_nhwc_stub);
36 DEFINE_DISPATCH(qmaxpool_3d_nthwc_stub);
37 
38 namespace {
39 
40 /* Computes the spatial 2D max pooling with dilation.
41 
42 Argument description in the argument list.
43 */
44 template <typename T>
spatial_dilated_max_pooling(const T * iData,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,T * oData)45 void spatial_dilated_max_pooling(
46     const T* iData,
47     int64_t iC, // input/output channels
48     int64_t iH,
49     int64_t iW, // input sizes
50     int64_t oH,
51     int64_t oW, // output sizes
52     int64_t kH,
53     int64_t kW, // kernel size
54     int64_t sH,
55     int64_t sW, // strides
56     int64_t pH,
57     int64_t pW, // padding
58     int64_t dH,
59     int64_t dW, // dilation
60     T* oData) { // output arrays (data and max-index)
61   at::parallel_for(0, iC, 0, [&](int64_t start, int64_t end) {
62     for (const auto p : c10::irange(start, end)) {
63       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
64       int64_t row, col;
65       const T* i_p = iData + p * iW * iH;
66       for (row = 0; row < oH; ++row) {
67         for (col = 0; col < oW; ++col) {
68           int64_t h_start = row * sH - pH;
69           int64_t w_start = col * sW - pW;
70           int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
71           int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
72           while (h_start < 0)
73             h_start += dH;
74           while (w_start < 0)
75             w_start += dW;
76 
77           // local pointers
78           T* o_p = oData + p * oW * oH + row * oW + col;
79 
80           // local max
81           auto max_val = std::numeric_limits<typename T::underlying>::lowest();
82           int64_t tcntr = 0; // center point
83           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
84           int64_t x, y;
85           for (y = h_start; y < h_end; y += dH) {
86             for (x = w_start; x < w_end; x += dW) {
87               tcntr = y * iW + x;
88               auto val = (i_p + tcntr)->val_;
89               if (val > max_val) {
90                 max_val = val;
91               }
92             }
93           }
94           *o_p = T(max_val); // Output.
95         }
96       }
97     }
98   });
99 }
100 
101 template <typename T>
spatial_dilated_max_pooling3d(const T * qxd,int64_t nbatch,int64_t iC,int64_t iT,int64_t iH,int64_t iW,int64_t oT,int64_t oH,int64_t oW,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,T * qyd)102 void spatial_dilated_max_pooling3d(
103     const T* qxd,
104     int64_t nbatch,
105     int64_t iC, // input/output channels
106     int64_t iT,
107     int64_t iH,
108     int64_t iW, // input sizes
109     int64_t oT,
110     int64_t oH,
111     int64_t oW, // output sizes
112     int64_t kT,
113     int64_t kH,
114     int64_t kW, // kernel size
115     int64_t sT,
116     int64_t sH,
117     int64_t sW, // strides
118     int64_t pT,
119     int64_t pH,
120     int64_t pW, // padding
121     int64_t dT,
122     int64_t dH,
123     int64_t dW, // dilation
124     T* qyd) { // output arrays (data and max-index)
125   // TODO: Further optimize the performance suggested by @mingfeima. Parallel on NCTH and cache the output indices from W.
126   // Handle each bs
127   int64_t oC = iC;
128   int64_t parallel_dim = nbatch * iC;
129   at::parallel_for(0, parallel_dim, 0, [&](int64_t start, int64_t end) {
130     for (const auto p : c10::irange(start, end)) {
131 
132       int64_t batch_idx = p / iC;
133       int64_t channel_idx = p - batch_idx * iC;
134 
135       auto* iData = qxd + batch_idx * iC * iT * iH * iW;
136       auto* oData = qyd + batch_idx * oC * oT * oH * oW;
137 
138       // Handle each Channel
139       int64_t time, row, col;
140       const T* i_p = iData + channel_idx * iT * iW * iH;
141       for (time = 0; time < oT; ++time) {
142         for (row = 0; row < oH; ++row) {
143           for (col = 0; col < oW; ++col) {
144             // Handle each output element
145             int64_t t_start = time * sT - pT;
146             int64_t h_start = row * sH - pH;
147             int64_t w_start = col * sW - pW;
148             int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
149             int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
150             int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
151 
152             while (t_start < 0)
153               t_start += dT;
154             while (h_start < 0)
155               h_start += dH;
156             while (w_start < 0)
157               w_start += dW;
158 
159             // local pointers
160             T* o_p = oData + channel_idx * oT * oH * oW  + time * oH * oW  + row * oW + col;
161 
162             // local max
163             auto max_val = std::numeric_limits<typename T::underlying>::lowest();
164             int64_t tcntr = 0; // center point
165             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
166             int64_t t, x, y;
167             for (t = t_start; t < t_end; t += dT) {
168               for (y = h_start; y < h_end; y += dH) {
169                 for (x = w_start; x < w_end; x += dW) {
170                   tcntr = t * iH * iW + y * iW + x;
171                   auto val = (i_p + tcntr)->val_;
172                   if (val > max_val) {
173                     max_val = val;
174                   }
175                 }
176               }
177             }
178             *o_p = T(max_val); // Output.
179           }
180         }
181       }
182     }
183   });
184 }
185 
186 template <typename Q>
q_maxpool_2d(Tensor qx,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,bool ceil_mode)187 Tensor q_maxpool_2d(
188     Tensor qx, // Input Tensor (Quantized)
189     int64_t kH,
190     int64_t kW, // kernel size
191     int64_t sH,
192     int64_t sW, // strides
193     int64_t pH,
194     int64_t pW, // padding
195     int64_t dH,
196     int64_t dW,
197     bool ceil_mode) { // dilation
198   // Check input dimensions.
199   TORCH_CHECK(kH > 0 && kW > 0, "kernel_size should be greater than zero.");
200   TORCH_CHECK(sH > 0 && sW > 0, "strides should be greater than zero.");
201   TORCH_CHECK(
202       dH > 0 && dW > 0,
203       "dilation should be greater than zero. "
204       "Got (",
205       dH,
206       ", ",
207       dW,
208       ")");
209 
210   int ndim = qx.dim();
211   TORCH_CHECK(
212       ndim == 3 || ndim == 4, "Expecting the input tensor of rank 3 or 4.");
213   int dimc = 0;
214   int dimh = 1;
215   int dimw = 2;
216   int nbatch = 1;
217   if (ndim == 4) { // Includes batches
218     ++dimc;
219     ++dimh;
220     ++dimw;
221     nbatch = qx.size(0);
222   }
223 
224   // Check if inputs are valid.
225   int64_t iC = qx.size(dimc);
226   int64_t iH = qx.size(dimh);
227   int64_t iW = qx.size(dimw);
228   TORCH_CHECK(iC > 0 && iH > 0 && iW > 0, "input dimensions must be non-zero.");
229   TORCH_CHECK(
230       (ndim == 3 || ndim == 4),
231       "non-empty 3D or 4D input tensor is expected.");
232   TORCH_CHECK(
233       kH / 2 >= pH && kW / 2 >= pW,
234       "padding should be smaller than half of kernel_size.");
235 
236   // Check output dimensions.
237   int64_t oC = iC;
238   int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
239   int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
240   TORCH_CHECK(oH > 0 && oW > 0,
241               "Given input size: (",
242               iC, "x", iH, "x", iW,
243               "). Calculated output size: (",
244               oC, "x", oH, "x", oW,
245               "). Output size is too small.");
246 
247   std::vector<int64_t> oSizes;
248   if (ndim == 3) {
249     oSizes = {oC, oH, oW};
250   } else {
251     oSizes = {nbatch, oC, oH, oW};
252   }
253 
254   if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
255     // Fast path case for channels-last case.
256     // In this case, we can preserve the data layout in memory
257     // as well as use a loop nest that is more amenable to
258     // vectorization.
259     Tensor qy;
260     if constexpr(std::is_same_v<Q, uint8_t>) {
261       qy = at::empty(
262         oSizes,
263         qx.options()
264           .device(c10::kCPU)
265           .dtype(qx.scalar_type())
266           .memory_format(c10::MemoryFormat::ChannelsLast));
267     } else {
268       qy = at::_empty_affine_quantized(
269           oSizes,
270           qx.options()
271             .dtype(toQIntType(qx.scalar_type()))
272             .memory_format(qx.suggest_memory_format()),
273           qx.q_scale(),
274           qx.q_zero_point(),
275           std::nullopt);
276     }
277     qmaxpool_2d_nhwc_stub(qx.device().type(), qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
278     return qy;
279   } else {
280     Tensor qy;
281     if constexpr(!std::is_same_v<Q, uint8_t>) {
282       qy = at::_empty_affine_quantized(
283               oSizes,
284               qx.options().dtype(toQIntType(qx.scalar_type())),
285               qx.q_scale(),
286               qx.q_zero_point());
287       auto qx_contig = qx.contiguous();
288       auto qxd = qx_contig.data_ptr<Q>();
289       auto qyd = qy.data_ptr<Q>();
290       if (ndim == 3 || nbatch == 1) {
291         auto* iData = qxd;
292         auto* oData = qyd;
293         spatial_dilated_max_pooling<Q>(
294             iData,
295             iC,
296             iH,
297             iW,
298             oH,
299             oW,
300             kH,
301             kW,
302             sH,
303             sW,
304             pH,
305             pW,
306             dH,
307             dW,
308             oData);
309       } else {
310         at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
311           for (const auto p : c10::irange(start, end)) {
312             auto* iData = qxd + p * iC * iW * iH;
313             auto* oData = qyd + p * oC * oW * oH;
314             spatial_dilated_max_pooling<Q>(
315                 iData,
316                 iC,
317                 iH,
318                 iW,
319                 oH,
320                 oW,
321                 kH,
322                 kW,
323                 sH,
324                 sW,
325                 pH,
326                 pW,
327                 dH,
328                 dW,
329                 oData);
330           }
331         });
332       }
333     } else {
334       // If qx is uint8 and contiguous memory format,
335       // Use the channels_last implementation and convert qy back to contiguous.
336       qy = at::empty(
337         oSizes,
338         qx.options()
339           .device(c10::kCPU)
340           .dtype(qx.scalar_type())
341           .memory_format(c10::MemoryFormat::ChannelsLast));
342       auto qx_nhwc = qx.contiguous(c10::MemoryFormat::ChannelsLast);
343       qmaxpool_2d_nhwc_stub(qx_nhwc.device().type(), qx_nhwc, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
344       qy = qy.contiguous();
345     }
346     return qy;
347   }
348 }
349 
350 template <typename Q>
q_maxpool_3d(Tensor qx,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,bool ceil_mode)351 Tensor q_maxpool_3d(
352     Tensor qx, // Input Tensor (Quantized)
353     int64_t kT,
354     int64_t kH,
355     int64_t kW, // kernel size
356     int64_t sT,
357     int64_t sH,
358     int64_t sW, // strides
359     int64_t pT,
360     int64_t pH,
361     int64_t pW, // padding
362     int64_t dT,
363     int64_t dH,
364     int64_t dW,
365     bool ceil_mode) { // dilation
366   // Check input dimensions.
367   TORCH_CHECK(kT > 0 && kH > 0 && kW > 0, "kernel_size should be greater than zero.");
368   TORCH_CHECK(sT > 0 && sH > 0 && sW > 0, "strides should be greater than zero.");
369   TORCH_CHECK(
370       dT && dH > 0 && dW > 0,
371       "dilation should be greater than zero. "
372       "Got (",
373       dT,
374       ", ",
375       dH,
376       ", ",
377       dW,
378       ")");
379   int ndim = qx.dim();
380   // TODO leslie: Support non batch mode input when input is THWC which is 4-d tensor.
381   TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
382 
383   // act: n, c, t, h, w
384   int dimc = 1;
385   int dimt = 2;
386   int dimh = 3;
387   int dimw = 4;
388   int nbatch = qx.size(0);
389   // Check if inputs are valid.
390   int64_t iC = qx.size(dimc);
391   int64_t iT = qx.size(dimt);
392   int64_t iH = qx.size(dimh);
393   int64_t iW = qx.size(dimw);
394   TORCH_CHECK(iC > 0 && iT > 0 && iH > 0 && iW > 0, "input dimensions must be non-zero.");
395   TORCH_CHECK(
396       kT / 2 >= pT && kH / 2 >= pH && kW / 2 >= pW,
397       "padding should be smaller than half of kernel_size.");
398 
399   // Check output dimensions.
400   int64_t oC = iC;
401   int64_t oT = pooling_output_shape(iT, kT, pT, sT, dT, ceil_mode);
402   int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
403   int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
404   TORCH_CHECK(oT > 0 && oH > 0 && oW > 0,
405               "Given input size: (",
406               iC, "t", iT , "x", iH, "x", iW,
407               "). Calculated output size: (",
408               oC, "t", oT , "x", oH, "x", oW,
409               "). Output size is too small.");
410 
411   std::vector<int64_t> oSizes = {nbatch, oC, oT, oH, oW};
412 
413   if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
414     // Fast path case for channels-last case.
415     // In this case, we can preserve the data layout in memory
416     // as well as use a loop nest that is more amenable to
417     // vectorization.
418     Tensor qy = at::_empty_affine_quantized(
419         oSizes,
420         qx.options()
421           .dtype(toQIntType(qx.scalar_type()))
422           .memory_format(qx.suggest_memory_format()),
423         qx.q_scale(),
424         qx.q_zero_point(),
425         std::nullopt);
426     qmaxpool_3d_nthwc_stub(qx.device().type(), qx, iC, iT, iH, iW, oT, oH, oW, kT, kH, kW, sT, sH, sW, pT, pH, pW, dT, dH, dW, qy);
427     return qy;
428   } else {
429     Tensor qy = at::_empty_affine_quantized(
430       oSizes,
431       qx.options().dtype(toQIntType(qx.scalar_type())),
432       qx.q_scale(),
433       qx.q_zero_point());
434     auto qx_contig = qx.contiguous();
435     auto qxd = qx_contig.data_ptr<Q>();
436     auto qyd = qy.data_ptr<Q>();
437 
438     spatial_dilated_max_pooling3d<Q>(
439         qxd,
440         nbatch,
441         iC,
442         iT,
443         iH,
444         iW,
445         oT,
446         oH,
447         oW,
448         kT,
449         kH,
450         kW,
451         sT,
452         sH,
453         sW,
454         pT,
455         pH,
456         pW,
457         dT,
458         dH,
459         dW,
460         qyd);
461 
462     return qy;
463   }
464 }
465 } // namespace
466 
467 namespace {
check_maxpool2d_params(IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)468 void check_maxpool2d_params(
469     IntArrayRef kernel_size,
470     IntArrayRef stride,
471     IntArrayRef padding,
472     IntArrayRef dilation) {
473   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
474               "Expected 1d or 2d kernel size, got ", kernel_size.size());
475   TORCH_CHECK(stride.empty() || stride.size() == 2,
476               "Expected no strides or 2d strides, got", stride.size());
477   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
478               "Expected 1d or 2d padding, got ", padding.size());
479   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
480               "Expected 1d or 2d dilation, got ", dilation.size());
481 }
482 
check_maxpool3d_params(IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)483 void check_maxpool3d_params(
484     IntArrayRef kernel_size,
485     IntArrayRef stride,
486     IntArrayRef padding,
487     IntArrayRef dilation) {
488   TORCH_CHECK(kernel_size.size() == 3, "Expected 3d kernel size, got ", kernel_size.size());
489   TORCH_CHECK(stride.empty() || stride.size() == 3,
490               "Expected no strides or 3d strides, got", stride.size());
491   TORCH_CHECK(padding.size() == 3, "Expected 3d padding, got ", padding.size());
492   TORCH_CHECK(dilation.size() == 3, "Expected 1d or 3d dilation, got ", dilation.size());
493 }
494 
495 #ifdef USE_PYTORCH_QNNPACK
qnnpack_maxpool2d(Tensor input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)496  static Tensor qnnpack_maxpool2d(
497      Tensor input,
498      IntArrayRef kernel_size,
499      IntArrayRef stride,
500      IntArrayRef padding,
501      IntArrayRef dilation,
502      bool ceil_mode) {
503    Tensor qy;
504 
505    TORCH_CHECK(
506        input.ndimension() == 4,
507        "qnnpack_maxpool2d(): Expected input to be 4-dimensional: got ",
508        input.ndimension());
509    TORCH_CHECK(
510        kernel_size.size() == 2,
511        "qnnpack_maxpool2d(): Expected kernel_size to be 2-dimensional: got ",
512        kernel_size.size());
513    TORCH_CHECK(
514        stride.size() == 2,
515        "qnnpack_maxpool2d(): Expected stride to be 2-dimensional: got ",
516        stride.size());
517    TORCH_CHECK(
518        dilation.size() == 2,
519        "qnnpack_maxpool2d(): Expected dilation to be 2-dimensional: got ",
520        dilation.size());
521    TORCH_CHECK(
522        padding.size() == 2,
523        "qnnpack_maxpool2d(): Expected padding to be 2-dimensional: got ",
524        padding.size());
525 
526    int64_t batch_size = input.size(0);
527    int64_t inC = input.size(1);
528    int64_t inH = input.size(2);
529    int64_t inW = input.size(3);
530    Tensor input_contig = input.contiguous(MemoryFormat::ChannelsLast);
531 
532    initQNNPACK();
533    const auto scale = input_contig.q_scale();
534    const auto zero_point = input_contig.q_zero_point();
535    pytorch_qnnp_operator_t qnnpack_operator{nullptr};
536 
537    int64_t padH = padding[0];
538    int64_t padW = padding[1];
539    int64_t kH = kernel_size[0];
540    int64_t kW = kernel_size[1];
541    int64_t strideH = stride[0];
542    int64_t strideW = stride[1];
543    int64_t dilationH = dilation[0];
544    int64_t dilationW = dilation[1];
545 
546    TORCH_CHECK(
547        kH > 0 && kW > 0,
548        "qnnpack_maxpool2d(): kernel_size should be greater than zero.");
549    TORCH_CHECK(
550        strideH > 0 && strideW > 0,
551        "qnnpack_maxpool2d(): strides should be greater than zero.");
552 
553    const pytorch_qnnp_status createStatus =
554        pytorch_qnnp_create_max_pooling2d_nhwc_u8(
555            padH /* input_padding_height */,
556            padW /* input_padding_width */,
557            kH /* pooling height */,
558            kW /* pooling width */,
559            strideH /* stride height */,
560            strideW /* stride width */,
561            dilationH /* dilation height */,
562            dilationW /* dilation width */,
563            inC /* input channels */,
564            std::numeric_limits<uint8_t>::min() /* output min */,
565            std::numeric_limits<uint8_t>::max() /* output max */,
566            0 /* flags */,
567            &qnnpack_operator);
568    TORCH_INTERNAL_ASSERT(
569        createStatus == pytorch_qnnp_status_success,
570        "failed to create QNNPACK MaxPool operator");
571 
572    int64_t outC = inC;
573    int64_t outH =
574        pooling_output_shape(inH, kH, padH, strideH, dilationH, ceil_mode);
575    int64_t outW =
576        pooling_output_shape(inW, kW, padW, strideW, dilationW, ceil_mode);
577 
578    TORCH_CHECK(
579        outH > 0 && outW > 0,
580        "qnnpack_maxpool2d(): the resulting output Tensor size should be >= 0");
581 
582    std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
583        qnnpack_uniq_ptr(qnnpack_operator);
584 
585    // NHWC output
586    qy = at::_empty_affine_quantized(
587        {batch_size, outC, outH, outW},
588        at::device(kCPU).dtype(kQUInt8),
589        scale,
590        zero_point,
591        MemoryFormat::ChannelsLast);
592 
593    const pytorch_qnnp_status setupStatus =
594        pytorch_qnnp_setup_max_pooling2d_nhwc_u8(
595            qnnpack_operator /* max pooling */,
596            batch_size /* batch size */,
597            inH /* input height */,
598            inW /* input width */,
599            (uint8_t*)input_contig.data_ptr<c10::quint8>() /* input */,
600            inC /* input_pixel_stride */,
601            (uint8_t*)qy.data_ptr<c10::quint8>() /* output data */,
602            outC /* output_pixel_stride */,
603            nullptr /* thread pool */);
604    TORCH_INTERNAL_ASSERT(
605        setupStatus == pytorch_qnnp_status_success,
606        "failed to setup QNNPACK MaxPool operator");
607 
608    pthreadpool_t threadpool = caffe2::pthreadpool_();
609    const pytorch_qnnp_status runStatus =
610        pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
611    TORCH_INTERNAL_ASSERT(
612        runStatus == pytorch_qnnp_status_success,
613        "failed to run QNNPACK MaxPool operator");
614    return qy.contiguous(input.suggest_memory_format());
615  }
616  #endif
617 }  // namespace
618 
619 // at::native functions for the native_functions.yaml
quantized_max_pool2d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)620 Tensor quantized_max_pool2d(
621     const Tensor& qx,
622     IntArrayRef kernel_size,
623     IntArrayRef stride,
624     IntArrayRef padding,
625     IntArrayRef dilation,
626     bool ceil_mode) {
627   check_maxpool2d_params(
628       kernel_size,
629       stride,
630       padding,
631       dilation);
632   if (stride.empty()) {
633     stride = kernel_size;
634   }
635 #ifdef USE_PYTORCH_QNNPACK
636   if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8 && !ceil_mode) {
637     return qnnpack_maxpool2d(qx, kernel_size, stride, padding, dilation, ceil_mode);
638   }
639 #endif
640   Tensor qy;
641   AT_DISPATCH_QINT_TYPES_AND(ScalarType::Byte, qx.scalar_type(), "max_pool2d", [&]() {
642     qy = q_maxpool_2d<scalar_t>(
643         qx,
644         kernel_size[0],
645         kernel_size[1],
646         stride[0],
647         stride[1],
648         padding[0],
649         padding[1],
650         dilation[0],
651         dilation[1],
652         ceil_mode);
653   });
654   return qy;
655 }
656 
quantized_max_pool3d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)657 Tensor quantized_max_pool3d(
658     const Tensor& qx,
659     IntArrayRef kernel_size,
660     IntArrayRef stride,
661     IntArrayRef padding,
662     IntArrayRef dilation,
663     bool ceil_mode) {
664   check_maxpool3d_params(
665       kernel_size,
666       stride,
667       padding,
668       dilation);
669   if (stride.empty()) {
670     stride = kernel_size;
671   }
672 #ifdef USE_PYTORCH_QNNPACK
673   TORCH_CHECK(at::globalContext().qEngine() != at::QEngine::QNNPACK,
674               "QNNPACK backend doesn't support of quantized_max_pool3d");
675 #endif
676   Tensor qy;
677   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool3d", [&]() {
678     qy = q_maxpool_3d<scalar_t>(
679         qx,
680         kernel_size[0],
681         kernel_size[1],
682         kernel_size[2],
683         stride[0],
684         stride[1],
685         stride[2],
686         padding[0],
687         padding[1],
688         padding[2],
689         dilation[0],
690         dilation[1],
691         dilation[2],
692         ceil_mode);
693   });
694   return qy;
695 }
696 
697 // Quantized max_pool1d is a special case of the max_pool2d, with one of the
698 // dimensions and kernels removed.
quantized_max_pool1d(const Tensor & qx,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)699 Tensor quantized_max_pool1d(
700     const Tensor& qx,
701     IntArrayRef kernel_size,
702     IntArrayRef stride,
703     IntArrayRef padding,
704     IntArrayRef dilation,
705     bool ceil_mode) {
706   check_max_pool1d(qx, kernel_size, stride, padding, dilation, ceil_mode);
707   // (C, L) -> (C, 1, L) => kSqueezeDim = 1
708   // (N, C, L) -> (N, C, 1, L) => kSqueezeDim = 2
709   const int32_t kSqueezeDim = qx.dim() - 1;
710   const auto qx_unsqueeze = qx.unsqueeze(kSqueezeDim);
711   if (stride.empty()) {
712     stride = kernel_size;
713   }
714   auto qy = at::quantized_max_pool2d(
715     qx.unsqueeze(kSqueezeDim),
716     {1, kernel_size[0]},
717     {1, stride[0]},
718     {0, padding[0]},
719     {1, dilation[0]},
720     ceil_mode);
721   qy = qy.squeeze(kSqueezeDim);
722   return qy;
723 }
724 
725 // Keep the registry in the anonymous namespace.
726 namespace {
727 template <uint32_t kSpatialDim>
728 class QMaxPool_arr_args final {
729  public:
run(const Tensor & qx,std::vector<int64_t> kernel_size,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,bool ceil_mode)730   static Tensor run(
731       const Tensor& qx,
732       std::vector<int64_t> kernel_size,
733       std::vector<int64_t> stride,
734       std::vector<int64_t> padding,
735       std::vector<int64_t> dilation,
736       bool ceil_mode) {
737     if (!qx.is_quantized() && kSpatialDim == 2 && qx.scalar_type() == c10::ScalarType::Byte){
738       return at::native::quantized_max_pool2d(qx, kernel_size, stride, padding,
739                                       dilation, ceil_mode);
740     }
741     if (kSpatialDim == 1) {
742       return at::quantized_max_pool1d(qx, kernel_size, stride, padding,
743                                       dilation, ceil_mode);
744     } else if (kSpatialDim == 2) {
745       return at::quantized_max_pool2d(qx, kernel_size, stride, padding,
746                                       dilation, ceil_mode);
747     }
748     TORCH_CHECK(false, "MaxPool", kSpatialDim, "D is not supported.");
749   }
750 };
751 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)752 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
753   m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool1d"), TORCH_FN(QMaxPool_arr_args<1>::run));
754   m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
755 }
756 
TORCH_LIBRARY_IMPL(quantized,CPU,m)757 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
758   m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run));
759 }
760 
761 } // namespace
762 } // namespace native
763 } // namespace at
764