xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FractionalMaxPool3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/FractionalMaxPooling.h>
7 
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/fractional_max_pool3d_backward_native.h>
16 #include <ATen/ops/fractional_max_pool3d_native.h>
17 #endif
18 
19 
20 namespace at::meta {
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)21 TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
22   const at::Tensor& input_,
23   IntArrayRef pool_size,
24   IntArrayRef output_size,
25   const at::Tensor& randomSamples
26 ) {
27   TORCH_CHECK(
28       pool_size.size() == 3,
29       "fractional_max_pool3d: kernel_size must either be a single Int or tuple of three Ints")
30   TORCH_CHECK(
31       output_size.size() == 3,
32       "fractional_max_pool3d: output_size must either be a single Int or tuple of three Ints")
33   int64_t outputT = output_size[0];
34   int64_t outputH = output_size[1];
35   int64_t outputW = output_size[2];
36   int64_t poolSizeT = pool_size[0];
37   int64_t poolSizeH = pool_size[1];
38   int64_t poolSizeW = pool_size[2];
39 
40   int64_t numBatch = 1;
41   int64_t planeDim = 0;
42   int64_t timeDim = 1;
43   int64_t heightDim = 2;
44   int64_t widthDim = 3;
45 
46   int64_t ndims = input_.ndimension();
47   TORCH_CHECK(ndims == 4 || ndims == 5,
48               "fractional_max_pool3d_out(): Expected 4D or 5D tensor, but got: ",
49               input_.sizes());
50   for (const auto i : c10::irange(1, ndims)) {
51     TORCH_CHECK(input_.size(i) > 0,
52                 "fractional_max_pool3d_out(): Expected input to have non-zero size for non-batch dimensions, but got",
53                 input_.sizes(), " with dimension ", i, " being empty.");
54   }
55 
56   if (ndims == 5) {
57     numBatch = input_.size(0);
58     planeDim++;
59     timeDim++;
60     heightDim++;
61     widthDim++;
62   }
63 
64   /* sizes */
65   int64_t numPlanes = input_.size(planeDim);
66   int64_t inputT = input_.size(timeDim);
67   int64_t inputH = input_.size(heightDim);
68   int64_t inputW = input_.size(widthDim);
69 
70   TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
71            "fractional_max_pool3d_out(): pool time ", poolSizeT,
72            " too large relative to input time ", inputT);
73   TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
74            "fractional_max_pool3d_out(): pool width ", poolSizeW,
75            " too large relative to input width ", inputW);
76   TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
77            "fractional_max_pool3d_out(): pool height ", poolSizeH,
78            " too large relative to input height ", inputH);
79 
80   if (ndims == 4) {
81     /* resize output */
82     set_output_raw_strided(0, {numPlanes, outputT, outputH, outputW}, {}, input_.options());
83     /* indices will contain the locations for each output point */
84     set_output_raw_strided(1, {numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
85   } else {
86     set_output_raw_strided(0, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options());
87     /* indices will contain the locations for each output point */
88     set_output_raw_strided(1, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
89   }
90 
91   return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
92                                                          .set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
93                                                          .set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
94 }
95 
96 } // namespace at::meta
97 
98 namespace at::native {
99 namespace {
100 
101 template<typename scalar_t>
fractional_max_pool3d_out_single_batch_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW,int64_t poolSizeT,int64_t poolSizeH,int64_t poolSizeW)102 static void fractional_max_pool3d_out_single_batch_frame(
103   const scalar_t* input,
104   scalar_t* output,
105   int64_t* indices,
106   const scalar_t* randomSamples,
107   int64_t numPlanes,
108   int64_t inputT, int64_t inputH, int64_t inputW,
109   int64_t outputT, int64_t outputH, int64_t outputW,
110   int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
111 
112   at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
113     for (const auto plane : c10::irange(start, end)) {
114       /* each plane contains 3 random samples,
115          one for T, one for W, and one for H */
116       const scalar_t* randomSamplesForPlane = randomSamples + plane * 3;
117 
118       /* Generate interval sequence */
119       auto sequenceT = generate_intervals<scalar_t>(
120           randomSamplesForPlane[0], inputT, outputT, poolSizeT);
121       auto sequenceH = generate_intervals<scalar_t>(
122           randomSamplesForPlane[1], inputH, outputH, poolSizeH);
123       auto sequenceW = generate_intervals<scalar_t>(
124           randomSamplesForPlane[2], inputW, outputW, poolSizeW);
125 
126       /* loop over output */
127       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
128       int64_t t, h, w;
129 
130       const scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
131       scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
132       int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
133 
134       for (t = 0; t < outputT; ++t) {
135         int64_t inputTStart = sequenceT[t];
136 
137         for (h = 0; h < outputH; ++h) {
138           int64_t inputHStart = sequenceH[h];
139 
140           for (w = 0; w < outputW; ++w) {
141             int64_t inputWStart = sequenceW[w];
142 
143             int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
144             scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
145             int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2;
146 
147             for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) {
148               for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
149                 for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
150                   AT_ASSERT(t2 >= 0 && t2 < inputT);
151                   AT_ASSERT(h2 >= 0 && h2 < inputH);
152                   AT_ASSERT(w2 >= 0 && w2 < inputW);
153 
154                   int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2;
155                   scalar_t val = inputForPlane[planeIndex];
156                   if (val > maxVal || std::isnan(val)) {
157                     maxVal = val;
158                     maxIndex = planeIndex;
159                   }
160                 }
161               }
162             }
163 
164             outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal;
165             indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex;
166           }
167         }
168       }
169     }
170   });
171 }
172 
173 template<typename scalar_t>
fractional_max_pool3d_out_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int64_t numBatch,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW,int64_t poolSizeT,int64_t poolSizeH,int64_t poolSizeW)174 static void fractional_max_pool3d_out_frame(
175   const scalar_t* input,
176   scalar_t* output,
177   int64_t* indices,
178   const scalar_t* randomSamples,
179   int64_t numBatch, int64_t numPlanes,
180   int64_t inputT, int64_t inputH, int64_t inputW,
181   int64_t outputT, int64_t outputH, int64_t outputW,
182   int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
183     if(numBatch == 1) {
184       fractional_max_pool3d_out_single_batch_frame<scalar_t>(
185         input, output, indices, randomSamples,
186         numPlanes,
187         inputT, inputH, inputW,
188         outputT, outputH, outputW,
189         poolSizeT, poolSizeH, poolSizeW
190       );
191       return;
192     }
193 
194     at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
195       for (const auto batch : c10::irange(start, end)) {
196         fractional_max_pool3d_out_single_batch_frame<scalar_t>(
197           input + batch * numPlanes * inputW * inputH * inputT,
198           output + batch * numPlanes * outputW * outputH * outputT,
199           indices + batch * numPlanes * outputW * outputH * outputT,
200           randomSamples + batch * numPlanes * 3,
201           numPlanes,
202           inputT, inputH, inputW,
203           outputT, outputH, outputW,
204           poolSizeT, poolSizeH, poolSizeW
205         );
206       }
207     });
208   }
209 
210 } // anonymous namespace
211 
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)212 TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
213   const at::Tensor& input_,
214   int64_t poolSizeT,
215   int64_t poolSizeH,
216   int64_t poolSizeW,
217   int64_t outputT,
218   int64_t outputH,
219   int64_t outputW,
220   const at::Tensor& randomSamples_,
221   int64_t numBatch,
222   int64_t numPlanes,
223   int64_t inputT,
224   int64_t inputH,
225   int64_t inputW,
226   const at::Tensor& output,
227   const at::Tensor& indices) {
228 
229   fractional_max_pool_check_shape</*ndim*/ 3>(input_, randomSamples_);
230 
231   if (output.numel() == 0) {
232     return;
233   }
234 
235   /* get contiguous input and samples */
236   auto input = input_.contiguous();
237   auto randomSamples = randomSamples_.contiguous();
238 
239   AT_DISPATCH_FLOATING_TYPES_AND2(
240     kBFloat16,
241     kHalf,
242     input.scalar_type(),
243     "fractional_max_pool3d_out_frame",
244     [&] {
245       fractional_max_pool3d_out_frame<scalar_t>(
246         input.const_data_ptr<scalar_t>(),
247         output.data_ptr<scalar_t>(),
248         indices.data_ptr<int64_t>(),
249         randomSamples.const_data_ptr<scalar_t>(),
250         numBatch, numPlanes,
251         inputT, inputH, inputW,
252         outputT, outputH, outputW,
253         poolSizeT, poolSizeH, poolSizeW
254       );
255     }
256   );
257 }
258 
259 namespace {
260 
261 template<typename scalar_t>
fractional_max_pool3d_backward_out_single_batch_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW)262 static void fractional_max_pool3d_backward_out_single_batch_frame(
263   scalar_t* gradInput,
264   const scalar_t* gradOutput,
265   const int64_t* indices,
266   int64_t numPlanes,
267   int64_t inputT, int64_t inputH, int64_t inputW,
268   int64_t outputT, int64_t outputH, int64_t outputW) {
269 
270   at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
271     for (const auto plane : c10::irange(start, end)) {
272       scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW;
273       const scalar_t* gradOutputForPlane = gradOutput +
274                   plane * outputT * outputH * outputW;
275       const int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
276 
277       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
278       int64_t h, w, t;
279       for (t = 0; t < outputT; ++t) {
280         for (h = 0; h < outputH; ++h) {
281           for (w = 0; w < outputW; ++w) {
282             int64_t outputIndex = t * outputH * outputW + h * outputW + w;
283             int64_t index = indicesForPlane[outputIndex];
284             AT_ASSERT(index >= 0 && index < inputT * inputH * inputW);
285             gradInputForPlane[index] += gradOutputForPlane[outputIndex];
286           }
287         }
288       }
289     }
290   });
291 }
292 
293 template<typename scalar_t>
fractional_max_pool3d_backward_out_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int64_t numBatch,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW)294 static void fractional_max_pool3d_backward_out_frame(
295   scalar_t* gradInput,
296   const scalar_t* gradOutput,
297   const int64_t* indices,
298   int64_t numBatch, int64_t numPlanes,
299   int64_t inputT, int64_t inputH, int64_t inputW,
300   int64_t outputT, int64_t outputH, int64_t outputW) {
301     if(numBatch == 1) {
302       fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
303         gradInput, gradOutput, indices,
304         numPlanes,
305         inputT, inputH, inputW,
306         outputT, outputH, outputW
307       );
308       return;
309     }
310 
311     at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
312       for (const auto batch : c10::irange(start, end)) {
313         fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
314           gradInput + batch * numPlanes * inputW * inputH * inputT,
315           gradOutput + batch * numPlanes * outputW * outputH * outputT,
316           indices + batch * numPlanes * outputW * outputH * outputT,
317           numPlanes,
318           inputT, inputH, inputW,
319           outputT, outputH, outputW
320         );
321       }
322     });
323   }
324 
325 
fractional_max_pool3d_backward_out_cpu_template(const Tensor & input,const Tensor & gradOutput_,Tensor & gradInput,IntArrayRef output_size,IntArrayRef pool_size,const Tensor & indices)326 void fractional_max_pool3d_backward_out_cpu_template(
327   const Tensor& input,
328   const Tensor& gradOutput_,
329   Tensor& gradInput,
330   IntArrayRef output_size,
331   IntArrayRef pool_size /* unused */,
332   const Tensor& indices) {
333 
334   int64_t outputT = output_size[0];
335   int64_t outputH = output_size[1];
336   int64_t outputW = output_size[2];
337 
338   int64_t numBatch = 1;
339   int64_t planeDim = 0;
340   int64_t timeDim = 1;
341   int64_t heightDim = 2;
342   int64_t widthDim = 3;
343 
344   int64_t ndims = input.ndimension();
345   if (ndims == 5) {
346     numBatch = input.size(0);
347     planeDim = 1;
348     heightDim++;
349     widthDim++;
350     timeDim++;
351   }
352 
353   /* sizes */
354   int64_t numPlanes = input.size(planeDim);
355   int64_t inputT = input.size(timeDim);
356   int64_t inputH = input.size(heightDim);
357   int64_t inputW = input.size(widthDim);
358 
359   TORCH_CHECK(outputT == gradOutput_.size(timeDim),
360            "fractional_max_pool3d_backward_out(): gradOutput time unexpected");
361   TORCH_CHECK(outputH == gradOutput_.size(heightDim),
362            "fractional_max_pool3d_backward_out(): ",
363            "gradOutput height unexpected");
364   TORCH_CHECK(outputW == gradOutput_.size(widthDim),
365            "fractional_max_pool3d_backward_out(): gradOutput width unexpected");
366 
367   /* get contiguous gradOutput */
368   auto gradOutput = gradOutput_.contiguous();
369 
370   /* resize */
371   gradInput.resize_as_(input);
372   gradInput.zero_();
373 
374   /* backprop */
375   AT_DISPATCH_FLOATING_TYPES_AND2(
376     kBFloat16,
377     kHalf,
378     input.scalar_type(),
379     "fractional_max_pool3d_backward_out_frame",
380     [&]{
381       fractional_max_pool3d_backward_out_frame<scalar_t>(
382         gradInput.data_ptr<scalar_t>(),
383         gradOutput.const_data_ptr<scalar_t>(),
384         indices.const_data_ptr<int64_t>(),
385         numBatch, numPlanes,
386         inputT, inputH, inputW,
387         outputT, outputH, outputW
388       );
389     }
390   );
391 }
392 
393 }// anonymous namespace
394 
fractional_max_pool3d_backward_out_cpu(const at::Tensor & gradOutput_,const at::Tensor & input,IntArrayRef pool_size,IntArrayRef output_size,const at::Tensor & indices,at::Tensor & gradInput)395 Tensor& fractional_max_pool3d_backward_out_cpu(const at::Tensor& gradOutput_,
396   const at::Tensor& input,
397   IntArrayRef pool_size,
398   IntArrayRef output_size,
399   const at::Tensor& indices,
400   at::Tensor& gradInput) {
401   fractional_max_pool3d_backward_out_cpu_template(
402     input,
403     gradOutput_,
404     gradInput,
405     output_size,
406     pool_size,
407     indices);
408   return gradInput;
409 }
410 
fractional_max_pool3d_backward_cpu(const at::Tensor & gradOutput_,const at::Tensor & input,IntArrayRef pool_size,IntArrayRef output_size,const at::Tensor & indices)411 Tensor fractional_max_pool3d_backward_cpu(
412   const at::Tensor& gradOutput_,
413   const at::Tensor& input,
414   IntArrayRef pool_size,
415   IntArrayRef output_size,
416   const at::Tensor& indices) {
417   Tensor gradInput = at::empty({0}, input.options());
418   fractional_max_pool3d_backward_out_cpu_template(
419     input,
420     gradOutput_,
421     gradInput,
422     output_size,
423     pool_size,
424     indices);
425   return gradInput;
426 }
427 
428 } // namespace at::native
429