xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6 
7 #include <ATen/native/AdaptivePooling.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_adaptive_avg_pool3d.h>
14 #include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
15 #include <ATen/ops/_adaptive_avg_pool3d_native.h>
16 #include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
17 #include <ATen/ops/adaptive_avg_pool3d_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/zeros_like.h>
20 #endif
21 
22 namespace at::native {
23 
24 namespace {
25 
26 template <typename scalar_t>
adaptive_avg_pool3d_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)27 static void adaptive_avg_pool3d_out_frame(
28     const scalar_t* input_p,
29     scalar_t* output_p,
30     int64_t sizeD,
31     int64_t isizeT,
32     int64_t isizeH,
33     int64_t isizeW,
34     int64_t osizeT,
35     int64_t osizeH,
36     int64_t osizeW,
37     int64_t istrideD,
38     int64_t istrideT,
39     int64_t istrideH,
40     int64_t istrideW) {
41   at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
42     for (const auto d : c10::irange(start, end)) {
43       /* loop over output */
44       for (const auto ot : c10::irange(osizeT)) {
45         auto istartT = start_index(ot, osizeT, isizeT);
46         auto iendT = end_index(ot, osizeT, isizeT);
47         auto kT = iendT - istartT;
48 
49         for (const auto oh : c10::irange(osizeH)) {
50           auto istartH = start_index(oh, osizeH, isizeH);
51           auto iendH = end_index(oh, osizeH, isizeH);
52           auto kH = iendH - istartH;
53 
54           for (const auto ow : c10::irange(osizeW)) {
55             auto istartW = start_index(ow, osizeW, isizeW);
56             auto iendW = end_index(ow, osizeW, isizeW);
57             auto kW = iendW - istartW;
58 
59             /* local pointers */
60             const scalar_t* ip = input_p + d * istrideD + istartT * istrideT +
61                 istartH * istrideH + istartW * istrideW;
62             scalar_t* op = output_p + d * osizeT * osizeH * osizeW +
63                 ot * osizeH * osizeW + oh * osizeW + ow;
64 
65             /* compute local average: */
66             scalar_t sum = 0;
67             for (const auto it : c10::irange(kT)) {
68               for (const auto ih : c10::irange(kH)) {
69                 for (const auto iw : c10::irange(kW)) {
70                   scalar_t val =
71                       *(ip + it * istrideT + ih * istrideH + iw * istrideW);
72                   sum += val;
73                 }
74               }
75             }
76 
77             /* set output to local average */
78             *op = sum / kT / kH / kW;
79           }
80         }
81       }
82     }
83   });
84 }
85 
adaptive_avg_pool3d_out_cpu_template(Tensor & output,Tensor const & input,IntArrayRef output_size)86 void adaptive_avg_pool3d_out_cpu_template(
87     Tensor& output,
88     Tensor const& input,
89     IntArrayRef output_size) {
90   TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
91 
92   for (const auto i : c10::irange(1, input.ndimension())) {
93     TORCH_CHECK(
94         input.size(i) > 0,
95         "adaptive_avg_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
96         "but input has sizes ",
97         input.sizes(),
98         " with dimension ",
99         i,
100         " being "
101         "empty");
102   }
103 
104   TORCH_CHECK(
105       (input.ndimension() == 4 || input.ndimension() == 5),
106       "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
107       input.sizes());
108   TORCH_CHECK(input.dtype() == output.dtype(),
109       "expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype());
110 
111   /* sizes */
112   int64_t sizeD = input.size(-4);
113   int64_t isizeT = input.size(-3);
114   int64_t isizeH = input.size(-2);
115   int64_t isizeW = input.size(-1);
116   /* strides */
117   int64_t istrideD = input.stride(-4);
118   int64_t istrideT = input.stride(-3);
119   int64_t istrideH = input.stride(-2);
120   int64_t istrideW = input.stride(-1);
121   /* output sizes */
122   auto osizeT = output_size[0];
123   auto osizeH = output_size[1];
124   auto osizeW = output_size[2];
125 
126   if (input.ndimension() == 4) {
127     output.resize_({sizeD, osizeT, osizeH, osizeW});
128 
129     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
130         input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
131           auto input_data = input.const_data_ptr<scalar_t>();
132           auto output_data = output.data_ptr<scalar_t>();
133           adaptive_avg_pool3d_out_frame<scalar_t>(
134               input_data,
135               output_data,
136               sizeD,
137               isizeT,
138               isizeH,
139               isizeW,
140               osizeT,
141               osizeH,
142               osizeW,
143               istrideD,
144               istrideT,
145               istrideH,
146               istrideW);
147         });
148   } else {
149     output.resize_({input.size(-5), sizeD, osizeT, osizeH, osizeW});
150     int64_t n = input.size(0);
151 
152     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
153         input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
154           auto input_data = input.const_data_ptr<scalar_t>();
155           auto output_data = output.data_ptr<scalar_t>();
156           at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
157             for (const auto b : c10::irange(start, end)) {
158               adaptive_avg_pool3d_out_frame<scalar_t>(
159                   input_data + b * input.stride(0),
160                   output_data + b * sizeD * osizeT * osizeH * osizeW,
161                   sizeD,
162                   isizeT,
163                   isizeH,
164                   isizeW,
165                   osizeT,
166                   osizeH,
167                   osizeW,
168                   istrideD,
169                   istrideT,
170                   istrideH,
171                   istrideW);
172             }
173           });
174     });
175   }
176 }
177 
178 template <typename scalar_t>
adaptive_avg_pool3d_backward_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)179 static void adaptive_avg_pool3d_backward_out_frame(
180     scalar_t* gradInput_p,
181     const scalar_t* gradOutput_p,
182     int64_t sizeD,
183     int64_t isizeT,
184     int64_t isizeH,
185     int64_t isizeW,
186     int64_t osizeT,
187     int64_t osizeH,
188     int64_t osizeW) {
189   at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
190     for (const auto d : c10::irange(start, end)) {
191       scalar_t* gradInput_p_d = gradInput_p + d * isizeT * isizeW * isizeH;
192       const scalar_t* gradOutput_p_d = gradOutput_p + d * osizeT * osizeW * osizeH;
193 
194       /* calculate average */
195       for (const auto ot : c10::irange(osizeT)) {
196         auto istartT = start_index(ot, osizeT, isizeT);
197         auto iendT = end_index(ot, osizeT, isizeT);
198         auto kT = iendT - istartT;
199 
200         for (const auto oh : c10::irange(osizeH)) {
201           auto istartH = start_index(oh, osizeH, isizeH);
202           auto iendH = end_index(oh, osizeH, isizeH);
203           auto kH = iendH - istartH;
204 
205           for (const auto ow : c10::irange(osizeW)) {
206             auto istartW = start_index(ow, osizeW, isizeW);
207             auto iendW = end_index(ow, osizeW, isizeW);
208             auto kW = iendW - istartW;
209 
210             scalar_t grad_delta =
211                 gradOutput_p_d[ot * osizeH * osizeW + oh * osizeW + ow] / kT /
212                 kH / kW;
213 
214             for (const auto it : c10::irange(istartT, iendT)) {
215               for (const auto ih : c10::irange(istartH, iendH)) {
216                 for (const auto iw : c10::irange(istartW, iendW)) {
217                   /* update gradient */
218                   gradInput_p_d[it * isizeH * isizeW + ih * isizeW + iw] +=
219                       grad_delta;
220                 }
221               }
222             }
223           }
224         }
225       }
226     }
227   });
228 }
229 
adaptive_avg_pool3d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput_,const Tensor & input)230 Tensor& adaptive_avg_pool3d_backward_out_cpu_template(
231     Tensor& gradInput,
232     const Tensor& gradOutput_,
233     const Tensor& input) {
234   /* get contiguous gradOutput */
235   auto gradOutput = gradOutput_.contiguous();
236 
237   adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
238 
239   /* sizes */
240   int64_t sizeD = input.size(-4);
241   int64_t isizeT = input.size(-3);
242   int64_t isizeH = input.size(-2);
243   int64_t isizeW = input.size(-1);
244   int64_t osizeT = gradOutput.size(-3);
245   int64_t osizeH = gradOutput.size(-2);
246   int64_t osizeW = gradOutput.size(-1);
247 
248   /* backprop */
249   if (input.ndimension() == 4) {
250     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
251         input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
252           /* get raw pointers */
253           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
254           const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
255 
256           adaptive_avg_pool3d_backward_out_frame<scalar_t>(
257               gradInput_data,
258               gradOutput_data,
259               sizeD,
260               isizeT,
261               isizeH,
262               isizeW,
263               osizeT,
264               osizeH,
265               osizeW);
266         });
267   } else {
268     int64_t n = input.size(0);
269 
270     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
271         input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
272           /* get raw pointers */
273           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
274           const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
275           at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
276             for (const auto b : c10::irange(start, end)) {
277               adaptive_avg_pool3d_backward_out_frame<scalar_t>(
278                   gradInput_data + b * sizeD * isizeT * isizeH * isizeW,
279                   gradOutput_data + b * sizeD * osizeT * osizeH * osizeW,
280                   sizeD,
281                   isizeT,
282                   isizeH,
283                   isizeW,
284                   osizeT,
285                   osizeH,
286                   osizeW);
287             }
288           });
289     });
290   }
291   return gradInput;
292 }
293 
294 } // namespace
295 
adaptive_avg_pool3d_out_cpu(const Tensor & input,IntArrayRef output_size,Tensor & output)296 Tensor& adaptive_avg_pool3d_out_cpu(const Tensor& input,
297     IntArrayRef output_size,
298     Tensor& output) {
299   adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
300   return output;
301 }
302 
adaptive_avg_pool3d_cpu(Tensor const & input,IntArrayRef output_size)303 Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) {
304   auto output = at::empty({0}, input.options());
305   adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
306   return output;
307 }
308 
adaptive_avg_pool3d_symint(Tensor const & input,SymIntArrayRef output_size)309 Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_size) {
310   TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
311   TORCH_CHECK(
312         (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
313         "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
314         "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}");
315 
316   if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1) {
317     // in this case, adaptive pooling is just computing mean over hw
318     // dimensions, which can be done more efficiently
319     Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true);
320     if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d) {
321       // assert ndim == 5, since ndim = 4 doesn't give channels_last
322       const auto n = input.sym_size(0);
323       const auto c = input.sym_size(1);
324       out.as_strided__symint({n, c, 1, 1, 1}, {c, 1, c, c, c});
325     }
326     return out;
327   } else {
328     return _adaptive_avg_pool3d_symint(input, output_size);
329   }
330 }
331 
adaptive_avg_pool3d_backward_out_cpu(const Tensor & gradOutput_,const Tensor & input,Tensor & gradInput)332 Tensor& adaptive_avg_pool3d_backward_out_cpu(const Tensor& gradOutput_,
333     const Tensor& input,
334     Tensor& gradInput) {
335   gradInput.resize_as_(input).zero_();
336   adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
337   return gradInput;
338 }
339 
adaptive_avg_pool3d_backward_cpu(const Tensor & gradOutput_,const Tensor & input)340 Tensor adaptive_avg_pool3d_backward_cpu(const Tensor& gradOutput_,
341     const Tensor& input) {
342   auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
343   adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
344   return gradInput;
345 }
346 
347 } // namespace at::native
348