1 #include <ATen/core/Tensor.h>
2 #include <ATen/div_rtn.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <c10/util/irange.h>
6
7 #include <utility>
8
9 #pragma once
10
11 namespace at::native {
12
13 using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
14 int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
15 using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
16
17 DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
18 DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
19
20 // averge pooling has same signature for forward and backward
21 using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
22 int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional<int64_t> divisor_override);
23 using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
24 int dW, int dH, int padW, int padH, bool count_include_pad, std::optional<int64_t> divisor_override);
25
26 DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
27 DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
28
29 // averge pooling has same signature for forward and backward
30 using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input,
31 int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD,
32 int64_t padW, int64_t padH, int64_t padD, bool count_include_pad,
33 std::optional<int64_t> divisor_override);
34 using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input,
35 int kW, int kH, int kD, int dW, int dH, int dD,
36 int padW, int padH, int padD, bool count_include_pad,
37 std::optional<int64_t> divisor_override);
38
39 DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel);
40 DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel);
41
42 using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
43 int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
44 using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
45
46 DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
47 DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
48 namespace {
49
50 template <typename dest_t, typename src_t>
51 inline dest_t
safe_downcast(src_t v)52 safe_downcast(src_t v)
53 {
54 TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
55 "integer out of range");
56
57 return static_cast<dest_t>(v);
58 }
59
60 template<typename T>
pooling_output_shape_pad_lr(T inputSize,T kernelSize,T pad_l,T pad_r,T stride,T dilation,bool ceil_mode)61 inline T pooling_output_shape_pad_lr(
62 T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
63 bool ceil_mode) {
64 T outputSize = div_rtn<T>(
65 inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
66 (ceil_mode ? stride - 1 : 0), stride) + 1;
67 if (ceil_mode) {
68 // ensure that the last pooling starts inside the image
69 // needed to avoid problems in ceil mode
70 if ((outputSize - 1) * stride >= inputSize + pad_l) {
71 --outputSize;
72 }
73 }
74 return outputSize;
75 }
76
77 template<typename T>
pooling_output_shape(T inputSize,T kernelSize,T pad,T stride,T dilation,bool ceil_mode)78 inline T pooling_output_shape(
79 T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
80 TORCH_CHECK(stride != 0, "stride should not be zero");
81 TORCH_CHECK(pad >= 0,
82 "pad must be non-negative, but got pad: ", pad);
83 TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
84 "pad should be at most half of effective kernel size, but got pad=",
85 pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
86 return pooling_output_shape_pad_lr(
87 inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
88 }
89
90 template <typename T>
_pooling_same_mode_padding_lr(T inputSize,T kernelSize,T stride,T dilation)91 std::pair<T, T> _pooling_same_mode_padding_lr(
92 T inputSize, T kernelSize, T stride, T dilation) {
93 // NOTE: with strides, the output shape is ceil(inputSize/stride)
94 auto total_padding = T(dilation) * (kernelSize - 1);
95
96 // Prefer symmetric padding if possible
97 if (stride > 2 && (total_padding % 2 == 1)) {
98 // The floor in the output size calculation gives us a little wiggle room
99 auto wiggle_room = inputSize % stride - 1;
100 if (wiggle_room > 0) {
101 total_padding = total_padding - 1;
102 }
103 }
104
105 auto left = total_padding / 2;
106 return {left, total_padding - left};
107 }
108
pooling_same_mode_padding_lr(int64_t inputSize,int64_t kernelSize,int64_t stride,int64_t dilation)109 inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
110 int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
111 return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
112 }
113
pooling_same_mode_padding_lr(c10::SymInt inputSize,c10::SymInt kernelSize,c10::SymInt stride,c10::SymInt dilation)114 inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
115 c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
116 return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
117 }
118
119 // AveragePool2d/DilatedMaxPool2d (forward)
120 inline void
pool2d_shape_check(const Tensor & input,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t dilationH,int64_t dilationW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)121 pool2d_shape_check(
122 const Tensor& input,
123 int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW,
124 int64_t nInputPlane,
125 int64_t inputHeight, int64_t inputWidth,
126 int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
127 {
128 const int64_t ndim = input.ndimension();
129 #ifndef STRIP_ERROR_MESSAGES
130 const int64_t nOutputPlane = nInputPlane;
131 #endif
132
133 TORCH_CHECK(kW > 0 && kH > 0,
134 "kernel size should be greater than zero, but got ",
135 "kH: ", kH, " kW: ", kW);
136 TORCH_CHECK(dW > 0 && dH > 0,
137 "stride should be greater than zero, but got "
138 "dH: ", dH, " dW: ", dW);
139 TORCH_CHECK(dilationH > 0 && dilationW > 0,
140 "dilation should be greater than zero, but got ",
141 "dilationH: ", dilationH, " dilationW: ", dilationW);
142
143 bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
144 if (memory_format == at::MemoryFormat::ChannelsLast){
145 // Expect tensor in NHWC format and allow 0-dim only for N.
146 TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
147 "Expected 4D (batch mode) tensor expected for input with channels_last layout"
148 " with optional 0 dim batch size for input, but got: ", input.sizes());
149 } else {
150 TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
151 (ndim == 4 && valid_dims && input.size(3) != 0),
152 "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
153 input.sizes());
154 }
155
156 TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
157 "pad should be smaller than or equal to half of kernel size, but got ",
158 "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
159
160 TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
161 "Given input size: (",
162 nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
163 "Calculated output size: (",
164 nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
165 "Output size is too small");
166 }
167
168 // DilatedMaxPool2d (backward)
169 inline void
max_pool2d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,int kH,int kW,int dH,int dW,int padH,int padW,int dilationH,int dilationW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)170 max_pool2d_backward_shape_check(
171 const Tensor& input,
172 const Tensor& gradOutput,
173 const Tensor& indices,
174 int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
175 int64_t nInputPlane,
176 int64_t inputHeight, int64_t inputWidth,
177 int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
178 {
179 pool2d_shape_check(
180 input,
181 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
182 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
183
184 const int64_t ndim = input.ndimension();
185 const int64_t nOutputPlane = nInputPlane;
186
187 check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
188 check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
189 check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
190
191 check_dim_size(indices, ndim, ndim-3, nOutputPlane);
192 check_dim_size(indices, ndim, ndim-2, outputHeight);
193 check_dim_size(indices, ndim, ndim-1, outputWidth);
194 }
195
196 // AveragePool2d (backward)
197 inline void
avg_pool2d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,int64_t,int kH,int kW,int dH,int dW,int padH,int padW,int64_t nInputPlane,int64_t inputHeight,int64_t inputWidth,int64_t outputHeight,int64_t outputWidth,MemoryFormat memory_format)198 avg_pool2d_backward_shape_check(
199 const Tensor& input,
200 const Tensor& gradOutput,
201 int64_t /*nbatch*/,
202 int kH, int kW, int dH, int dW, int padH, int padW,
203 int64_t nInputPlane,
204 int64_t inputHeight, int64_t inputWidth,
205 int64_t outputHeight, int64_t outputWidth,
206 MemoryFormat memory_format)
207 {
208 pool2d_shape_check(
209 input,
210 kH, kW, dH, dW, padH, padW, 1, 1,
211 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
212 memory_format);
213
214 const int64_t ndim = input.ndimension();
215 const int64_t nOutputPlane = nInputPlane;
216
217 check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
218 check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
219 check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
220 }
221
222 // AveragePool3d/DilatedMaxPool3d (forward)
223 inline void
224 pool3d_shape_check(
225 const Tensor& input,
226 int64_t nslices,
227 int kT, int kH, int kW,
228 int dT, int dH, int dW,
229 int pT, int pH, int pW,
230 int dilationT, int dilationH, int dilationW,
231 int64_t itime, int64_t iheight, int64_t iwidth,
232 int64_t otime, int64_t oheight, int64_t owidth,
233 const char *fn_name,
234 bool check_input_size=false)
235 {
236 const int64_t ndim = input.ndimension();
237
238 TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
239 "kernel size should be greater than zero, but got ",
240 "kT: ", kT, " kH: ", kH, " kW: ", kW);
241 TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
242 "stride should be greater than zero, but got ",
243 "dT: ", dT, " dH: ", dH, " dW: ", dW);
244 TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
245 "dilation should be greater than zero, but got ",
246 "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
247
248 TORCH_CHECK(ndim == 4 || ndim == 5,
249 fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
250
251 for (const auto i : c10::irange(ndim)) {
252 if (ndim == 5 && i == 0) {
253 // size of batch-dim can be 0.
254 continue;
255 }
256 TORCH_CHECK(
257 input.size(i) > 0,
258 fn_name,
259 ": Expected input's non-batch dimensions to have positive length,"
260 " but input has a shape of ",
261 input.sizes(),
262 " and non-batch dimension ",
263 input.size(i),
264 " has length zero!")
265 }
266
267 if (check_input_size) { // AveragePool3d
268 TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
269 "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
270 "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
271 }
272
273 TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
274 "pad should be smaller than or equal to half of kernel size, but got "
275 "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
276
277 TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
278 "Given input size: (",
279 nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
280 "Calculated output size: (",
281 nslices, "x", otime, "x", oheight, "x", owidth, "). ",
282 "Output size is too small");
283 }
284
285 inline void
max_pool3d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,int64_t nslices,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,int64_t itime,int64_t iheight,int64_t iwidth,int64_t otime,int64_t oheight,int64_t owidth,const char * fn_name)286 max_pool3d_backward_shape_check(
287 const Tensor& input,
288 const Tensor& gradOutput,
289 const Tensor& indices,
290 int64_t nslices,
291 int kT, int kH, int kW,
292 int dT, int dH, int dW,
293 int pT, int pH, int pW,
294 int dilationT, int dilationH, int dilationW,
295 int64_t itime, int64_t iheight, int64_t iwidth,
296 int64_t otime, int64_t oheight, int64_t owidth,
297 const char* fn_name)
298 {
299 const int64_t ndim = input.ndimension();
300
301 pool3d_shape_check(
302 input,
303 nslices,
304 kT, kH, kW,
305 dT, dH, dW,
306 pT, pH, pW,
307 dilationT, dilationH, dilationW,
308 itime, iheight, iwidth,
309 otime, oheight, owidth, fn_name);
310
311 check_dim_size(gradOutput, ndim, ndim-4, nslices);
312 check_dim_size(gradOutput, ndim, ndim-3, otime);
313 check_dim_size(gradOutput, ndim, ndim-2, oheight);
314 check_dim_size(gradOutput, ndim, ndim-1, owidth);
315
316 check_dim_size(indices, ndim, ndim-4, nslices);
317 check_dim_size(indices, ndim, ndim-3, otime);
318 check_dim_size(indices, ndim, ndim-2, oheight);
319 check_dim_size(indices, ndim, ndim-1, owidth);
320 }
321
322 inline void
avg_pool3d_backward_shape_check(const Tensor & input,const Tensor & gradOutput,int64_t nslices,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int64_t itime,int64_t iheight,int64_t iwidth,int64_t otime,int64_t oheight,int64_t owidth,const char * fn_name)323 avg_pool3d_backward_shape_check(
324 const Tensor& input,
325 const Tensor& gradOutput,
326 int64_t nslices,
327 int kT, int kH, int kW,
328 int dT, int dH, int dW,
329 int pT, int pH, int pW,
330 int64_t itime, int64_t iheight, int64_t iwidth,
331 int64_t otime, int64_t oheight, int64_t owidth,
332 const char *fn_name)
333 {
334 const int64_t ndim = input.ndimension();
335
336 pool3d_shape_check(
337 input,
338 nslices,
339 kT, kH, kW,
340 dT, dH, dW,
341 pT, pH, pW,
342 1, 1, 1,
343 itime, iheight, iwidth,
344 otime, oheight, owidth,
345 fn_name, true);
346
347 check_dim_size(gradOutput, ndim, ndim-4, nslices);
348 check_dim_size(gradOutput, ndim, ndim-3, otime);
349 check_dim_size(gradOutput, ndim, ndim-2, oheight);
350 check_dim_size(gradOutput, ndim, ndim-1, owidth);
351 }
352
353 } // anonymous namespace
354
355 } // namespace at::native
356