1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <c10/util/irange.h>
5
6 namespace at::native {
7
8 template<typename scalar_t>
generate_intervals(scalar_t sample,int64_t inputSize,int64_t outputSize,int64_t poolSize)9 inline std::vector<int64_t> generate_intervals(
10 scalar_t sample,
11 int64_t inputSize,
12 int64_t outputSize,
13 int64_t poolSize) {
14 std::vector<int64_t> sequence(outputSize);
15 if (outputSize > 1) {
16 scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
17 static_cast<scalar_t>(outputSize - 1);
18
19 for (const auto i : c10::irange(outputSize - 1)) {
20 sequence[i] =
21 static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
22 }
23 }
24 if (outputSize > 0) {
25 sequence[outputSize - 1] = inputSize - poolSize;
26 }
27 return sequence;
28 }
29
30 template <int64_t ndim>
fractional_max_pool_check_shape(const Tensor & input,const Tensor & randomSamples)31 inline void fractional_max_pool_check_shape(
32 const Tensor& input,
33 const Tensor& randomSamples) {
34
35 TORCH_CHECK(
36 input.scalar_type() == randomSamples.scalar_type(),
37 "Expect _random_samples to have the same dtype as input");
38
39 int64_t ndimension = randomSamples.ndimension();
40 TORCH_CHECK(
41 ndimension == 3,
42 "Expect _random_samples to have 3 dimensions, got ", ndimension);
43
44 int64_t N = randomSamples.size(0);
45 int64_t C = randomSamples.size(1);
46 int64_t D = randomSamples.size(2);
47
48 int64_t input_batch = 0, input_channel = 0;
49 if (ndim == 2) {
50 // fractional_max_pool2d
51 if (input.ndimension() == 3) {
52 input_batch = 1;
53 input_channel = input.size(0);
54 } else {
55 input_batch = input.size(0);
56 input_channel = input.size(1);
57 }
58 } else {
59 // factional_max_pool3d
60 if (input.ndimension() == 4) {
61 input_batch = 1;
62 input_channel = input.size(0);
63 } else {
64 input_batch = input.size(0);
65 input_channel = input.size(1);
66 }
67 }
68
69 TORCH_CHECK(
70 N >= input_batch,
71 "Expect _random_samples.size(0) no less then input batch size.");
72 TORCH_CHECK(
73 C == input_channel,
74 "Expect _random_samples.size(1) equals to input channel size.");
75 TORCH_CHECK(
76 D == ndim,
77 "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
78 }
79
80 } // namespace at::native
81