1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/native/MaxPooling.h>
7 #include <c10/util/irange.h>
8
9 namespace at::native {
10
11 namespace {
12
13 template <typename scalar_t>
max_pool1d_kernel(scalar_t * C10_RESTRICT op,const scalar_t * C10_RESTRICT ip,const PoolingParams1D & p)14 inline void max_pool1d_kernel(
15 scalar_t* C10_RESTRICT op,
16 const scalar_t* C10_RESTRICT ip,
17 const PoolingParams1D& p) {
18 for (const auto kj : c10::irange(p.KW)) {
19 int64_t oj = p.valid_output_start(kj);
20 int64_t oe = p.valid_output_end(kj);
21 int64_t ij = p.index(kj, oj);
22 for (; oj < oe; ++oj, ij += p.SJ) {
23 scalar_t val = ip[ij];
24 bool update_max = std::isnan(val) || op[oj] < val;
25 op[oj] = update_max ? val : op[oj];
26 }
27 }
28 }
29
max_pool1d_impl(Tensor & output,const Tensor & input,const PoolingParams1D & p)30 void max_pool1d_impl(
31 Tensor& output,
32 const Tensor& input,
33 const PoolingParams1D& p) {
34 AT_DISPATCH_FLOATING_TYPES_AND2(
35 ScalarType::BFloat16,
36 ScalarType::Half,
37 input.scalar_type(),
38 "max_pool1d_impl",
39 [&] {
40 const Tensor in = input.contiguous();
41 scalar_t* const OP = output.data_ptr<scalar_t>();
42 const scalar_t* const IP = in.const_data_ptr<scalar_t>();
43
44 // Value used for padding
45 scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
46 ? -std::numeric_limits<scalar_t>::infinity()
47 : std::numeric_limits<scalar_t>::lowest();
48
49 at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
50 for (const auto it : c10::irange(begin, end)) {
51 scalar_t* op = OP + it * p.OW;
52 const scalar_t* ip = IP + it * p.IW;
53 std::fill_n(op, p.OW, FILL);
54 max_pool1d_kernel(op, ip, p);
55 }
56 });
57 });
58 }
59
60 } // namespace
61
62 REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);
63
64 } // namespace at::native
65