1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/NamedTensor.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/Pool.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/max_pool2d_with_indices_backward_native.h>
13 #include <ATen/ops/max_pool2d_with_indices_native.h>
14 #endif
15
16 namespace at::meta {
17 using namespace at::native;
TORCH_META_FUNC(max_pool2d_with_indices)18 TORCH_META_FUNC(max_pool2d_with_indices)
19 (const Tensor& input,
20 IntArrayRef kernel_size,
21 IntArrayRef stride,
22 IntArrayRef padding,
23 IntArrayRef dilation,
24 bool ceil_mode) {
25 // #20866, #22032: Guarantee this for the official C++ API?
26 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
27 "max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
28 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
29 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
30
31 // NB: stride default is not expressible as an integer constant, so we accept
32 // empty stride for this case
33 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
34 "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
35 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
36 const int dW = stride.empty() ? kW :
37 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
38
39 TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
40 "max_pool2d: padding must either be a single int, or a tuple of two ints");
41 const int padH = safe_downcast<int, int64_t>(padding[0]);
42 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
43
44 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
45 "max_pool2d: dilation must be either a single int, or a tuple of two ints");
46 const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
47 const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
48
49 const auto memory_format = input.suggest_memory_format();
50 if (memory_format == at::MemoryFormat::ChannelsLast) {
51 TORCH_CHECK(input.ndimension() == 4,
52 "non-empty 4D (batch mode) tensor expected for input with channels_last layout");
53 } else if (memory_format == at::MemoryFormat::Contiguous) {
54 TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
55 "non-empty 3D or 4D (batch mode) tensor expected for input");
56 } else {
57 TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous");
58 }
59
60 /* sizes */
61 const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
62 const int64_t nInputPlane = input.size(-3);
63 const int64_t inputHeight = input.size(-2);
64 const int64_t inputWidth = input.size(-1);
65
66 const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
67 const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
68
69 pool2d_shape_check(
70 input,
71 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
72 nInputPlane,
73 inputHeight, inputWidth,
74 outputHeight, outputWidth, memory_format);
75
76 /* resize output and indices */
77 DimnameList maybe_names = input.has_names() ? input.names() : DimnameList{};
78 if (input.ndimension() == 3) {
79 set_output_raw_strided(0, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
80 /* indices will contain the locations for each output point */
81 set_output_raw_strided(1, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format).dtype(kLong), maybe_names);
82 } else {
83 set_output_raw_strided(0, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
84 /* indices will contain the locations for each output point */
85 set_output_raw_strided(1, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format).dtype(kLong), maybe_names);
86 }
87 }
88
TORCH_META_FUNC(max_pool2d_with_indices_backward)89 TORCH_META_FUNC(max_pool2d_with_indices_backward)
90 (const Tensor& gradOutput,
91 const Tensor& input,
92 IntArrayRef kernel_size,
93 IntArrayRef stride,
94 IntArrayRef padding,
95 IntArrayRef dilation,
96 bool ceil_mode,
97 const Tensor& indices) {
98 // #20866, #22032: Guarantee this for the official C++ API?
99 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
100 "max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
101 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
102 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
103
104 // NB: stride default is not expressible as an integer constant, so we accept
105 // empty stride for this case
106 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
107 "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
108 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
109 const int dW = stride.empty() ? kW :
110 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
111
112 TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
113 "max_pool2d: padding must either be a single int, or a tuple of two ints");
114 const int padH = safe_downcast<int, int64_t>(padding[0]);
115 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
116
117 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
118 "max_pool2d: dilation must be either a single int, or a tuple of two ints");
119 const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
120 const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
121
122 TORCH_CHECK(input.dtype() == gradOutput.dtype(),
123 "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
124
125 const auto memory_format = input.suggest_memory_format();
126 if (memory_format == at::MemoryFormat::ChannelsLast) {
127 TORCH_CHECK(input.ndimension() == 4,
128 "non-empty 4D (batch mode) tensor expected for input with channels_last layout");
129 } else if (memory_format == at::MemoryFormat::Contiguous) {
130 TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
131 "non-empty 3D or 4D (batch mode) tensor expected for input");
132 } else {
133 TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous");
134 }
135
136 /* sizes */
137 const int64_t nInputPlane = input.size(-3);
138 const int64_t inputHeight = input.size(-2);
139 const int64_t inputWidth = input.size(-1);
140
141 /* XXX preserve the existing shape check behavior */
142 const int64_t outputHeight_for_shape_check = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
143 const int64_t outputWidth_for_shape_check = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
144
145 max_pool2d_backward_shape_check(
146 input,
147 gradOutput,
148 indices,
149 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
150 nInputPlane,
151 inputHeight, inputWidth,
152 outputHeight_for_shape_check, outputWidth_for_shape_check,
153 memory_format);
154
155 set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(memory_format),
156 input.has_names() ? input.names() : DimnameList{});
157 }
158 } // namespace at::meta
159
160 namespace at::native {
161
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cpu)162 TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cpu)
163 (const Tensor& input,
164 IntArrayRef kernel_size,
165 IntArrayRef stride,
166 IntArrayRef padding,
167 IntArrayRef dilation,
168 bool ceil_mode,
169 const Tensor& output,
170 const Tensor& indices) {
171 NoNamesGuard guard;
172
173 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
174 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
175
176 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
177 const int dW = stride.empty() ? kW :
178 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
179
180 const int padH = safe_downcast<int, int64_t>(padding[0]);
181 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
182
183 const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
184 const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
185
186 max_pool2d_kernel(
187 kCPU, output, indices, input,
188 kW, kH,
189 dW, dH,
190 padW, padH,
191 dilationW, dilationH);
192 }
193
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cpu)194 TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cpu)
195 (const Tensor& gradOutput,
196 const Tensor& input,
197 IntArrayRef kernel_size,
198 IntArrayRef stride,
199 IntArrayRef padding,
200 IntArrayRef dilation,
201 bool ceil_mode,
202 const Tensor& indices,
203 const Tensor& gradInput) {
204 NoNamesGuard guard;
205
206 gradInput.zero_();
207 max_pool2d_backward_kernel(
208 kCPU, const_cast<Tensor&>(gradInput),
209 gradOutput, indices);
210 }
211
212 DEFINE_DISPATCH(max_pool2d_kernel);
213 DEFINE_DISPATCH(max_pool2d_backward_kernel);
214
215 } // at
216