xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/MaxPooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/core/grad_mode.h>
6 #include <ATen/native/DispatchStub.h>
7 #include <ATen/native/MaxPooling.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/max_pool1d_native.h>
15 #include <ATen/ops/max_pool1d_with_indices.h>
16 #include <ATen/ops/quantized_max_pool1d.h>
17 #endif
18 
19 namespace at::native {
20 
21 DEFINE_DISPATCH(max_pool1d_stub);
22 
23 namespace {
24 
max_pool1d_impl(const Tensor & self,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)25 Tensor max_pool1d_impl(
26     const Tensor& self,
27     IntArrayRef kernel_size,
28     IntArrayRef stride,
29     IntArrayRef padding,
30     IntArrayRef dilation,
31     bool ceil_mode) {
32   NoNamesGuard guard;
33 
34   // If stride=None then set it to kernel_size
35   if (stride.empty()) {
36     stride = kernel_size;
37   }
38 
39   const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
40   const int64_t NC = self.size(-2);
41   const int64_t IW = self.size(-1);
42   const int64_t KW = kernel_size[0];
43   const int64_t SJ = stride[0];
44   const int64_t PJ = padding[0];
45   const int64_t DJ = dilation[0];
46 
47   const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
48   Tensor output = at::empty({NB, NC, OW}, self.options());
49 
50   PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
51   max_pool1d_stub(self.device().type(), output, self, params);
52 
53   if (self.dim() == 2) {
54     output.squeeze_(0);
55   }
56 
57   guard.reset();
58   namedinference::propagate_names(output, self);
59 
60   return output;
61 }
62 
63 } // namespace
64 
max_pool1d(const Tensor & self,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)65 Tensor max_pool1d(
66     const Tensor& self,
67     IntArrayRef kernel_size,
68     IntArrayRef stride,
69     IntArrayRef padding,
70     IntArrayRef dilation,
71     bool ceil_mode) {
72 
73   auto ndim = self.ndimension();
74    TORCH_CHECK(
75        (ndim == 2 && self.sym_size(0) != 0 && self.sym_size(1) != 0) ||
76            (ndim == 3 && self.sym_size(1) != 0 && self.sym_size(2) != 0),
77        "max_pool1d: Expected 2D or 3D (batch mode) tensor with optional 0 dim batch size for input, but got:",
78        self.sym_sizes());
79 
80   if (self.is_quantized()) {
81     return at::quantized_max_pool1d(
82         self, kernel_size, stride, padding, dilation, ceil_mode);
83   }
84 
85   check_max_pool1d(self, kernel_size, stride, padding, dilation, ceil_mode);
86   if ((self.requires_grad() && at::GradMode::is_enabled()) ||
87       self._fw_grad(/*level */ 0).defined() ||
88       !self.device().is_cpu() ||
89       isTensorSubclassLike(self)) {
90     // Needs indices for grad and with_indices defines CUDA dispatch
91     return std::get<0>(at::max_pool1d_with_indices(
92         self, kernel_size, stride, padding, dilation, ceil_mode));
93   }
94   return max_pool1d_impl(
95       self, kernel_size, stride, padding, dilation, ceil_mode);
96 }
97 
98 } // namespace at::native
99