xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/MaxPooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/native/MaxPooling.h>
7 #include <c10/util/irange.h>
8 
9 namespace at::native {
10 
11 namespace {
12 
13 template <typename scalar_t>
max_pool1d_kernel(scalar_t * C10_RESTRICT op,const scalar_t * C10_RESTRICT ip,const PoolingParams1D & p)14 inline void max_pool1d_kernel(
15     scalar_t* C10_RESTRICT op,
16     const scalar_t* C10_RESTRICT ip,
17     const PoolingParams1D& p) {
18   for (const auto kj : c10::irange(p.KW)) {
19     int64_t oj = p.valid_output_start(kj);
20     int64_t oe = p.valid_output_end(kj);
21     int64_t ij = p.index(kj, oj);
22     for (; oj < oe; ++oj, ij += p.SJ) {
23       scalar_t val = ip[ij];
24       bool update_max = std::isnan(val) || op[oj] < val;
25       op[oj] = update_max ? val : op[oj];
26     }
27   }
28 }
29 
max_pool1d_impl(Tensor & output,const Tensor & input,const PoolingParams1D & p)30 void max_pool1d_impl(
31     Tensor& output,
32     const Tensor& input,
33     const PoolingParams1D& p) {
34   AT_DISPATCH_FLOATING_TYPES_AND2(
35       ScalarType::BFloat16,
36       ScalarType::Half,
37       input.scalar_type(),
38       "max_pool1d_impl",
39       [&] {
40         const Tensor in = input.contiguous();
41         scalar_t* const OP = output.data_ptr<scalar_t>();
42         const scalar_t* const IP = in.const_data_ptr<scalar_t>();
43 
44         // Value used for padding
45         scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
46             ? -std::numeric_limits<scalar_t>::infinity()
47             : std::numeric_limits<scalar_t>::lowest();
48 
49         at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
50           for (const auto it : c10::irange(begin, end)) {
51             scalar_t* op = OP + it * p.OW;
52             const scalar_t* ip = IP + it * p.IW;
53             std::fill_n(op, p.OW, FILL);
54             max_pool1d_kernel(op, ip, p);
55           }
56         });
57       });
58 }
59 
60 } // namespace
61 
62 REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);
63 
64 } // namespace at::native
65