xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AveragePool3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/Pool.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/avg_pool3d_backward_native.h>
14 #include <ATen/ops/avg_pool3d_native.h>
15 #endif
16 
17 namespace at::meta {
18 using namespace ::at::native;
19 
TORCH_META_FUNC(avg_pool3d)20 TORCH_META_FUNC(avg_pool3d) (
21   const Tensor& input,
22   IntArrayRef kernel_size,
23   IntArrayRef stride,
24   IntArrayRef padding,
25   bool ceil_mode,
26   bool count_include_pad,
27   std::optional<int64_t> divisor_override
28 ) {
29   // #20866, #22032: Guarantee this for the official C++ API?
30   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
31     "avg_pool3d: kernel_size must be a single int, or a tuple of three ints");
32   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
33   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
34   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
35 
36   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
37     "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints");
38   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
39   const int dH = stride.empty() ? kH :
40                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
41   const int dW = stride.empty() ? kW :
42                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
43 
44   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
45     "avg_pool3d: padding must be a single int, or a tuple of three ints");
46   const int padT = safe_downcast<int, int64_t>(padding[0]);
47   const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
48   const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
49 
50   TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
51     "non-empty 4D or 5D (batch mode) tensor expected for input");
52 
53   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
54     "divisor must be not zero");
55 
56   /* sizes */
57   const int64_t nbatch = input.size(0);
58   const int64_t nslices = input.size(-4);
59   const int64_t itime = input.size(-3);
60   const int64_t iheight = input.size(-2);
61   const int64_t iwidth = input.size(-1);
62 
63   const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
64   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
65   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
66 
67   pool3d_shape_check(
68     input,
69     nslices,
70     kT, kH, kW,
71     dT, dH, dW,
72     padT, padH, padW,
73     1, 1, 1,
74     itime, iheight, iwidth,
75     otime, oheight, owidth,
76     "avg_pool3d()",
77     /*check_input_size=*/ true);
78 
79   /* resize output */
80   if (input.ndimension() == 4) {
81     set_output_raw_strided(0, {nslices, otime, oheight, owidth}, {}, input.options());
82   }
83   else {
84     set_output_raw_strided(0, {nbatch, nslices, otime, oheight, owidth}, {}, input.options());
85   }
86 }
87 
TORCH_META_FUNC(avg_pool3d_backward)88 TORCH_META_FUNC(avg_pool3d_backward) (
89   const Tensor& gradOutput_,
90   const Tensor& input,
91   IntArrayRef kernel_size,
92   IntArrayRef stride,
93   IntArrayRef padding,
94   bool ceil_mode,
95   bool count_include_pad,
96   std::optional<int64_t> divisor_override
97 ) {
98   // #20866, #22032: Guarantee this for the official C++ API?
99   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
100     "avg_pool3d: kernel_size must be a single int, or a tuple of three ints");
101   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
102   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
103   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
104 
105   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
106     "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints");
107   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
108   const int dH = stride.empty() ? kH :
109                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
110   const int dW = stride.empty() ? kW :
111                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
112 
113   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
114     "avg_pool3d: padding must be a single int, or a tuple of three ints");
115   const int padT = safe_downcast<int, int64_t>(padding[0]);
116   const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
117   const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
118 
119   TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
120     "non-empty 4D or 5D (batch mode) tensor expected for input");
121 
122   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
123 
124   const int64_t nslices = input.size(-4);
125   const int64_t itime = input.size(-3);
126   const int64_t iheight = input.size(-2);
127   const int64_t iwidth = input.size(-1);
128 
129   /* XXX shape check behavior from TH */
130   const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
131   const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
132   const int64_t owidth_for_shape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
133 
134   avg_pool3d_backward_shape_check(
135     input,
136     gradOutput_,
137     nslices,
138     kT, kH, kW,
139     dT, dH, dW,
140     padT, padH, padW,
141     itime, iheight, iwidth,
142     otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check,
143     "avg_pool3d_backward()");
144 
145   /* resize output */
146   set_output_raw_strided(0, input.sizes(), {}, input.options());
147 }
148 
149 } // namespace at::meta
150 
151 namespace at::native {
152 
153 namespace {
154 
155 template <typename scalar_t>
avg_pool3d_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t nslices,int64_t itime,int64_t iwidth,int64_t iheight,int64_t otime,int64_t owidth,int64_t oheight,int kT,int kW,int kH,int dT,int dW,int dH,int padT,int padW,int padH,bool count_include_pad,std::optional<int64_t> divisor_override)156 static void avg_pool3d_out_frame(
157           const scalar_t *input_p,
158           scalar_t *output_p,
159           int64_t nslices,
160           int64_t itime,
161           int64_t iwidth,
162           int64_t iheight,
163           int64_t otime,
164           int64_t owidth,
165           int64_t oheight,
166           int kT,
167           int kW,
168           int kH,
169           int dT,
170           int dW,
171           int dH,
172           int padT,
173           int padW,
174           int padH,
175           bool count_include_pad,
176           std::optional<int64_t> divisor_override)
177 {
178   at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
179     for (const auto k : c10::irange(start, end)) {
180       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
181       int64_t i, j, ti;
182 
183       /* local pointers. */
184       const scalar_t *ip = input_p + k * itime * iwidth * iheight;
185       scalar_t *op = output_p + k * otime * owidth * oheight;
186       for (i = 0; i < otime * oheight * owidth; ++i)
187         *(op + i) = 0;
188 
189       /* loop over output */
190       for (ti = 0; ti < otime; ti++)
191       {
192         for (i = 0; i < oheight; i++)
193         {
194           for (j = 0; j < owidth; j++)
195           {
196             /* compute pool range. */
197             int64_t tstart = ti * dT - padT;
198             int64_t hstart = i  * dH - padH;
199             int64_t wstart = j  * dW - padW;
200             int64_t tend = std::min(tstart + kT, itime + padT);
201             int64_t hend = std::min(hstart + kH, iheight + padH);
202             int64_t wend = std::min(wstart + kW, iwidth + padW);
203             int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
204             tstart = std::max(tstart, (int64_t) 0);
205             hstart = std::max(hstart, (int64_t) 0);
206             wstart = std::max(wstart, (int64_t) 0);
207             tend = std::min(tend, itime);
208             hend = std::min(hend, iheight);
209             wend = std::min(wend, iwidth);
210 
211             if (tstart >= tend || hstart >= hend || wstart >= wend) {
212               ++op;
213               continue;
214             }
215 
216             int64_t divide_factor = 0;
217             if (divisor_override.has_value()) {
218               divide_factor = divisor_override.value();
219             } else {
220               if(count_include_pad) {
221                 divide_factor = pool_size;
222               } else {
223                 divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
224               }
225             }
226 
227             /* compute local sum: */
228             scalar_t sum = 0.0;
229             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
230             int64_t x, y, z;
231 
232             for (z = tstart; z < tend; z++)
233             {
234               for (y = hstart; y < hend; y++)
235               {
236                 for (x = wstart; x < wend; x++)
237                 {
238                   sum +=  *(ip + z * iwidth * iheight + y * iwidth + x);
239                 }
240               }
241             }
242 
243             /* set output to local max */
244             *op++ += sum / divide_factor;
245           }
246         }
247       }
248     }
249   });
250 }
251 
252 } // anonymous namespace
253 
TORCH_IMPL_FUNC(avg_pool3d_out_cpu)254 TORCH_IMPL_FUNC(avg_pool3d_out_cpu) (
255   const Tensor& input_,
256   IntArrayRef kernel_size,
257   IntArrayRef stride,
258   IntArrayRef padding,
259   bool ceil_mode,
260   bool count_include_pad,
261   std::optional<int64_t> divisor_override,
262   const Tensor& output
263 ) {
264   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
265   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
266   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
267 
268   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
269   const int dH = stride.empty() ? kH :
270                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
271   const int dW = stride.empty() ? kW :
272                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
273 
274   const int padT = safe_downcast<int, int64_t>(padding[0]);
275   const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
276   const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
277 
278   const int64_t nslices = input_.size(-4);
279   const int64_t itime = input_.size(-3);
280   const int64_t iheight = input_.size(-2);
281   const int64_t iwidth = input_.size(-1);
282 
283   const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
284   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
285   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
286 
287   /* get contiguous input */
288   Tensor input = input_.contiguous();
289 
290   if (input.ndimension() == 4) /* non-batch mode */
291   {
292     AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
293       "avg_pool3d_out_frame",
294       [&] {
295         const scalar_t *input_data = input.const_data_ptr<scalar_t>();
296         scalar_t *output_data = output.data_ptr<scalar_t>();
297 
298         avg_pool3d_out_frame(
299           input_data, output_data, nslices,
300           itime, iwidth, iheight,
301           otime, owidth, oheight,
302           kT, kW, kH,
303           dT, dW, dH,
304           padT, padW, padH,
305           count_include_pad,
306           divisor_override);
307     });
308   }
309   else  /* batch mode */
310   {
311     const int64_t nbatch = input.size(0);
312     const int64_t istride = nslices * itime * iwidth * iheight;
313     const int64_t ostride = nslices * otime * owidth * oheight;
314 
315     AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
316       "avg_pool3d_out_frame",
317       [&] {
318         const scalar_t *input_data = input.const_data_ptr<scalar_t>();
319         scalar_t *output_data = output.data_ptr<scalar_t>();
320 
321         at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
322           for (const auto p : c10::irange(start, end)) {
323             avg_pool3d_out_frame(
324               input_data + p * istride, output_data + p * ostride, nslices,
325               itime, iwidth, iheight,
326               otime, owidth, oheight,
327               kT, kW, kH,
328               dT, dW, dH,
329               padT, padW, padH,
330               count_include_pad,
331               divisor_override
332             );
333           }
334         });
335     });
336   }
337 }
338 
339 namespace {
340 
341 template <typename scalar_t>
avg_pool3d_backward_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,int64_t nslices,int64_t itime,int64_t iwidth,int64_t iheight,int64_t otime,int64_t owidth,int64_t oheight,int kT,int kW,int kH,int dT,int dW,int dH,int padT,int padW,int padH,bool count_include_pad,std::optional<int64_t> divisor_override)342 static void avg_pool3d_backward_out_frame(
343           scalar_t *gradInput_p,
344           const scalar_t *gradOutput_p,
345           int64_t nslices,
346           int64_t itime,
347           int64_t iwidth,
348           int64_t iheight,
349           int64_t otime,
350           int64_t owidth,
351           int64_t oheight,
352           int kT,
353           int kW,
354           int kH,
355           int dT,
356           int dW,
357           int dH,
358           int padT,
359           int padW,
360           int padH,
361           bool count_include_pad,
362           std::optional<int64_t> divisor_override)
363 {
364   at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
365     for (const auto k : c10::irange(start, end)) {
366       /* local pointers */
367       scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
368       const scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
369       for (int64_t i = 0; i < itime*iwidth*iheight; i++)
370         *(ip + i) = 0;
371 
372       /* loop over output */
373       for (int64_t ti = 0; ti < otime; ti++)
374       {
375         for (int64_t i = 0; i < oheight; i++)
376         {
377           for (int64_t j = 0; j < owidth; j++)
378           {
379             int64_t tstart = ti * dT - padT;
380             int64_t hstart = i  * dH - padH;
381             int64_t wstart = j  * dW - padW;
382             int64_t tend = std::min(tstart + kT, itime + padT);
383             int64_t hend = std::min(hstart + kH, iheight + padH);
384             int64_t wend = std::min(wstart + kW, iwidth + padW);
385             int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
386             tstart = std::max(tstart, (int64_t) 0);
387             hstart = std::max(hstart, (int64_t) 0);
388             wstart = std::max(wstart, (int64_t) 0);
389             tend = std::min(tend, itime);
390             hend = std::min(hend, iheight);
391             wend = std::min(wend, iwidth);
392 
393             int64_t divide_factor = 0;
394             if (divisor_override.has_value()) {
395               divide_factor = divisor_override.value();
396             } else {
397               if(count_include_pad) {
398                 divide_factor = pool_size;
399               } else {
400                 divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
401               }
402             }
403 
404             /* scatter gradients out to footprint: */
405             scalar_t val  = *op++;
406 
407             for (auto z = tstart; z < tend; z++)
408             {
409               for (auto y = hstart; y < hend; y++)
410               {
411                 for (auto x = wstart; x < wend; x++)
412                 {
413                   *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
414                 }
415               }
416             }
417           }
418         }
419       }
420     }
421   });
422 }
423 
424 } // anonymous namespace
425 
TORCH_IMPL_FUNC(avg_pool3d_backward_out_cpu)426 TORCH_IMPL_FUNC(avg_pool3d_backward_out_cpu) (
427   const Tensor& gradOutput_,
428   const Tensor& input,
429   IntArrayRef kernel_size,
430   IntArrayRef stride,
431   IntArrayRef padding,
432   bool ceil_mode,
433   bool count_include_pad,
434   std::optional<int64_t> divisor_override,
435   const Tensor& gradInput
436 ) {
437   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
438   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
439   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
440 
441   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
442   const int dH = stride.empty() ? kH :
443                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
444   const int dW = stride.empty() ? kW :
445                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
446 
447   const int padT = safe_downcast<int, int64_t>(padding[0]);
448   const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
449   const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
450 
451   const int64_t nslices = input.size(-4);
452   const int64_t itime = input.size(-3);
453   const int64_t iheight = input.size(-2);
454   const int64_t iwidth = input.size(-1);
455 
456   /* get contiguous gradOutput */
457   Tensor gradOutput = gradOutput_.contiguous();
458 
459   const int64_t otime = gradOutput.size(-3);
460   const int64_t oheight = gradOutput.size(-2);
461   const int64_t owidth = gradOutput.size(-1);
462 
463   gradInput.zero_();
464 
465   /* backprop */
466   if (input.ndimension() == 4) /* non-batch mode*/
467   {
468     AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
469       "avg_pool3d_backward_out_frame",
470       [&] {
471        scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
472        const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
473 
474        avg_pool3d_backward_out_frame(
475          gradInput_data, gradOutput_data,
476          nslices,
477          itime, iwidth, iheight,
478          otime, owidth, oheight,
479          kT, kW, kH,
480          dT, dW, dH,
481          padT, padW, padH,
482          count_include_pad,
483          divisor_override);
484     });
485   }
486   else /* batch mode */
487   {
488     const int64_t nbatch = input.size(0);
489     const int64_t istride = nslices * itime * iwidth * iheight;
490     const int64_t ostride = nslices * otime * owidth * oheight;
491 
492     AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(),
493       "avg_pool3d_backward_out_frame",
494       [&] {
495         scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
496         const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
497 
498         at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
499           for (const auto p : c10::irange(start, end)) {
500             avg_pool3d_backward_out_frame(
501               gradInput_data  + p * istride, gradOutput_data + p * ostride, nslices,
502               itime, iwidth, iheight,
503               otime, owidth, oheight,
504               kT, kW, kH,
505               dT, dW, dH,
506               padT, padW, padH,
507               count_include_pad,
508               divisor_override
509             );
510           }
511         });
512     });
513   }
514 }
515 
516 } // namespace at::native
517