1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/FractionalMaxPooling.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/fractional_max_pool2d_backward_native.h>
14 #include <ATen/ops/fractional_max_pool2d_native.h>
15 #endif
16
17 namespace at {
18
19 namespace meta {
TORCH_META_FUNC(fractional_max_pool2d)20 TORCH_META_FUNC(fractional_max_pool2d) (
21 const at::Tensor& input,
22 IntArrayRef pool_size,
23 IntArrayRef output_size,
24 const at::Tensor& randomSamples
25 ) {
26 TORCH_CHECK(
27 pool_size.size() == 2,
28 "fractional_max_pool2d: kernel_size must either be a single Int or tuple of Ints")
29 TORCH_CHECK(
30 output_size.size() == 2,
31 "fractional_max_pool2d: output_size must either be a single Int or tuple of Ints")
32 int64_t numBatch = 1;
33 int64_t planeDim = 0;
34 int64_t heightDim = 1;
35 int64_t widthDim = 2;
36 int64_t outputH = output_size[0];
37 int64_t outputW = output_size[1];
38 int64_t poolSizeH = pool_size[0];
39 int64_t poolSizeW = pool_size[1];
40
41 int64_t ndims = input.ndimension();
42 TORCH_CHECK(ndims == 3 || ndims == 4,
43 "fractional_max_pool2d(): Expected 3D or 4D tensor, but got: ", input.sizes());
44 for (const auto i : c10::irange(1, ndims)) {
45 TORCH_CHECK(input.size(i) > 0,
46 "fractional_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, but got",
47 input.sizes(), " with dimension ", i, " being empty.");
48 }
49
50
51 if (ndims == 4) {
52 numBatch = input.size(0);
53 planeDim++;
54 heightDim++;
55 widthDim++;
56 }
57
58 /* sizes */
59 int64_t numPlanes = input.size(planeDim);
60 int64_t inputH = input.size(heightDim);
61 auto inputW = input.size(widthDim);
62
63 TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
64 "fractional_max_pool2d(): pool height ", poolSizeH,
65 " too large relative to input height ", inputH);
66 TORCH_CHECK(outputW + poolSizeW - 1 <= inputW,
67 "fractional_max_pool2d(): pool width ", poolSizeW,
68 " too large relative to input width ", inputW);
69
70 if (ndims == 3) {
71 set_output_raw_strided(0, {numPlanes, outputH, outputW}, {}, input.options());
72 /* indices will contain the locations for each output point */
73 set_output_raw_strided(1, {numPlanes, outputH, outputW}, {}, input.options().dtype(kLong));
74 } else {
75 set_output_raw_strided(0, {numBatch, numPlanes, outputH, outputW}, {}, input.options());
76 /* indices will contain the locations for each output point */
77 set_output_raw_strided(1, {numBatch, numPlanes, outputH, outputW}, {}, input.options().dtype(kLong));
78 }
79 }
80
TORCH_META_FUNC(fractional_max_pool2d_backward)81 TORCH_META_FUNC(fractional_max_pool2d_backward)(
82 const at::Tensor& gradOutput_,
83 const at::Tensor& input,
84 IntArrayRef pool_size /* unused */,
85 IntArrayRef output_size,
86 const at::Tensor& indices) {
87
88 int64_t numBatch = 1;
89 int planeDim = 0;
90 int heightDim = 1;
91 int widthDim = 2;
92
93 auto outputH = output_size[0];
94 auto outputW = output_size[1];
95
96 auto ndims = input.ndimension();
97 if (ndims == 4) {
98 numBatch = input.size(0);
99 planeDim = 1;
100 heightDim++;
101 widthDim++;
102 }
103
104 /* sizes */
105 auto numPlanes = input.size(planeDim);
106 auto inputH = input.size(heightDim);
107 auto inputW = input.size(widthDim);
108
109 /* get contiguous gradOutput */
110 auto gradOutput = gradOutput_.contiguous();
111
112 TORCH_CHECK(outputW == gradOutput.size(widthDim),
113 "fractional_max_pool2d_backward(): gradOutput width unexpected");
114 TORCH_CHECK(outputH == gradOutput.size(heightDim),
115 "fractional_max_pool2d_backward(): gradOutput height unexpected");
116
117 /* resize */
118 if (ndims == 3) {
119 set_output_raw_strided(0, {numPlanes, inputH, inputW}, {}, input.options());
120 } else {
121 set_output_raw_strided(0, {numBatch, numPlanes, inputH, inputW}, {}, input.options());
122 }
123 }
124 } // namespace meta
125
126 namespace native {
127 namespace {
128
129 template <typename scalar_t>
fractional_max_pool2d_out_single_batch_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int numPlanes,int inputW,int inputH,int outputW,int outputH,int poolSizeW,int poolSizeH)130 static void fractional_max_pool2d_out_single_batch_frame(
131 const scalar_t* input,
132 scalar_t* output,
133 int64_t* indices,
134 const scalar_t* randomSamples,
135 int numPlanes,
136 int inputW, int inputH,
137 int outputW, int outputH,
138 int poolSizeW, int poolSizeH) {
139 at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
140 for (const auto plane : c10::irange(start, end)) {
141 /* each plane contains 2 random samples, one for W and one for H */
142 const scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
143
144 /* Generate interval sequence */
145 auto sequenceW = generate_intervals<scalar_t>(
146 randomSamplesForPlane[0], inputW, outputW, poolSizeW);
147 auto sequenceH = generate_intervals<scalar_t>(
148 randomSamplesForPlane[1], inputH, outputH, poolSizeH);
149
150 /* loop over output */
151 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
152 int h, w;
153
154 const scalar_t* inputForPlane = input + plane * inputW * inputH;
155 scalar_t* outputForPlane = output + plane * outputW * outputH;
156 int64_t* indicesForPlane = indices + plane * outputW * outputH;
157
158 for (h = 0; h < outputH; ++h) {
159 int inputHStart = sequenceH[h];
160
161 for (w = 0; w < outputW; ++w) {
162 int inputWStart = sequenceW[w];
163
164 int h2 = inputHStart, w2 = inputWStart;
165 scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
166 int64_t maxIndex = h2 * inputW + w2;
167
168 for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
169 for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
170 AT_ASSERT(h2 >= 0 && h2 < inputH);
171 AT_ASSERT(w2 >= 0 && w2 < inputW);
172
173 int planeIndex = h2 * inputW + w2;
174 scalar_t val = inputForPlane[planeIndex];
175 if (val > maxVal || std::isnan(val)) {
176 maxVal = val;
177 maxIndex = planeIndex;
178 }
179 }
180 }
181
182 outputForPlane[h * outputW + w] = maxVal;
183 indicesForPlane[h * outputW + w] = maxIndex;
184 }
185 }
186 }
187 });
188 }
189
190 template <typename scalar_t>
fractional_max_pool2d_out_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int numBatch,int numPlanes,int inputW,int inputH,int outputW,int outputH,int poolSizeW,int poolSizeH)191 static void fractional_max_pool2d_out_frame(
192 const scalar_t* input,
193 scalar_t* output,
194 int64_t* indices,
195 const scalar_t* randomSamples,
196 int numBatch, int numPlanes,
197 int inputW, int inputH,
198 int outputW, int outputH,
199 int poolSizeW, int poolSizeH) {
200 if(numBatch == 1) {
201 fractional_max_pool2d_out_single_batch_frame<scalar_t>(
202 input,
203 output,
204 indices,
205 randomSamples,
206 numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
207 );
208 return;
209 }
210 at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
211 for (const auto batch : c10::irange(start, end)) {
212 fractional_max_pool2d_out_single_batch_frame<scalar_t>(
213 input + batch * numPlanes * inputH * inputW,
214 output + batch * numPlanes * outputH * outputW,
215 indices + batch * numPlanes * outputH * outputW,
216 randomSamples + batch * numPlanes * 2,
217 numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
218 }
219 });
220 }
221
222 template <typename scalar_t>
fractional_max_pool2d_backward_out_single_batch_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int numPlanes,int inputW,int inputH,int outputW,int outputH)223 static void fractional_max_pool2d_backward_out_single_batch_frame(
224 scalar_t* gradInput,
225 const scalar_t* gradOutput,
226 const int64_t* indices,
227 int numPlanes,
228 int inputW, int inputH,
229 int outputW, int outputH) {
230 at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
231 for (const auto plane : c10::irange(start, end)) {
232 scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
233 const scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
234 const int64_t* indicesForPlane = indices + plane * outputW * outputH;
235
236 for (int h = 0; h < outputH; ++h) {
237 for (int w = 0; w < outputW; ++w) {
238 int outputIndex = h * outputW + w;
239 int64_t index = indicesForPlane[outputIndex];
240 AT_ASSERT(index >= 0 && index < static_cast<int64_t>(inputW) * inputH);
241
242 gradInputForPlane[index] += gradOutputForPlane[outputIndex];
243 }
244 }
245 }
246 });
247 }
248
249 template <typename scalar_t>
fractional_max_pool2d_backward_out_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int numBatch,int numPlanes,int inputW,int inputH,int outputW,int outputH)250 static void fractional_max_pool2d_backward_out_frame(
251 scalar_t* gradInput,
252 const scalar_t* gradOutput,
253 const int64_t* indices,
254 int numBatch, int numPlanes,
255 int inputW, int inputH,
256 int outputW, int outputH) {
257 if(numBatch == 1) {
258 fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
259 gradInput, gradOutput, indices,
260 numPlanes,
261 inputW, inputH, outputW, outputH
262 );
263 return;
264 }
265 at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
266 for (const auto batch : c10::irange(start, end)) {
267 fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
268 gradInput + batch * numPlanes * inputH * inputW,
269 gradOutput + batch * numPlanes * outputH * outputW,
270 indices + batch * numPlanes * outputH * outputW,
271 numPlanes, inputW, inputH, outputW, outputH);
272 }
273 });
274 }
275
276 } // anonymous namespace
277
TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu)278 TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) (
279 const at::Tensor& input_,
280 IntArrayRef pool_size,
281 IntArrayRef output_size,
282 const at::Tensor& randomSamples_,
283 const at::Tensor& output,
284 const at::Tensor& indices) {
285
286 fractional_max_pool_check_shape</*ndim*/ 2>(input_, randomSamples_);
287
288 if (output.numel() == 0) {
289 return;
290 }
291
292 int64_t numBatch = 1;
293 int64_t planeDim = 0;
294 int64_t heightDim = 1;
295 int64_t widthDim = 2;
296 int64_t outputH = output_size[0]; // output.size(heightDim)
297 int64_t outputW = output_size[1]; // output.size(widthDim)
298 int64_t poolSizeH = pool_size[0];
299 int64_t poolSizeW = pool_size[1];
300
301 /* get contiguous input and samples */
302 auto input = input_.contiguous();
303 auto randomSamples = randomSamples_.contiguous();
304
305 int64_t ndims = input.ndimension();
306
307 if (ndims == 4) {
308 numBatch = input.size(0);
309 planeDim++;
310 heightDim++;
311 widthDim++;
312 }
313
314 /* sizes */
315 int64_t numPlanes = input.size(planeDim);
316 int64_t inputH = input.size(heightDim);
317 int64_t inputW = input.size(widthDim);
318
319 AT_DISPATCH_FLOATING_TYPES_AND2(
320 kBFloat16,
321 kHalf,
322 input.scalar_type(),
323 "fractional_max_pool2d_out_frame", [&] {
324 auto input_data = input.const_data_ptr<scalar_t>();
325 auto output_data = output.data_ptr<scalar_t>();
326 auto indices_data = indices.data_ptr<int64_t>();
327 auto randomSamples_data = randomSamples.const_data_ptr<scalar_t>();
328 fractional_max_pool2d_out_frame<scalar_t>(
329 input_data,
330 output_data,
331 indices_data,
332 randomSamples_data,
333 numBatch, numPlanes,
334 inputW, inputH,
335 outputW, outputH,
336 poolSizeW, poolSizeH);
337 }
338 );
339 }
340
TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu)341 TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
342 const at::Tensor& gradOutput_,
343 const at::Tensor& input,
344 IntArrayRef pool_size,
345 IntArrayRef output_size,
346 const at::Tensor& indices,
347 const at::Tensor& gradInput) {
348
349 gradInput.zero_();
350
351 int64_t numBatch = 1;
352 int planeDim = 0;
353 int heightDim = 1;
354 int widthDim = 2;
355
356 auto outputH = output_size[0];
357 auto outputW = output_size[1];
358
359 auto ndims = input.ndimension();
360 if (ndims == 4) {
361 numBatch = input.size(0);
362 planeDim = 1;
363 heightDim++;
364 widthDim++;
365 }
366
367 /* sizes */
368 auto numPlanes = input.size(planeDim);
369 auto inputH = input.size(heightDim);
370 auto inputW = input.size(widthDim);
371
372 /* get contiguous gradOutput */
373 auto gradOutput = gradOutput_.contiguous();
374
375 /* backprop */
376 AT_DISPATCH_FLOATING_TYPES_AND2(
377 kBFloat16,
378 kHalf,
379 input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
380 auto gradInput_data = gradInput.data_ptr<scalar_t>();
381 auto gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
382 auto indices_data = indices.const_data_ptr<int64_t>();
383 fractional_max_pool2d_backward_out_frame<scalar_t>(
384 gradInput_data,
385 gradOutput_data,
386 indices_data,
387 numBatch, numPlanes,
388 inputW, inputH,
389 outputW, outputH
390 );
391 }
392 );
393 }
394
395 } // at::native
396 } // at
397