1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/Pool.h>
7
8 namespace at::native {
9
check_max_pool1d(const Tensor & self,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)10 inline void check_max_pool1d(
11 const Tensor& self,
12 IntArrayRef kernel_size,
13 IntArrayRef stride,
14 IntArrayRef padding,
15 IntArrayRef dilation,
16 bool ceil_mode) {
17
18 TORCH_CHECK(
19 self.dim() == 2 || self.dim() == 3,
20 "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
21 TORCH_CHECK(
22 kernel_size.size() == 1,
23 "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
24 kernel_size.size());
25 TORCH_CHECK(
26 stride.empty() || stride.size() == 1,
27 "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
28 stride.size());
29 TORCH_CHECK(
30 padding.size() == 1,
31 "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
32 padding.size());
33 TORCH_CHECK(
34 dilation.size() == 1,
35 "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
36 dilation.size());
37
38 // If stride=None then set it to kernel_size
39 if (stride.empty()) {
40 stride = kernel_size;
41 }
42
43 TORCH_CHECK(
44 kernel_size[0] > 0,
45 "max_pool1d() kernel_size must be greater than zero, but got ",
46 kernel_size[0]);
47 TORCH_CHECK(
48 stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
49 TORCH_CHECK(
50 padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
51 TORCH_CHECK(
52 padding[0] <= kernel_size[0] / 2,
53 "max_pool1d() padding should be at most half of kernel size, but got padding=",
54 padding[0],
55 " and kernel_size=",
56 kernel_size[0]);
57 TORCH_CHECK(
58 dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
59
60 const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
61 TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
62 }
63
64 // TODO(Heitor) Template by dimension
65 struct PoolingParams1D {
66 int64_t NB; // Number of batches
67 int64_t NC; // Number of channels
68 int64_t IW; // Input width
69 int64_t OW; // Output width
70 int64_t KW; // Kernel width
71 int64_t SJ; // Column stride
72 int64_t PJ; // Column padding
73 int64_t DJ; // Column dilation
74
75 // Return index of input element for the given kernel and output index
indexPoolingParams1D76 inline int64_t index(int64_t kj, int64_t oj) const {
77 return oj * SJ + kj * DJ - PJ;
78 }
79
80 // Return index of first output within bounds for this kernel index
valid_output_startPoolingParams1D81 inline int64_t valid_output_start(int64_t kj) const {
82 int64_t ij = index(kj, 0);;
83 return ij < 0 ? at::divup(-ij, SJ) : 0;
84 }
85
86 // Return index one past last output within bounds for this kernel index
valid_output_endPoolingParams1D87 inline int64_t valid_output_end(int64_t kj) const {
88 int64_t ij = index(kj, OW - 1);
89 return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
90 }
91 };
92
93 using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
94
95 DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
96
97 } // namespace at::native
98