xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Pool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/div_rtn.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <c10/util/irange.h>
6 
7 #include <utility>
8 
9 #pragma once
10 
11 namespace at::native {
12 
13 using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
14     int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
15 using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
16 
17 DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
18 DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
19 
20 // averge pooling has same signature for forward and backward
21 using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
22     int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional<int64_t> divisor_override);
23 using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
24     int dW, int dH, int padW, int padH, bool count_include_pad, std::optional<int64_t> divisor_override);
25 
26 DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
27 DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
28 
29 // averge pooling has same signature for forward and backward
30 using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input,
31     int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD,
32     int64_t padW, int64_t padH, int64_t padD, bool count_include_pad,
33     std::optional<int64_t> divisor_override);
34 using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input,
35     int kW, int kH, int kD, int dW, int dH, int dD,
36     int padW, int padH, int padD, bool count_include_pad,
37     std::optional<int64_t> divisor_override);
38 
39 DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel);
40 DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel);
41 
42 using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
43     int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
44 using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
45 
46 DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
47 DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
48 namespace {
49 
50 template <typename dest_t, typename src_t>
51 inline dest_t
safe_downcast(src_t v)52 safe_downcast(src_t v)
53 {
54   TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
55               "integer out of range");
56 
57   return static_cast<dest_t>(v);
58 }
59 
60 template<typename T>
pooling_output_shape_pad_lr(T inputSize,T kernelSize,T pad_l,T pad_r,T stride,T dilation,bool ceil_mode)61 inline T pooling_output_shape_pad_lr(
62         T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
63         bool ceil_mode) {
64     T outputSize = div_rtn<T>(
65         inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
66         (ceil_mode ? stride - 1 : 0), stride) + 1;
67     if (ceil_mode) {
68         // ensure that the last pooling starts inside the image
69         // needed to avoid problems in ceil mode
70         if ((outputSize - 1) * stride >= inputSize + pad_l) {
71           --outputSize;
72         }
73     }
74     return outputSize;
75 }
76 
77 template<typename T>
pooling_output_shape(T inputSize,T kernelSize,T pad,T stride,T dilation,bool ceil_mode)78 inline T pooling_output_shape(
79       T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
80     TORCH_CHECK(stride != 0, "stride should not be zero");
81     TORCH_CHECK(pad >= 0,
82                 "pad must be non-negative, but got pad: ", pad);
83     TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
84                 "pad should be at most half of effective kernel size, but got pad=",
85                 pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
86     return pooling_output_shape_pad_lr(
87         inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
88 }
89 
90 template <typename T>
_pooling_same_mode_padding_lr(T inputSize,T kernelSize,T stride,T dilation)91 std::pair<T, T> _pooling_same_mode_padding_lr(
92     T inputSize, T kernelSize, T stride, T dilation) {
93   // NOTE: with strides, the output shape is ceil(inputSize/stride)
94   auto total_padding = T(dilation) * (kernelSize - 1);
95 
96   // Prefer symmetric padding if possible
97   if (stride > 2 && (total_padding % 2 == 1)) {
98     // The floor in the output size calculation gives us a little wiggle room
99     auto wiggle_room = inputSize % stride - 1;
100     if (wiggle_room > 0) {
101       total_padding = total_padding - 1;
102     }
103   }
104 
105   auto left = total_padding / 2;
106   return {left, total_padding - left};
107 }
108 
pooling_same_mode_padding_lr(int64_t inputSize,int64_t kernelSize,int64_t stride,int64_t dilation)109 inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
110     int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
111   return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
112 }
113 
pooling_same_mode_padding_lr(c10::SymInt inputSize,c10::SymInt kernelSize,c10::SymInt stride,c10::SymInt dilation)114 inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
115     c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
116   return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
117 }
118 
119 // AveragePool2d/DilatedMaxPool2d (forward)
120 inline void
pool2d_shape_check(const Tensor & input,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t dilationH,int64_t dilationW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)121 pool2d_shape_check(
122   const Tensor& input,
123   int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW,
124   int64_t nInputPlane,
125   int64_t inputHeight, int64_t inputWidth,
126   int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
127 {
128   const int64_t ndim = input.ndimension();
129 #ifndef STRIP_ERROR_MESSAGES
130   const int64_t nOutputPlane = nInputPlane;
131 #endif
132 
133   TORCH_CHECK(kW > 0 && kH > 0,
134               "kernel size should be greater than zero, but got ",
135               "kH: ", kH, " kW: ", kW);
136   TORCH_CHECK(dW > 0 && dH > 0,
137               "stride should be greater than zero, but got "
138               "dH: ", dH, " dW: ", dW);
139   TORCH_CHECK(dilationH > 0 && dilationW > 0,
140               "dilation should be greater than zero, but got ",
141               "dilationH: ", dilationH, " dilationW: ", dilationW);
142 
143   bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
144   if (memory_format == at::MemoryFormat::ChannelsLast){
145     // Expect tensor in NHWC format and allow 0-dim only for N.
146     TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
147       "Expected 4D (batch mode) tensor expected for input with channels_last layout"
148       " with optional 0 dim batch size for input, but got: ", input.sizes());
149   } else {
150     TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
151       (ndim == 4 && valid_dims && input.size(3) != 0),
152       "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
153       input.sizes());
154   }
155 
156   TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
157               "pad should be smaller than or equal to half of kernel size, but got ",
158               "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
159 
160   TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
161               "Given input size: (",
162               nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
163               "Calculated output size: (",
164               nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
165               "Output size is too small");
166 }
167 
168 // DilatedMaxPool2d (backward)
169 inline void
max_pool2d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,int kH,int kW,int dH,int dW,int padH,int padW,int dilationH,int dilationW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)170 max_pool2d_backward_shape_check(
171   const Tensor& input,
172   const Tensor& gradOutput,
173   const Tensor& indices,
174   int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
175   int64_t nInputPlane,
176   int64_t inputHeight, int64_t inputWidth,
177   int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
178 {
179   pool2d_shape_check(
180     input,
181     kH, kW, dH, dW, padH, padW, dilationH, dilationW,
182     nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
183 
184   const int64_t ndim = input.ndimension();
185   const int64_t nOutputPlane = nInputPlane;
186 
187   check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
188   check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
189   check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
190 
191   check_dim_size(indices, ndim, ndim-3, nOutputPlane);
192   check_dim_size(indices, ndim, ndim-2, outputHeight);
193   check_dim_size(indices, ndim, ndim-1, outputWidth);
194 }
195 
196 // AveragePool2d (backward)
197 inline void
avg_pool2d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,int64_t,int kH,int kW,int dH,int dW,int padH,int padW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)198 avg_pool2d_backward_shape_check(
199   const Tensor& input,
200   const Tensor& gradOutput,
201   int64_t /*nbatch*/,
202   int kH, int kW, int dH, int dW, int padH, int padW,
203   int64_t nInputPlane,
204   int64_t inputHeight, int64_t inputWidth,
205   int64_t outputHeight, int64_t outputWidth,
206   MemoryFormat memory_format)
207 {
208   pool2d_shape_check(
209     input,
210     kH, kW, dH, dW, padH, padW, 1, 1,
211     nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
212     memory_format);
213 
214   const int64_t ndim = input.ndimension();
215   const int64_t nOutputPlane = nInputPlane;
216 
217   check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
218   check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
219   check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
220 }
221 
222 // AveragePool3d/DilatedMaxPool3d (forward)
223 inline void
224 pool3d_shape_check(
225   const Tensor& input,
226   int64_t nslices,
227   int kT, int kH, int kW,
228   int dT, int dH, int dW,
229   int pT, int pH, int pW,
230   int dilationT, int dilationH, int dilationW,
231   int64_t itime, int64_t iheight, int64_t iwidth,
232   int64_t otime, int64_t oheight, int64_t owidth,
233   const char *fn_name,
234   bool check_input_size=false)
235 {
236   const int64_t ndim = input.ndimension();
237 
238   TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
239               "kernel size should be greater than zero, but got ",
240               "kT: ", kT, " kH: ", kH, " kW: ", kW);
241   TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
242               "stride should be greater than zero, but got ",
243               "dT: ", dT, " dH: ", dH, " dW: ", dW);
244   TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
245               "dilation should be greater than zero, but got ",
246               "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
247 
248   TORCH_CHECK(ndim == 4 || ndim == 5,
249               fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
250 
251   for (const auto i : c10::irange(ndim)) {
252     if (ndim == 5 && i == 0) {
253       // size of batch-dim can be 0.
254       continue;
255     }
256     TORCH_CHECK(
257         input.size(i) > 0,
258         fn_name,
259         ": Expected input's non-batch dimensions to have positive length,"
260         " but input has a shape of ",
261         input.sizes(),
262         " and non-batch dimension ",
263         input.size(i),
264         " has length zero!")
265   }
266 
267   if (check_input_size) { // AveragePool3d
268     TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
269                 "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
270                 "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
271   }
272 
273   TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
274               "pad should be smaller than or equal to half of kernel size, but got "
275               "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
276 
277   TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
278               "Given input size: (",
279               nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
280               "Calculated output size: (",
281               nslices, "x", otime, "x", oheight, "x", owidth, "). ",
282               "Output size is too small");
283 }
284 
285 inline void
max_pool3d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,int64_t nslices,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,int64_t itime,int64_t iheight,int64_t iwidth,int64_t otime,int64_t oheight,int64_t owidth,const char * fn_name)286 max_pool3d_backward_shape_check(
287   const Tensor& input,
288   const Tensor& gradOutput,
289   const Tensor& indices,
290   int64_t nslices,
291   int kT, int kH, int kW,
292   int dT, int dH, int dW,
293   int pT, int pH, int pW,
294   int dilationT, int dilationH, int dilationW,
295   int64_t itime, int64_t iheight, int64_t iwidth,
296   int64_t otime, int64_t oheight, int64_t owidth,
297   const char* fn_name)
298 {
299   const int64_t ndim = input.ndimension();
300 
301   pool3d_shape_check(
302     input,
303     nslices,
304     kT, kH, kW,
305     dT, dH, dW,
306     pT, pH, pW,
307     dilationT, dilationH, dilationW,
308     itime, iheight, iwidth,
309     otime, oheight, owidth, fn_name);
310 
311   check_dim_size(gradOutput, ndim, ndim-4, nslices);
312   check_dim_size(gradOutput, ndim, ndim-3, otime);
313   check_dim_size(gradOutput, ndim, ndim-2, oheight);
314   check_dim_size(gradOutput, ndim, ndim-1, owidth);
315 
316   check_dim_size(indices, ndim, ndim-4, nslices);
317   check_dim_size(indices, ndim, ndim-3, otime);
318   check_dim_size(indices, ndim, ndim-2, oheight);
319   check_dim_size(indices, ndim, ndim-1, owidth);
320 }
321 
322 inline void
avg_pool3d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,int64_t nslices,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int64_t itime,int64_t iheight,int64_t iwidth,int64_t otime,int64_t oheight,int64_t owidth,const char * fn_name)323 avg_pool3d_backward_shape_check(
324   const Tensor& input,
325   const Tensor& gradOutput,
326   int64_t nslices,
327   int kT, int kH, int kW,
328   int dT, int dH, int dW,
329   int pT, int pH, int pW,
330   int64_t itime, int64_t iheight, int64_t iwidth,
331   int64_t otime, int64_t oheight, int64_t owidth,
332   const char *fn_name)
333 {
334   const int64_t ndim = input.ndimension();
335 
336   pool3d_shape_check(
337     input,
338     nslices,
339     kT, kH, kW,
340     dT, dH, dW,
341     pT, pH, pW,
342     1, 1, 1,
343     itime, iheight, iwidth,
344     otime, oheight, owidth,
345     fn_name, true);
346 
347   check_dim_size(gradOutput, ndim, ndim-4, nslices);
348   check_dim_size(gradOutput, ndim, ndim-3, otime);
349   check_dim_size(gradOutput, ndim, ndim-2, oheight);
350   check_dim_size(gradOutput, ndim, ndim-1, owidth);
351 }
352 
353 } // anonymous namespace
354 
355 } // namespace at::native
356