xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/MaxPooling.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/Pool.h>
7 
8 namespace at::native {
9 
check_max_pool1d(const Tensor & self,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)10 inline void check_max_pool1d(
11     const Tensor& self,
12     IntArrayRef kernel_size,
13     IntArrayRef stride,
14     IntArrayRef padding,
15     IntArrayRef dilation,
16     bool ceil_mode) {
17 
18   TORCH_CHECK(
19       self.dim() == 2 || self.dim() == 3,
20       "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
21   TORCH_CHECK(
22       kernel_size.size() == 1,
23       "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
24       kernel_size.size());
25   TORCH_CHECK(
26       stride.empty() || stride.size() == 1,
27       "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
28       stride.size());
29   TORCH_CHECK(
30       padding.size() == 1,
31       "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
32       padding.size());
33   TORCH_CHECK(
34       dilation.size() == 1,
35       "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
36       dilation.size());
37 
38   // If stride=None then set it to kernel_size
39   if (stride.empty()) {
40     stride = kernel_size;
41   }
42 
43   TORCH_CHECK(
44       kernel_size[0] > 0,
45       "max_pool1d() kernel_size must be greater than zero, but got ",
46       kernel_size[0]);
47   TORCH_CHECK(
48       stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
49   TORCH_CHECK(
50       padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
51   TORCH_CHECK(
52       padding[0] <= kernel_size[0] / 2,
53       "max_pool1d() padding should be at most half of kernel size, but got padding=",
54       padding[0],
55       " and kernel_size=",
56       kernel_size[0]);
57   TORCH_CHECK(
58       dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
59 
60   const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
61   TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
62 }
63 
64 // TODO(Heitor) Template by dimension
65 struct PoolingParams1D {
66   int64_t NB; // Number of batches
67   int64_t NC; // Number of channels
68   int64_t IW; // Input width
69   int64_t OW; // Output width
70   int64_t KW; // Kernel width
71   int64_t SJ; // Column stride
72   int64_t PJ; // Column padding
73   int64_t DJ; // Column dilation
74 
75   // Return index of input element for the given kernel and output index
indexPoolingParams1D76   inline int64_t index(int64_t kj, int64_t oj) const {
77     return oj * SJ + kj * DJ - PJ;
78   }
79 
80   // Return index of first output within bounds for this kernel index
valid_output_startPoolingParams1D81   inline int64_t valid_output_start(int64_t kj) const {
82     int64_t ij = index(kj, 0);;
83     return ij < 0 ? at::divup(-ij, SJ) : 0;
84   }
85 
86   // Return index one past last output within bounds for this kernel index
valid_output_endPoolingParams1D87   inline int64_t valid_output_end(int64_t kj) const {
88     int64_t ij = index(kj, OW - 1);
89     return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
90   }
91 };
92 
93 using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
94 
95 DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
96 
97 } // namespace at::native
98