1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/Pool.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/avg_pool3d_backward_native.h>
14 #include <ATen/ops/avg_pool3d_native.h>
15 #endif
16
17 namespace at::meta {
18 using namespace ::at::native;
19
TORCH_META_FUNC(avg_pool3d)20 TORCH_META_FUNC(avg_pool3d) (
21 const Tensor& input,
22 IntArrayRef kernel_size,
23 IntArrayRef stride,
24 IntArrayRef padding,
25 bool ceil_mode,
26 bool count_include_pad,
27 std::optional<int64_t> divisor_override
28 ) {
29 // #20866, #22032: Guarantee this for the official C++ API?
30 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
31 "avg_pool3d: kernel_size must be a single int, or a tuple of three ints");
32 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
33 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
34 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
35
36 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
37 "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints");
38 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
39 const int dH = stride.empty() ? kH :
40 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
41 const int dW = stride.empty() ? kW :
42 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
43
44 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
45 "avg_pool3d: padding must be a single int, or a tuple of three ints");
46 const int padT = safe_downcast<int, int64_t>(padding[0]);
47 const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
48 const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
49
50 TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
51 "non-empty 4D or 5D (batch mode) tensor expected for input");
52
53 TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
54 "divisor must be not zero");
55
56 /* sizes */
57 const int64_t nbatch = input.size(0);
58 const int64_t nslices = input.size(-4);
59 const int64_t itime = input.size(-3);
60 const int64_t iheight = input.size(-2);
61 const int64_t iwidth = input.size(-1);
62
63 const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
64 const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
65 const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
66
67 pool3d_shape_check(
68 input,
69 nslices,
70 kT, kH, kW,
71 dT, dH, dW,
72 padT, padH, padW,
73 1, 1, 1,
74 itime, iheight, iwidth,
75 otime, oheight, owidth,
76 "avg_pool3d()",
77 /*check_input_size=*/ true);
78
79 /* resize output */
80 if (input.ndimension() == 4) {
81 set_output_raw_strided(0, {nslices, otime, oheight, owidth}, {}, input.options());
82 }
83 else {
84 set_output_raw_strided(0, {nbatch, nslices, otime, oheight, owidth}, {}, input.options());
85 }
86 }
87
TORCH_META_FUNC(avg_pool3d_backward)88 TORCH_META_FUNC(avg_pool3d_backward) (
89 const Tensor& gradOutput_,
90 const Tensor& input,
91 IntArrayRef kernel_size,
92 IntArrayRef stride,
93 IntArrayRef padding,
94 bool ceil_mode,
95 bool count_include_pad,
96 std::optional<int64_t> divisor_override
97 ) {
98 // #20866, #22032: Guarantee this for the official C++ API?
99 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
100 "avg_pool3d: kernel_size must be a single int, or a tuple of three ints");
101 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
102 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
103 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
104
105 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
106 "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints");
107 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
108 const int dH = stride.empty() ? kH :
109 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
110 const int dW = stride.empty() ? kW :
111 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
112
113 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
114 "avg_pool3d: padding must be a single int, or a tuple of three ints");
115 const int padT = safe_downcast<int, int64_t>(padding[0]);
116 const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
117 const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
118
119 TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
120 "non-empty 4D or 5D (batch mode) tensor expected for input");
121
122 TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
123
124 const int64_t nslices = input.size(-4);
125 const int64_t itime = input.size(-3);
126 const int64_t iheight = input.size(-2);
127 const int64_t iwidth = input.size(-1);
128
129 /* XXX shape check behavior from TH */
130 const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
131 const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
132 const int64_t owidth_for_shape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
133
134 avg_pool3d_backward_shape_check(
135 input,
136 gradOutput_,
137 nslices,
138 kT, kH, kW,
139 dT, dH, dW,
140 padT, padH, padW,
141 itime, iheight, iwidth,
142 otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check,
143 "avg_pool3d_backward()");
144
145 /* resize output */
146 set_output_raw_strided(0, input.sizes(), {}, input.options());
147 }
148
149 } // namespace at::meta
150
151 namespace at::native {
152
153 namespace {
154
155 template <typename scalar_t>
avg_pool3d_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t nslices,int64_t itime,int64_t iwidth,int64_t iheight,int64_t otime,int64_t owidth,int64_t oheight,int kT,int kW,int kH,int dT,int dW,int dH,int padT,int padW,int padH,bool count_include_pad,std::optional<int64_t> divisor_override)156 static void avg_pool3d_out_frame(
157 const scalar_t *input_p,
158 scalar_t *output_p,
159 int64_t nslices,
160 int64_t itime,
161 int64_t iwidth,
162 int64_t iheight,
163 int64_t otime,
164 int64_t owidth,
165 int64_t oheight,
166 int kT,
167 int kW,
168 int kH,
169 int dT,
170 int dW,
171 int dH,
172 int padT,
173 int padW,
174 int padH,
175 bool count_include_pad,
176 std::optional<int64_t> divisor_override)
177 {
178 at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
179 for (const auto k : c10::irange(start, end)) {
180 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
181 int64_t i, j, ti;
182
183 /* local pointers. */
184 const scalar_t *ip = input_p + k * itime * iwidth * iheight;
185 scalar_t *op = output_p + k * otime * owidth * oheight;
186 for (i = 0; i < otime * oheight * owidth; ++i)
187 *(op + i) = 0;
188
189 /* loop over output */
190 for (ti = 0; ti < otime; ti++)
191 {
192 for (i = 0; i < oheight; i++)
193 {
194 for (j = 0; j < owidth; j++)
195 {
196 /* compute pool range. */
197 int64_t tstart = ti * dT - padT;
198 int64_t hstart = i * dH - padH;
199 int64_t wstart = j * dW - padW;
200 int64_t tend = std::min(tstart + kT, itime + padT);
201 int64_t hend = std::min(hstart + kH, iheight + padH);
202 int64_t wend = std::min(wstart + kW, iwidth + padW);
203 int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
204 tstart = std::max(tstart, (int64_t) 0);
205 hstart = std::max(hstart, (int64_t) 0);
206 wstart = std::max(wstart, (int64_t) 0);
207 tend = std::min(tend, itime);
208 hend = std::min(hend, iheight);
209 wend = std::min(wend, iwidth);
210
211 if (tstart >= tend || hstart >= hend || wstart >= wend) {
212 ++op;
213 continue;
214 }
215
216 int64_t divide_factor = 0;
217 if (divisor_override.has_value()) {
218 divide_factor = divisor_override.value();
219 } else {
220 if(count_include_pad) {
221 divide_factor = pool_size;
222 } else {
223 divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
224 }
225 }
226
227 /* compute local sum: */
228 scalar_t sum = 0.0;
229 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
230 int64_t x, y, z;
231
232 for (z = tstart; z < tend; z++)
233 {
234 for (y = hstart; y < hend; y++)
235 {
236 for (x = wstart; x < wend; x++)
237 {
238 sum += *(ip + z * iwidth * iheight + y * iwidth + x);
239 }
240 }
241 }
242
243 /* set output to local max */
244 *op++ += sum / divide_factor;
245 }
246 }
247 }
248 }
249 });
250 }
251
252 } // anonymous namespace
253
TORCH_IMPL_FUNC(avg_pool3d_out_cpu)254 TORCH_IMPL_FUNC(avg_pool3d_out_cpu) (
255 const Tensor& input_,
256 IntArrayRef kernel_size,
257 IntArrayRef stride,
258 IntArrayRef padding,
259 bool ceil_mode,
260 bool count_include_pad,
261 std::optional<int64_t> divisor_override,
262 const Tensor& output
263 ) {
264 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
265 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
266 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
267
268 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
269 const int dH = stride.empty() ? kH :
270 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
271 const int dW = stride.empty() ? kW :
272 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
273
274 const int padT = safe_downcast<int, int64_t>(padding[0]);
275 const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
276 const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
277
278 const int64_t nslices = input_.size(-4);
279 const int64_t itime = input_.size(-3);
280 const int64_t iheight = input_.size(-2);
281 const int64_t iwidth = input_.size(-1);
282
283 const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
284 const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
285 const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
286
287 /* get contiguous input */
288 Tensor input = input_.contiguous();
289
290 if (input.ndimension() == 4) /* non-batch mode */
291 {
292 AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
293 "avg_pool3d_out_frame",
294 [&] {
295 const scalar_t *input_data = input.const_data_ptr<scalar_t>();
296 scalar_t *output_data = output.data_ptr<scalar_t>();
297
298 avg_pool3d_out_frame(
299 input_data, output_data, nslices,
300 itime, iwidth, iheight,
301 otime, owidth, oheight,
302 kT, kW, kH,
303 dT, dW, dH,
304 padT, padW, padH,
305 count_include_pad,
306 divisor_override);
307 });
308 }
309 else /* batch mode */
310 {
311 const int64_t nbatch = input.size(0);
312 const int64_t istride = nslices * itime * iwidth * iheight;
313 const int64_t ostride = nslices * otime * owidth * oheight;
314
315 AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
316 "avg_pool3d_out_frame",
317 [&] {
318 const scalar_t *input_data = input.const_data_ptr<scalar_t>();
319 scalar_t *output_data = output.data_ptr<scalar_t>();
320
321 at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
322 for (const auto p : c10::irange(start, end)) {
323 avg_pool3d_out_frame(
324 input_data + p * istride, output_data + p * ostride, nslices,
325 itime, iwidth, iheight,
326 otime, owidth, oheight,
327 kT, kW, kH,
328 dT, dW, dH,
329 padT, padW, padH,
330 count_include_pad,
331 divisor_override
332 );
333 }
334 });
335 });
336 }
337 }
338
339 namespace {
340
341 template <typename scalar_t>
avg_pool3d_backward_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,int64_t nslices,int64_t itime,int64_t iwidth,int64_t iheight,int64_t otime,int64_t owidth,int64_t oheight,int kT,int kW,int kH,int dT,int dW,int dH,int padT,int padW,int padH,bool count_include_pad,std::optional<int64_t> divisor_override)342 static void avg_pool3d_backward_out_frame(
343 scalar_t *gradInput_p,
344 const scalar_t *gradOutput_p,
345 int64_t nslices,
346 int64_t itime,
347 int64_t iwidth,
348 int64_t iheight,
349 int64_t otime,
350 int64_t owidth,
351 int64_t oheight,
352 int kT,
353 int kW,
354 int kH,
355 int dT,
356 int dW,
357 int dH,
358 int padT,
359 int padW,
360 int padH,
361 bool count_include_pad,
362 std::optional<int64_t> divisor_override)
363 {
364 at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
365 for (const auto k : c10::irange(start, end)) {
366 /* local pointers */
367 scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
368 const scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
369 for (int64_t i = 0; i < itime*iwidth*iheight; i++)
370 *(ip + i) = 0;
371
372 /* loop over output */
373 for (int64_t ti = 0; ti < otime; ti++)
374 {
375 for (int64_t i = 0; i < oheight; i++)
376 {
377 for (int64_t j = 0; j < owidth; j++)
378 {
379 int64_t tstart = ti * dT - padT;
380 int64_t hstart = i * dH - padH;
381 int64_t wstart = j * dW - padW;
382 int64_t tend = std::min(tstart + kT, itime + padT);
383 int64_t hend = std::min(hstart + kH, iheight + padH);
384 int64_t wend = std::min(wstart + kW, iwidth + padW);
385 int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
386 tstart = std::max(tstart, (int64_t) 0);
387 hstart = std::max(hstart, (int64_t) 0);
388 wstart = std::max(wstart, (int64_t) 0);
389 tend = std::min(tend, itime);
390 hend = std::min(hend, iheight);
391 wend = std::min(wend, iwidth);
392
393 int64_t divide_factor = 0;
394 if (divisor_override.has_value()) {
395 divide_factor = divisor_override.value();
396 } else {
397 if(count_include_pad) {
398 divide_factor = pool_size;
399 } else {
400 divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
401 }
402 }
403
404 /* scatter gradients out to footprint: */
405 scalar_t val = *op++;
406
407 for (auto z = tstart; z < tend; z++)
408 {
409 for (auto y = hstart; y < hend; y++)
410 {
411 for (auto x = wstart; x < wend; x++)
412 {
413 *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
414 }
415 }
416 }
417 }
418 }
419 }
420 }
421 });
422 }
423
424 } // anonymous namespace
425
TORCH_IMPL_FUNC(avg_pool3d_backward_out_cpu)426 TORCH_IMPL_FUNC(avg_pool3d_backward_out_cpu) (
427 const Tensor& gradOutput_,
428 const Tensor& input,
429 IntArrayRef kernel_size,
430 IntArrayRef stride,
431 IntArrayRef padding,
432 bool ceil_mode,
433 bool count_include_pad,
434 std::optional<int64_t> divisor_override,
435 const Tensor& gradInput
436 ) {
437 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
438 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
439 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
440
441 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
442 const int dH = stride.empty() ? kH :
443 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
444 const int dW = stride.empty() ? kW :
445 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
446
447 const int padT = safe_downcast<int, int64_t>(padding[0]);
448 const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
449 const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
450
451 const int64_t nslices = input.size(-4);
452 const int64_t itime = input.size(-3);
453 const int64_t iheight = input.size(-2);
454 const int64_t iwidth = input.size(-1);
455
456 /* get contiguous gradOutput */
457 Tensor gradOutput = gradOutput_.contiguous();
458
459 const int64_t otime = gradOutput.size(-3);
460 const int64_t oheight = gradOutput.size(-2);
461 const int64_t owidth = gradOutput.size(-1);
462
463 gradInput.zero_();
464
465 /* backprop */
466 if (input.ndimension() == 4) /* non-batch mode*/
467 {
468 AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
469 "avg_pool3d_backward_out_frame",
470 [&] {
471 scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
472 const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
473
474 avg_pool3d_backward_out_frame(
475 gradInput_data, gradOutput_data,
476 nslices,
477 itime, iwidth, iheight,
478 otime, owidth, oheight,
479 kT, kW, kH,
480 dT, dW, dH,
481 padT, padW, padH,
482 count_include_pad,
483 divisor_override);
484 });
485 }
486 else /* batch mode */
487 {
488 const int64_t nbatch = input.size(0);
489 const int64_t istride = nslices * itime * iwidth * iheight;
490 const int64_t ostride = nslices * otime * owidth * oheight;
491
492 AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
493 "avg_pool3d_backward_out_frame",
494 [&] {
495 scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
496 const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
497
498 at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
499 for (const auto p : c10::irange(start, end)) {
500 avg_pool3d_backward_out_frame(
501 gradInput_data + p * istride, gradOutput_data + p * ostride, nslices,
502 itime, iwidth, iheight,
503 otime, owidth, oheight,
504 kT, kW, kH,
505 dT, dW, dH,
506 padT, padW, padH,
507 count_include_pad,
508 divisor_override
509 );
510 }
511 });
512 });
513 }
514 }
515
516 } // namespace at::native
517