xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FractionalMaxPool2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/FractionalMaxPooling.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/fractional_max_pool2d_backward_native.h>
14 #include <ATen/ops/fractional_max_pool2d_native.h>
15 #endif
16 
17 namespace at {
18 
19 namespace meta {
TORCH_META_FUNC(fractional_max_pool2d)20 TORCH_META_FUNC(fractional_max_pool2d) (
21   const at::Tensor& input,
22   IntArrayRef pool_size,
23   IntArrayRef output_size,
24   const at::Tensor& randomSamples
25 ) {
26   TORCH_CHECK(
27       pool_size.size() == 2,
28       "fractional_max_pool2d: kernel_size must either be a single Int or tuple of Ints")
29   TORCH_CHECK(
30       output_size.size() == 2,
31       "fractional_max_pool2d: output_size must either be a single Int or tuple of Ints")
32   int64_t numBatch = 1;
33   int64_t planeDim = 0;
34   int64_t heightDim = 1;
35   int64_t widthDim = 2;
36   int64_t outputH = output_size[0];
37   int64_t outputW = output_size[1];
38   int64_t poolSizeH = pool_size[0];
39   int64_t poolSizeW = pool_size[1];
40 
41   int64_t ndims = input.ndimension();
42   TORCH_CHECK(ndims == 3 || ndims == 4,
43               "fractional_max_pool2d(): Expected 3D or 4D tensor, but got: ", input.sizes());
44   for (const auto i : c10::irange(1, ndims)) {
45     TORCH_CHECK(input.size(i) > 0,
46                 "fractional_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, but got",
47                 input.sizes(), " with dimension ", i, " being empty.");
48   }
49 
50 
51   if (ndims == 4) {
52     numBatch = input.size(0);
53     planeDim++;
54     heightDim++;
55     widthDim++;
56   }
57 
58   /* sizes */
59   int64_t numPlanes = input.size(planeDim);
60   int64_t inputH = input.size(heightDim);
61   auto inputW = input.size(widthDim);
62 
63   TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
64     "fractional_max_pool2d(): pool height ", poolSizeH,
65     " too large relative to input height ", inputH);
66   TORCH_CHECK(outputW + poolSizeW - 1 <= inputW,
67     "fractional_max_pool2d(): pool width ", poolSizeW,
68     " too large relative to input width ", inputW);
69 
70   if (ndims == 3) {
71     set_output_raw_strided(0, {numPlanes, outputH, outputW}, {}, input.options());
72     /* indices will contain the locations for each output point */
73     set_output_raw_strided(1, {numPlanes, outputH, outputW}, {}, input.options().dtype(kLong));
74   } else {
75     set_output_raw_strided(0, {numBatch, numPlanes, outputH, outputW}, {}, input.options());
76     /* indices will contain the locations for each output point */
77     set_output_raw_strided(1, {numBatch, numPlanes, outputH, outputW}, {}, input.options().dtype(kLong));
78   }
79 }
80 
TORCH_META_FUNC(fractional_max_pool2d_backward)81 TORCH_META_FUNC(fractional_max_pool2d_backward)(
82   const at::Tensor& gradOutput_,
83   const at::Tensor& input,
84   IntArrayRef pool_size /* unused */,
85   IntArrayRef output_size,
86   const at::Tensor& indices) {
87 
88   int64_t numBatch = 1;
89   int planeDim = 0;
90   int heightDim = 1;
91   int widthDim = 2;
92 
93   auto outputH = output_size[0];
94   auto outputW = output_size[1];
95 
96   auto ndims = input.ndimension();
97   if (ndims == 4) {
98     numBatch = input.size(0);
99     planeDim = 1;
100     heightDim++;
101     widthDim++;
102   }
103 
104   /* sizes */
105   auto numPlanes = input.size(planeDim);
106   auto inputH = input.size(heightDim);
107   auto inputW = input.size(widthDim);
108 
109   /* get contiguous gradOutput */
110   auto gradOutput = gradOutput_.contiguous();
111 
112   TORCH_CHECK(outputW == gradOutput.size(widthDim),
113     "fractional_max_pool2d_backward(): gradOutput width unexpected");
114   TORCH_CHECK(outputH == gradOutput.size(heightDim),
115     "fractional_max_pool2d_backward(): gradOutput height unexpected");
116 
117   /* resize */
118   if (ndims == 3) {
119     set_output_raw_strided(0, {numPlanes, inputH, inputW}, {}, input.options());
120   } else {
121     set_output_raw_strided(0, {numBatch, numPlanes, inputH, inputW}, {}, input.options());
122   }
123 }
124 } // namespace meta
125 
126 namespace native {
127 namespace {
128 
129 template <typename scalar_t>
fractional_max_pool2d_out_single_batch_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int numPlanes,int inputW,int inputH,int outputW,int outputH,int poolSizeW,int poolSizeH)130 static void fractional_max_pool2d_out_single_batch_frame(
131   const scalar_t* input,
132   scalar_t* output,
133   int64_t* indices,
134   const scalar_t* randomSamples,
135   int numPlanes,
136   int inputW, int inputH,
137   int outputW, int outputH,
138   int poolSizeW, int poolSizeH) {
139   at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
140     for (const auto plane : c10::irange(start, end)) {
141       /* each plane contains 2 random samples, one for W and one for H */
142       const scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
143 
144       /* Generate interval sequence */
145       auto sequenceW = generate_intervals<scalar_t>(
146           randomSamplesForPlane[0], inputW, outputW, poolSizeW);
147       auto sequenceH = generate_intervals<scalar_t>(
148           randomSamplesForPlane[1], inputH, outputH, poolSizeH);
149 
150       /* loop over output */
151       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
152       int h, w;
153 
154       const scalar_t* inputForPlane = input + plane * inputW * inputH;
155       scalar_t* outputForPlane = output + plane * outputW * outputH;
156       int64_t* indicesForPlane = indices + plane * outputW * outputH;
157 
158       for (h = 0; h < outputH; ++h) {
159         int inputHStart = sequenceH[h];
160 
161         for (w = 0; w < outputW; ++w) {
162           int inputWStart = sequenceW[w];
163 
164           int h2 = inputHStart, w2 = inputWStart;
165           scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
166           int64_t maxIndex = h2 * inputW + w2;
167 
168           for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
169             for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
170               AT_ASSERT(h2 >= 0 && h2 < inputH);
171               AT_ASSERT(w2 >= 0 && w2 < inputW);
172 
173               int planeIndex = h2 * inputW + w2;
174               scalar_t val = inputForPlane[planeIndex];
175               if (val > maxVal || std::isnan(val)) {
176                 maxVal = val;
177                 maxIndex = planeIndex;
178               }
179             }
180           }
181 
182           outputForPlane[h * outputW + w] = maxVal;
183           indicesForPlane[h * outputW + w] = maxIndex;
184         }
185       }
186     }
187   });
188 }
189 
190 template <typename scalar_t>
fractional_max_pool2d_out_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int numBatch,int numPlanes,int inputW,int inputH,int outputW,int outputH,int poolSizeW,int poolSizeH)191 static void fractional_max_pool2d_out_frame(
192   const scalar_t* input,
193   scalar_t* output,
194   int64_t* indices,
195   const scalar_t* randomSamples,
196   int numBatch, int numPlanes,
197   int inputW, int inputH,
198   int outputW, int outputH,
199   int poolSizeW, int poolSizeH) {
200     if(numBatch == 1) {
201       fractional_max_pool2d_out_single_batch_frame<scalar_t>(
202         input,
203         output,
204         indices,
205         randomSamples,
206         numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
207       );
208       return;
209     }
210     at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
211       for (const auto batch : c10::irange(start, end)) {
212         fractional_max_pool2d_out_single_batch_frame<scalar_t>(
213           input + batch * numPlanes * inputH * inputW,
214           output + batch * numPlanes * outputH * outputW,
215           indices + batch * numPlanes * outputH * outputW,
216           randomSamples + batch * numPlanes * 2,
217           numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
218       }
219     });
220   }
221 
222 template <typename scalar_t>
fractional_max_pool2d_backward_out_single_batch_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int numPlanes,int inputW,int inputH,int outputW,int outputH)223 static void fractional_max_pool2d_backward_out_single_batch_frame(
224   scalar_t* gradInput,
225   const scalar_t* gradOutput,
226   const int64_t* indices,
227   int numPlanes,
228   int inputW, int inputH,
229   int outputW, int outputH) {
230   at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
231     for (const auto plane : c10::irange(start, end)) {
232       scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
233       const scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
234       const int64_t* indicesForPlane = indices + plane * outputW * outputH;
235 
236       for (int h = 0; h < outputH; ++h) {
237         for (int w = 0; w < outputW; ++w) {
238           int outputIndex = h * outputW + w;
239           int64_t index = indicesForPlane[outputIndex];
240           AT_ASSERT(index >= 0 && index < static_cast<int64_t>(inputW) * inputH);
241 
242           gradInputForPlane[index] += gradOutputForPlane[outputIndex];
243         }
244       }
245     }
246   });
247 }
248 
249 template <typename scalar_t>
fractional_max_pool2d_backward_out_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int numBatch,int numPlanes,int inputW,int inputH,int outputW,int outputH)250 static void fractional_max_pool2d_backward_out_frame(
251   scalar_t* gradInput,
252   const scalar_t* gradOutput,
253   const int64_t* indices,
254   int numBatch, int numPlanes,
255   int inputW, int inputH,
256   int outputW, int outputH) {
257     if(numBatch == 1) {
258       fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
259         gradInput, gradOutput, indices,
260         numPlanes,
261         inputW, inputH, outputW, outputH
262       );
263       return;
264     }
265     at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
266       for (const auto batch : c10::irange(start, end)) {
267         fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
268           gradInput + batch * numPlanes * inputH * inputW,
269           gradOutput + batch * numPlanes * outputH * outputW,
270           indices + batch * numPlanes * outputH * outputW,
271           numPlanes, inputW, inputH, outputW, outputH);
272       }
273     });
274 }
275 
276 } // anonymous namespace
277 
TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu)278 TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) (
279   const at::Tensor& input_,
280   IntArrayRef pool_size,
281   IntArrayRef output_size,
282   const at::Tensor& randomSamples_,
283   const at::Tensor& output,
284   const at::Tensor& indices) {
285 
286   fractional_max_pool_check_shape</*ndim*/ 2>(input_, randomSamples_);
287 
288   if (output.numel() == 0) {
289     return;
290   }
291 
292   int64_t numBatch = 1;
293   int64_t planeDim = 0;
294   int64_t heightDim = 1;
295   int64_t widthDim = 2;
296   int64_t outputH = output_size[0]; // output.size(heightDim)
297   int64_t outputW = output_size[1]; // output.size(widthDim)
298   int64_t poolSizeH = pool_size[0];
299   int64_t poolSizeW = pool_size[1];
300 
301   /* get contiguous input and samples */
302   auto input = input_.contiguous();
303   auto randomSamples = randomSamples_.contiguous();
304 
305   int64_t ndims = input.ndimension();
306 
307   if (ndims == 4) {
308     numBatch = input.size(0);
309     planeDim++;
310     heightDim++;
311     widthDim++;
312   }
313 
314   /* sizes */
315   int64_t numPlanes = input.size(planeDim);
316   int64_t inputH = input.size(heightDim);
317   int64_t inputW = input.size(widthDim);
318 
319   AT_DISPATCH_FLOATING_TYPES_AND2(
320     kBFloat16,
321     kHalf,
322     input.scalar_type(),
323     "fractional_max_pool2d_out_frame", [&] {
324       auto input_data = input.const_data_ptr<scalar_t>();
325       auto output_data = output.data_ptr<scalar_t>();
326       auto indices_data = indices.data_ptr<int64_t>();
327       auto randomSamples_data = randomSamples.const_data_ptr<scalar_t>();
328       fractional_max_pool2d_out_frame<scalar_t>(
329         input_data,
330         output_data,
331         indices_data,
332         randomSamples_data,
333         numBatch, numPlanes,
334         inputW, inputH,
335         outputW, outputH,
336         poolSizeW, poolSizeH);
337     }
338   );
339 }
340 
TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu)341 TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
342   const at::Tensor& gradOutput_,
343   const at::Tensor& input,
344   IntArrayRef pool_size,
345   IntArrayRef output_size,
346   const at::Tensor& indices,
347   const at::Tensor& gradInput) {
348 
349   gradInput.zero_();
350 
351   int64_t numBatch = 1;
352   int planeDim = 0;
353   int heightDim = 1;
354   int widthDim = 2;
355 
356   auto outputH = output_size[0];
357   auto outputW = output_size[1];
358 
359   auto ndims = input.ndimension();
360   if (ndims == 4) {
361     numBatch = input.size(0);
362     planeDim = 1;
363     heightDim++;
364     widthDim++;
365   }
366 
367   /* sizes */
368   auto numPlanes = input.size(planeDim);
369   auto inputH = input.size(heightDim);
370   auto inputW = input.size(widthDim);
371 
372   /* get contiguous gradOutput */
373   auto gradOutput = gradOutput_.contiguous();
374 
375   /* backprop */
376   AT_DISPATCH_FLOATING_TYPES_AND2(
377     kBFloat16,
378     kHalf,
379     input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
380       auto gradInput_data = gradInput.data_ptr<scalar_t>();
381       auto gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
382       auto indices_data = indices.const_data_ptr<int64_t>();
383       fractional_max_pool2d_backward_out_frame<scalar_t>(
384         gradInput_data,
385         gradOutput_data,
386         indices_data,
387         numBatch, numPlanes,
388         inputW, inputH,
389         outputW, outputH
390       );
391     }
392   );
393 }
394 
395 } // at::native
396 } // at
397