xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FractionalMaxPooling.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <c10/util/irange.h>
5 
6 namespace at::native {
7 
8 template<typename scalar_t>
generate_intervals(scalar_t sample,int64_t inputSize,int64_t outputSize,int64_t poolSize)9 inline std::vector<int64_t> generate_intervals(
10     scalar_t sample,
11     int64_t inputSize,
12     int64_t outputSize,
13     int64_t poolSize) {
14   std::vector<int64_t> sequence(outputSize);
15   if (outputSize > 1) {
16     scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
17       static_cast<scalar_t>(outputSize - 1);
18 
19     for (const auto i : c10::irange(outputSize - 1)) {
20       sequence[i] =
21         static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
22     }
23   }
24   if (outputSize > 0) {
25     sequence[outputSize - 1] = inputSize - poolSize;
26   }
27   return sequence;
28 }
29 
30 template <int64_t ndim>
fractional_max_pool_check_shape(const Tensor & input,const Tensor & randomSamples)31 inline void fractional_max_pool_check_shape(
32     const Tensor& input,
33     const Tensor& randomSamples) {
34 
35   TORCH_CHECK(
36       input.scalar_type() == randomSamples.scalar_type(),
37       "Expect _random_samples to have the same dtype as input");
38 
39   int64_t ndimension = randomSamples.ndimension();
40   TORCH_CHECK(
41       ndimension == 3,
42       "Expect _random_samples to have 3 dimensions, got ", ndimension);
43 
44   int64_t N = randomSamples.size(0);
45   int64_t C = randomSamples.size(1);
46   int64_t D = randomSamples.size(2);
47 
48   int64_t input_batch = 0, input_channel = 0;
49   if (ndim == 2) {
50     // fractional_max_pool2d
51     if (input.ndimension() == 3) {
52       input_batch = 1;
53       input_channel = input.size(0);
54     } else {
55       input_batch = input.size(0);
56       input_channel = input.size(1);
57     }
58   } else {
59     // factional_max_pool3d
60     if (input.ndimension() == 4) {
61       input_batch = 1;
62       input_channel = input.size(0);
63     } else {
64       input_batch = input.size(0);
65       input_channel = input.size(1);
66     }
67   }
68 
69   TORCH_CHECK(
70       N >= input_batch,
71       "Expect _random_samples.size(0) no less then input batch size.");
72   TORCH_CHECK(
73       C == input_channel,
74       "Expect _random_samples.size(1) equals to input channel size.");
75   TORCH_CHECK(
76       D == ndim,
77       "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
78 }
79 
80 } // namespace at::native
81