1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/FractionalMaxPooling.h>
7
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/fractional_max_pool3d_backward_native.h>
16 #include <ATen/ops/fractional_max_pool3d_native.h>
17 #endif
18
19
20 namespace at::meta {
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)21 TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
22 const at::Tensor& input_,
23 IntArrayRef pool_size,
24 IntArrayRef output_size,
25 const at::Tensor& randomSamples
26 ) {
27 TORCH_CHECK(
28 pool_size.size() == 3,
29 "fractional_max_pool3d: kernel_size must either be a single Int or tuple of three Ints")
30 TORCH_CHECK(
31 output_size.size() == 3,
32 "fractional_max_pool3d: output_size must either be a single Int or tuple of three Ints")
33 int64_t outputT = output_size[0];
34 int64_t outputH = output_size[1];
35 int64_t outputW = output_size[2];
36 int64_t poolSizeT = pool_size[0];
37 int64_t poolSizeH = pool_size[1];
38 int64_t poolSizeW = pool_size[2];
39
40 int64_t numBatch = 1;
41 int64_t planeDim = 0;
42 int64_t timeDim = 1;
43 int64_t heightDim = 2;
44 int64_t widthDim = 3;
45
46 int64_t ndims = input_.ndimension();
47 TORCH_CHECK(ndims == 4 || ndims == 5,
48 "fractional_max_pool3d_out(): Expected 4D or 5D tensor, but got: ",
49 input_.sizes());
50 for (const auto i : c10::irange(1, ndims)) {
51 TORCH_CHECK(input_.size(i) > 0,
52 "fractional_max_pool3d_out(): Expected input to have non-zero size for non-batch dimensions, but got",
53 input_.sizes(), " with dimension ", i, " being empty.");
54 }
55
56 if (ndims == 5) {
57 numBatch = input_.size(0);
58 planeDim++;
59 timeDim++;
60 heightDim++;
61 widthDim++;
62 }
63
64 /* sizes */
65 int64_t numPlanes = input_.size(planeDim);
66 int64_t inputT = input_.size(timeDim);
67 int64_t inputH = input_.size(heightDim);
68 int64_t inputW = input_.size(widthDim);
69
70 TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
71 "fractional_max_pool3d_out(): pool time ", poolSizeT,
72 " too large relative to input time ", inputT);
73 TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
74 "fractional_max_pool3d_out(): pool width ", poolSizeW,
75 " too large relative to input width ", inputW);
76 TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
77 "fractional_max_pool3d_out(): pool height ", poolSizeH,
78 " too large relative to input height ", inputH);
79
80 if (ndims == 4) {
81 /* resize output */
82 set_output_raw_strided(0, {numPlanes, outputT, outputH, outputW}, {}, input_.options());
83 /* indices will contain the locations for each output point */
84 set_output_raw_strided(1, {numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
85 } else {
86 set_output_raw_strided(0, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options());
87 /* indices will contain the locations for each output point */
88 set_output_raw_strided(1, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
89 }
90
91 return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
92 .set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
93 .set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
94 }
95
96 } // namespace at::meta
97
98 namespace at::native {
99 namespace {
100
101 template<typename scalar_t>
fractional_max_pool3d_out_single_batch_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW,int64_t poolSizeT,int64_t poolSizeH,int64_t poolSizeW)102 static void fractional_max_pool3d_out_single_batch_frame(
103 const scalar_t* input,
104 scalar_t* output,
105 int64_t* indices,
106 const scalar_t* randomSamples,
107 int64_t numPlanes,
108 int64_t inputT, int64_t inputH, int64_t inputW,
109 int64_t outputT, int64_t outputH, int64_t outputW,
110 int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
111
112 at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
113 for (const auto plane : c10::irange(start, end)) {
114 /* each plane contains 3 random samples,
115 one for T, one for W, and one for H */
116 const scalar_t* randomSamplesForPlane = randomSamples + plane * 3;
117
118 /* Generate interval sequence */
119 auto sequenceT = generate_intervals<scalar_t>(
120 randomSamplesForPlane[0], inputT, outputT, poolSizeT);
121 auto sequenceH = generate_intervals<scalar_t>(
122 randomSamplesForPlane[1], inputH, outputH, poolSizeH);
123 auto sequenceW = generate_intervals<scalar_t>(
124 randomSamplesForPlane[2], inputW, outputW, poolSizeW);
125
126 /* loop over output */
127 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
128 int64_t t, h, w;
129
130 const scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
131 scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
132 int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
133
134 for (t = 0; t < outputT; ++t) {
135 int64_t inputTStart = sequenceT[t];
136
137 for (h = 0; h < outputH; ++h) {
138 int64_t inputHStart = sequenceH[h];
139
140 for (w = 0; w < outputW; ++w) {
141 int64_t inputWStart = sequenceW[w];
142
143 int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
144 scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
145 int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2;
146
147 for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) {
148 for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
149 for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
150 AT_ASSERT(t2 >= 0 && t2 < inputT);
151 AT_ASSERT(h2 >= 0 && h2 < inputH);
152 AT_ASSERT(w2 >= 0 && w2 < inputW);
153
154 int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2;
155 scalar_t val = inputForPlane[planeIndex];
156 if (val > maxVal || std::isnan(val)) {
157 maxVal = val;
158 maxIndex = planeIndex;
159 }
160 }
161 }
162 }
163
164 outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal;
165 indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex;
166 }
167 }
168 }
169 }
170 });
171 }
172
173 template<typename scalar_t>
fractional_max_pool3d_out_frame(const scalar_t * input,scalar_t * output,int64_t * indices,const scalar_t * randomSamples,int64_t numBatch,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW,int64_t poolSizeT,int64_t poolSizeH,int64_t poolSizeW)174 static void fractional_max_pool3d_out_frame(
175 const scalar_t* input,
176 scalar_t* output,
177 int64_t* indices,
178 const scalar_t* randomSamples,
179 int64_t numBatch, int64_t numPlanes,
180 int64_t inputT, int64_t inputH, int64_t inputW,
181 int64_t outputT, int64_t outputH, int64_t outputW,
182 int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
183 if(numBatch == 1) {
184 fractional_max_pool3d_out_single_batch_frame<scalar_t>(
185 input, output, indices, randomSamples,
186 numPlanes,
187 inputT, inputH, inputW,
188 outputT, outputH, outputW,
189 poolSizeT, poolSizeH, poolSizeW
190 );
191 return;
192 }
193
194 at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
195 for (const auto batch : c10::irange(start, end)) {
196 fractional_max_pool3d_out_single_batch_frame<scalar_t>(
197 input + batch * numPlanes * inputW * inputH * inputT,
198 output + batch * numPlanes * outputW * outputH * outputT,
199 indices + batch * numPlanes * outputW * outputH * outputT,
200 randomSamples + batch * numPlanes * 3,
201 numPlanes,
202 inputT, inputH, inputW,
203 outputT, outputH, outputW,
204 poolSizeT, poolSizeH, poolSizeW
205 );
206 }
207 });
208 }
209
210 } // anonymous namespace
211
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)212 TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
213 const at::Tensor& input_,
214 int64_t poolSizeT,
215 int64_t poolSizeH,
216 int64_t poolSizeW,
217 int64_t outputT,
218 int64_t outputH,
219 int64_t outputW,
220 const at::Tensor& randomSamples_,
221 int64_t numBatch,
222 int64_t numPlanes,
223 int64_t inputT,
224 int64_t inputH,
225 int64_t inputW,
226 const at::Tensor& output,
227 const at::Tensor& indices) {
228
229 fractional_max_pool_check_shape</*ndim*/ 3>(input_, randomSamples_);
230
231 if (output.numel() == 0) {
232 return;
233 }
234
235 /* get contiguous input and samples */
236 auto input = input_.contiguous();
237 auto randomSamples = randomSamples_.contiguous();
238
239 AT_DISPATCH_FLOATING_TYPES_AND2(
240 kBFloat16,
241 kHalf,
242 input.scalar_type(),
243 "fractional_max_pool3d_out_frame",
244 [&] {
245 fractional_max_pool3d_out_frame<scalar_t>(
246 input.const_data_ptr<scalar_t>(),
247 output.data_ptr<scalar_t>(),
248 indices.data_ptr<int64_t>(),
249 randomSamples.const_data_ptr<scalar_t>(),
250 numBatch, numPlanes,
251 inputT, inputH, inputW,
252 outputT, outputH, outputW,
253 poolSizeT, poolSizeH, poolSizeW
254 );
255 }
256 );
257 }
258
259 namespace {
260
261 template<typename scalar_t>
fractional_max_pool3d_backward_out_single_batch_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW)262 static void fractional_max_pool3d_backward_out_single_batch_frame(
263 scalar_t* gradInput,
264 const scalar_t* gradOutput,
265 const int64_t* indices,
266 int64_t numPlanes,
267 int64_t inputT, int64_t inputH, int64_t inputW,
268 int64_t outputT, int64_t outputH, int64_t outputW) {
269
270 at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
271 for (const auto plane : c10::irange(start, end)) {
272 scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW;
273 const scalar_t* gradOutputForPlane = gradOutput +
274 plane * outputT * outputH * outputW;
275 const int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
276
277 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
278 int64_t h, w, t;
279 for (t = 0; t < outputT; ++t) {
280 for (h = 0; h < outputH; ++h) {
281 for (w = 0; w < outputW; ++w) {
282 int64_t outputIndex = t * outputH * outputW + h * outputW + w;
283 int64_t index = indicesForPlane[outputIndex];
284 AT_ASSERT(index >= 0 && index < inputT * inputH * inputW);
285 gradInputForPlane[index] += gradOutputForPlane[outputIndex];
286 }
287 }
288 }
289 }
290 });
291 }
292
293 template<typename scalar_t>
fractional_max_pool3d_backward_out_frame(scalar_t * gradInput,const scalar_t * gradOutput,const int64_t * indices,int64_t numBatch,int64_t numPlanes,int64_t inputT,int64_t inputH,int64_t inputW,int64_t outputT,int64_t outputH,int64_t outputW)294 static void fractional_max_pool3d_backward_out_frame(
295 scalar_t* gradInput,
296 const scalar_t* gradOutput,
297 const int64_t* indices,
298 int64_t numBatch, int64_t numPlanes,
299 int64_t inputT, int64_t inputH, int64_t inputW,
300 int64_t outputT, int64_t outputH, int64_t outputW) {
301 if(numBatch == 1) {
302 fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
303 gradInput, gradOutput, indices,
304 numPlanes,
305 inputT, inputH, inputW,
306 outputT, outputH, outputW
307 );
308 return;
309 }
310
311 at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
312 for (const auto batch : c10::irange(start, end)) {
313 fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
314 gradInput + batch * numPlanes * inputW * inputH * inputT,
315 gradOutput + batch * numPlanes * outputW * outputH * outputT,
316 indices + batch * numPlanes * outputW * outputH * outputT,
317 numPlanes,
318 inputT, inputH, inputW,
319 outputT, outputH, outputW
320 );
321 }
322 });
323 }
324
325
fractional_max_pool3d_backward_out_cpu_template(const Tensor & input,const Tensor & gradOutput_,Tensor & gradInput,IntArrayRef output_size,IntArrayRef pool_size,const Tensor & indices)326 void fractional_max_pool3d_backward_out_cpu_template(
327 const Tensor& input,
328 const Tensor& gradOutput_,
329 Tensor& gradInput,
330 IntArrayRef output_size,
331 IntArrayRef pool_size /* unused */,
332 const Tensor& indices) {
333
334 int64_t outputT = output_size[0];
335 int64_t outputH = output_size[1];
336 int64_t outputW = output_size[2];
337
338 int64_t numBatch = 1;
339 int64_t planeDim = 0;
340 int64_t timeDim = 1;
341 int64_t heightDim = 2;
342 int64_t widthDim = 3;
343
344 int64_t ndims = input.ndimension();
345 if (ndims == 5) {
346 numBatch = input.size(0);
347 planeDim = 1;
348 heightDim++;
349 widthDim++;
350 timeDim++;
351 }
352
353 /* sizes */
354 int64_t numPlanes = input.size(planeDim);
355 int64_t inputT = input.size(timeDim);
356 int64_t inputH = input.size(heightDim);
357 int64_t inputW = input.size(widthDim);
358
359 TORCH_CHECK(outputT == gradOutput_.size(timeDim),
360 "fractional_max_pool3d_backward_out(): gradOutput time unexpected");
361 TORCH_CHECK(outputH == gradOutput_.size(heightDim),
362 "fractional_max_pool3d_backward_out(): ",
363 "gradOutput height unexpected");
364 TORCH_CHECK(outputW == gradOutput_.size(widthDim),
365 "fractional_max_pool3d_backward_out(): gradOutput width unexpected");
366
367 /* get contiguous gradOutput */
368 auto gradOutput = gradOutput_.contiguous();
369
370 /* resize */
371 gradInput.resize_as_(input);
372 gradInput.zero_();
373
374 /* backprop */
375 AT_DISPATCH_FLOATING_TYPES_AND2(
376 kBFloat16,
377 kHalf,
378 input.scalar_type(),
379 "fractional_max_pool3d_backward_out_frame",
380 [&]{
381 fractional_max_pool3d_backward_out_frame<scalar_t>(
382 gradInput.data_ptr<scalar_t>(),
383 gradOutput.const_data_ptr<scalar_t>(),
384 indices.const_data_ptr<int64_t>(),
385 numBatch, numPlanes,
386 inputT, inputH, inputW,
387 outputT, outputH, outputW
388 );
389 }
390 );
391 }
392
393 }// anonymous namespace
394
fractional_max_pool3d_backward_out_cpu(const at::Tensor & gradOutput_,const at::Tensor & input,IntArrayRef pool_size,IntArrayRef output_size,const at::Tensor & indices,at::Tensor & gradInput)395 Tensor& fractional_max_pool3d_backward_out_cpu(const at::Tensor& gradOutput_,
396 const at::Tensor& input,
397 IntArrayRef pool_size,
398 IntArrayRef output_size,
399 const at::Tensor& indices,
400 at::Tensor& gradInput) {
401 fractional_max_pool3d_backward_out_cpu_template(
402 input,
403 gradOutput_,
404 gradInput,
405 output_size,
406 pool_size,
407 indices);
408 return gradInput;
409 }
410
fractional_max_pool3d_backward_cpu(const at::Tensor & gradOutput_,const at::Tensor & input,IntArrayRef pool_size,IntArrayRef output_size,const at::Tensor & indices)411 Tensor fractional_max_pool3d_backward_cpu(
412 const at::Tensor& gradOutput_,
413 const at::Tensor& input,
414 IntArrayRef pool_size,
415 IntArrayRef output_size,
416 const at::Tensor& indices) {
417 Tensor gradInput = at::empty({0}, input.options());
418 fractional_max_pool3d_backward_out_cpu_template(
419 input,
420 gradOutput_,
421 gradInput,
422 output_size,
423 pool_size,
424 indices);
425 return gradInput;
426 }
427
428 } // namespace at::native
429