xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DilatedMaxPool2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/NamedTensor.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/Pool.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/max_pool2d_with_indices_backward_native.h>
13 #include <ATen/ops/max_pool2d_with_indices_native.h>
14 #endif
15 
16 namespace at::meta {
17 using namespace at::native;
TORCH_META_FUNC(max_pool2d_with_indices)18 TORCH_META_FUNC(max_pool2d_with_indices)
19 (const Tensor& input,
20 IntArrayRef kernel_size,
21 IntArrayRef stride,
22 IntArrayRef padding,
23 IntArrayRef dilation,
24 bool ceil_mode) {
25   // #20866, #22032: Guarantee this for the official C++ API?
26   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
27     "max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
28   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
29   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
30 
31   // NB: stride default is not expressible as an integer constant, so we accept
32   // empty stride for this case
33   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
34     "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
35   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
36   const int dW = stride.empty() ? kW :
37                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
38 
39   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
40     "max_pool2d: padding must either be a single int, or a tuple of two ints");
41   const int padH = safe_downcast<int, int64_t>(padding[0]);
42   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
43 
44   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
45     "max_pool2d: dilation must be either a single int, or a tuple of two ints");
46   const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
47   const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
48 
49   const auto memory_format = input.suggest_memory_format();
50   if (memory_format == at::MemoryFormat::ChannelsLast) {
51     TORCH_CHECK(input.ndimension() == 4,
52       "non-empty 4D (batch mode) tensor expected for input with channels_last layout");
53   } else if (memory_format == at::MemoryFormat::Contiguous) {
54     TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
55       "non-empty 3D or 4D (batch mode) tensor expected for input");
56   } else {
57     TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous");
58   }
59 
60   /* sizes */
61   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
62   const int64_t nInputPlane = input.size(-3);
63   const int64_t inputHeight = input.size(-2);
64   const int64_t inputWidth = input.size(-1);
65 
66   const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
67   const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
68 
69   pool2d_shape_check(
70     input,
71     kH, kW, dH, dW, padH, padW, dilationH, dilationW,
72     nInputPlane,
73     inputHeight, inputWidth,
74     outputHeight, outputWidth, memory_format);
75 
76   /* resize output and indices */
77   DimnameList maybe_names = input.has_names() ? input.names() : DimnameList{};
78   if (input.ndimension() == 3) {
79     set_output_raw_strided(0, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
80     /* indices will contain the locations for each output point */
81     set_output_raw_strided(1, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format).dtype(kLong), maybe_names);
82   } else {
83     set_output_raw_strided(0, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
84     /* indices will contain the locations for each output point */
85     set_output_raw_strided(1, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format).dtype(kLong), maybe_names);
86   }
87 }
88 
TORCH_META_FUNC(max_pool2d_with_indices_backward)89 TORCH_META_FUNC(max_pool2d_with_indices_backward)
90 (const Tensor& gradOutput,
91 const Tensor& input,
92 IntArrayRef kernel_size,
93 IntArrayRef stride,
94 IntArrayRef padding,
95 IntArrayRef dilation,
96 bool ceil_mode,
97 const Tensor& indices) {
98   // #20866, #22032: Guarantee this for the official C++ API?
99   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
100     "max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
101   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
102   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
103 
104   // NB: stride default is not expressible as an integer constant, so we accept
105   // empty stride for this case
106   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
107     "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
108   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
109   const int dW = stride.empty() ? kW :
110                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
111 
112   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
113     "max_pool2d: padding must either be a single int, or a tuple of two ints");
114   const int padH = safe_downcast<int, int64_t>(padding[0]);
115   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
116 
117   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
118     "max_pool2d: dilation must be either a single int, or a tuple of two ints");
119   const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
120   const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
121 
122   TORCH_CHECK(input.dtype() == gradOutput.dtype(),
123     "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
124 
125   const auto memory_format = input.suggest_memory_format();
126   if (memory_format == at::MemoryFormat::ChannelsLast) {
127     TORCH_CHECK(input.ndimension() == 4,
128       "non-empty 4D (batch mode) tensor expected for input with channels_last layout");
129   } else if (memory_format == at::MemoryFormat::Contiguous) {
130     TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
131       "non-empty 3D or 4D (batch mode) tensor expected for input");
132   } else {
133     TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous");
134   }
135 
136   /* sizes */
137   const int64_t nInputPlane = input.size(-3);
138   const int64_t inputHeight = input.size(-2);
139   const int64_t inputWidth = input.size(-1);
140 
141   /* XXX preserve the existing shape check behavior */
142   const int64_t outputHeight_for_shape_check = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
143   const int64_t outputWidth_for_shape_check = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
144 
145   max_pool2d_backward_shape_check(
146     input,
147     gradOutput,
148     indices,
149     kH, kW, dH, dW, padH, padW, dilationH, dilationW,
150     nInputPlane,
151     inputHeight, inputWidth,
152     outputHeight_for_shape_check, outputWidth_for_shape_check,
153     memory_format);
154 
155   set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(memory_format),
156              input.has_names() ? input.names() : DimnameList{});
157 }
158 } // namespace at::meta
159 
160 namespace at::native {
161 
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cpu)162 TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cpu)
163 (const Tensor& input,
164 IntArrayRef kernel_size,
165 IntArrayRef stride,
166 IntArrayRef padding,
167 IntArrayRef dilation,
168 bool ceil_mode,
169 const Tensor& output,
170 const Tensor& indices) {
171   NoNamesGuard guard;
172 
173   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
174   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
175 
176   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
177   const int dW = stride.empty() ? kW :
178                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
179 
180   const int padH = safe_downcast<int, int64_t>(padding[0]);
181   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
182 
183   const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
184   const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
185 
186   max_pool2d_kernel(
187       kCPU, output, indices, input,
188       kW, kH,
189       dW, dH,
190       padW, padH,
191       dilationW, dilationH);
192 }
193 
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cpu)194 TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cpu)
195 (const Tensor& gradOutput,
196 const Tensor& input,
197 IntArrayRef kernel_size,
198 IntArrayRef stride,
199 IntArrayRef padding,
200 IntArrayRef dilation,
201 bool ceil_mode,
202 const Tensor& indices,
203 const Tensor& gradInput) {
204   NoNamesGuard guard;
205 
206   gradInput.zero_();
207   max_pool2d_backward_kernel(
208       kCPU, const_cast<Tensor&>(gradInput),
209       gradOutput, indices);
210 }
211 
212 DEFINE_DISPATCH(max_pool2d_kernel);
213 DEFINE_DISPATCH(max_pool2d_backward_kernel);
214 
215 } // at
216