1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/AdaptivePooling.h>
3 #include <ATen/core/Tensor.h>
4
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/Pool.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/irange.h>
12 #include <type_traits>
13 #include <ATen/OpMathType.h>
14 #include <ATen/native/ReduceOpsUtils.h>
15
16 namespace at::native {
17
18 namespace {
19
20 template <typename scalar_t>
is_nan(scalar_t v)21 bool is_nan(scalar_t v) {
22 if (std::is_integral<scalar_t>::value || std::is_same<scalar_t, unsigned char>::value) {
23 return false;
24 }
25 return std::isnan(v);
26 }
27
28 template <typename scalar_t>
is_nan_vec(vec::Vectorized<scalar_t> vec)29 vec::Vectorized<scalar_t> is_nan_vec(vec::Vectorized<scalar_t> vec) {
30 return vec.isnan();
31 }
32
33 // TODO: use is_integeral/is_same to check the scalar_t and simplify the implementation
34 // currently it does not work
35 template <>
is_nan_vec(vec::Vectorized<unsigned char> vec)36 vec::Vectorized<unsigned char> is_nan_vec<unsigned char>(vec::Vectorized<unsigned char> vec) {
37 Vectorized<unsigned char> ret(false);
38 return ret;
39 }
40
41 template <>
is_nan_vec(vec::Vectorized<signed char> vec)42 vec::Vectorized<signed char> is_nan_vec<signed char>(vec::Vectorized<signed char> vec) {
43 Vectorized<signed char> ret(false);
44 return ret;
45 }
46
47 template <>
is_nan_vec(vec::Vectorized<short> vec)48 vec::Vectorized<short> is_nan_vec<short>(vec::Vectorized<short> vec) {
49 Vectorized<short> ret(false);
50 return ret;
51 }
52
53 template <>
is_nan_vec(vec::Vectorized<int> vec)54 vec::Vectorized<int> is_nan_vec<int>(vec::Vectorized<int> vec) {
55 Vectorized<int> ret(false);
56 return ret;
57 }
58
59 template <>
is_nan_vec(vec::Vectorized<int64_t> vec)60 vec::Vectorized<int64_t> is_nan_vec<int64_t>(vec::Vectorized<int64_t> vec) {
61 Vectorized<int64_t> ret(false);
62 return ret;
63 }
64
65 template <typename scalar_t, typename opmath_t>
66 inline
67 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(const scalar_t * input_data,scalar_t * out_data,opmath_t * max_ptr,vec::int_same_size_t<opmath_t> * index_ptr,int64_t * ind,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t channels,int64_t n,int64_t len,int64_t size,int64_t id0,int64_t id1,int64_t ih0,int64_t ih1,int64_t iw0,int64_t iw1,int64_t dilationD,int64_t dilationH,int64_t dilationW)68 compute_internal(
69 const scalar_t* input_data,
70 scalar_t* out_data,
71 opmath_t* max_ptr,
72 vec::int_same_size_t<opmath_t>* index_ptr,
73 int64_t* ind,
74 int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
75 int64_t n,
76 int64_t len,
77 int64_t size,
78 int64_t id0, int64_t id1,
79 int64_t ih0, int64_t ih1,
80 int64_t iw0, int64_t iw1,
81 int64_t dilationD,
82 int64_t dilationH,
83 int64_t dilationW) {
84 using Vec = vec::Vectorized<scalar_t>;
85 using integer_t = vec::int_same_size_t<opmath_t>;
86 using iVec = vec::Vectorized<integer_t>;
87 // Pass I: init out lane
88 iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
89
90 scalar_t min_value = lower_bound<scalar_t>();
91 Vec out_vec = Vec(min_value);
92 int64_t d1 = 0;
93 for (; d1 < len; d1 += Vec::size()) {
94 index0_vec.store(index_ptr + d1);
95 out_vec.store(out_data + d1);
96 }
97 for (; d1 < size; d1++) {
98 ind[d1] = ih0 * input_width + iw0;
99 out_data[d1] = min_value;
100 }
101 // Pass II: compute local max
102 for (int64_t id = id0; id < id1; id += dilationD) {
103 for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
104 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
105 const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
106 id * input_height * input_width + ih * input_width + iw) * channels;
107
108 int64_t d2 = 0;
109 for (; d2 < len; d2 += Vec::size()) {
110 iVec index_vec = iVec(id * input_height * input_width + ih * input_width + iw);
111 Vec val_vec = Vec::loadu(in + d2);
112 iVec maxindex_vec = iVec::loadu(index_ptr + d2);
113 Vec maxval_vec = Vec::loadu(out_data + d2);
114
115 // true = all ones, false = all zeros
116 Vec mask = (val_vec > maxval_vec) | is_nan_vec(val_vec);
117 iVec imask = vec::cast<integer_t>(mask);
118 Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
119 iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);
120
121 out_vec.store(out_data + d2);
122 ind_vec.store(index_ptr + d2);
123 }
124 for (; d2 < size; d2++) {
125 int64_t index = id * input_height * input_width + ih * input_width + iw;
126 scalar_t val = in[d2];
127 int64_t maxindex = ind[d2];
128 scalar_t maxval = out_data[d2];
129
130 bool mask = (val > maxval) || is_nan(static_cast<double>(val));
131 out_data[d2] = mask ? val : maxval;
132 ind[d2] = mask ? index : maxindex;
133 }
134 }
135 }
136 }
137 }
138
139 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
140 template <typename scalar_t, typename opmath_t>
141 inline
142 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(const scalar_t * input_data,scalar_t * out_data,opmath_t * max_ptr,vec::int_same_size_t<opmath_t> * index_ptr,int64_t * ind,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t channels,int64_t n,int64_t len,int64_t size,int64_t id0,int64_t id1,int64_t ih0,int64_t ih1,int64_t iw0,int64_t iw1,int64_t dilationD,int64_t dilationH,int64_t dilationW)143 compute_internal(
144 const scalar_t* input_data,
145 scalar_t* out_data,
146 opmath_t* max_ptr,
147 vec::int_same_size_t<opmath_t>* index_ptr,
148 int64_t* ind,
149 int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
150 int64_t n,
151 int64_t len,
152 int64_t size,
153 int64_t id0, int64_t id1,
154 int64_t ih0, int64_t ih1,
155 int64_t iw0, int64_t iw1,
156 int64_t dilationD,
157 int64_t dilationH,
158 int64_t dilationW) {
159 using Vec = vec::Vectorized<scalar_t>;
160 using fVec = vec::Vectorized<opmath_t>;
161 using iVec = vec::Vectorized<int32_t>;
162 // Pass I: init out lane
163 iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
164 fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
165 int64_t d1 = 0;
166 for (; d1 < len; d1 += fVec::size()) {
167 index0_vec.store(index_ptr + d1);
168 out_vec.store(max_ptr + d1);
169 }
170 for (; d1 < size; d1++) {
171 ind[d1] = ih0 * input_width + iw0;
172 max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
173 }
174 // Pass II: compute local max
175 for (int64_t id = id0; id < id1; id += dilationD) {
176 for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
177 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
178 const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
179 id * input_height * input_width + ih * input_width + iw) * channels;
180
181 int64_t d2 = 0;
182 for (; d2 < len; d2 += Vec::size()) {
183 iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
184 Vec val_bvec = Vec::loadu(in + d2);
185 auto [val_fvec0, val_fvec1] = convert_to_float<scalar_t>(val_bvec);
186
187 iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
188 iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
189 fVec maxval_fvec0 = fVec::loadu(max_ptr + d2);
190 fVec maxval_fvec1 = fVec::loadu(max_ptr + d2 + fVec::size());
191
192 // true = all ones, false = all zeros
193 fVec mask0 = (val_fvec0 > maxval_fvec0) | is_nan_vec(val_fvec0);
194 fVec mask1 = (val_fvec1 > maxval_fvec1) | is_nan_vec(val_fvec1);
195 iVec imask0 = vec::cast<int32_t>(mask0);
196 iVec imask1 = vec::cast<int32_t>(mask1);
197
198 fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
199 fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
200 iVec ind_vec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
201 iVec ind_vec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
202
203 max_fvec0.store(max_ptr + d2);
204 max_fvec1.store(max_ptr + d2 + fVec::size());
205 // out_vec.store(out + d2);
206 ind_vec0.store(index_ptr + d2);
207 ind_vec1.store(index_ptr + d2 + iVec::size());
208 }
209 for (; d2 < size; d2++) {
210 int64_t index = id * input_height * input_width + ih * input_width + iw;
211 opmath_t val = opmath_t(in[d2]);
212 int64_t maxindex = ind[d2];
213 opmath_t maxval = max_ptr[d2];
214
215 bool mask = (val > maxval) || std::isnan(val);
216 max_ptr[d2] = mask ? val : maxval;
217 ind[d2] = mask ? index : maxindex;
218 }
219 }
220 }
221 }
222 // Convert max values from float to bfloat16/half
223 int64_t d3 = 0;
224 for (; d3 < len; d3 += Vec::size()) {
225 fVec max_fvec0 = fVec::loadu(max_ptr + d3);
226 fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
227 Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
228 max_bvec.store(out_data + d3);
229 }
230 for (; d3 < size; d3++) {
231 out_data[d3] = scalar_t(max_ptr[d3]);
232 }
233 }
234
235 template <typename scalar_t, bool is_3d>
cpu_max_pool(const Tensor & output_,const Tensor & indices_,const Tensor & input_,IntArrayRef kWHD,IntArrayRef dWHD,IntArrayRef padWHD,IntArrayRef dilWHD)236 void cpu_max_pool(
237 const Tensor& output_,
238 const Tensor& indices_,
239 const Tensor& input_,
240 IntArrayRef kWHD,
241 IntArrayRef dWHD,
242 IntArrayRef padWHD,
243 IntArrayRef dilWHD) {
244 size_t dims = is_3d ? 3 : 2;
245 TORCH_CHECK(kWHD.size() == dims && dWHD.size() == dims && padWHD.size() == dims && dilWHD.size() == dims,
246 "max pooling 2d/3d are not matched");
247 int kW = kWHD[0];
248 int kH = kWHD[1];
249 int dW = dWHD[0];
250 int dH = dWHD[1];
251 int padW = padWHD[0];
252 int padH = padWHD[1];
253 int dilationW = dilWHD[0];
254 int dilationH = dilWHD[1];
255
256 int kD = is_3d ? kWHD[dims - 1] : 1;
257 int dD = is_3d ? dWHD[dims - 1] : 1;
258 int padD = is_3d ? padWHD[dims - 1] : 0;
259 int dilationD = is_3d ? dilWHD[dims - 1] : 1;
260
261 auto input = input_.contiguous();
262 auto output = output_.contiguous();
263 auto indices = indices_.contiguous();
264
265 auto input_data = input.const_data_ptr<scalar_t>();
266 auto output_data = output.data_ptr<scalar_t>();
267 auto indices_data = indices.data_ptr<int64_t>();
268
269 int64_t ndim = input.ndimension();
270 // treat batch size and channels as one dimension
271 //
272 // MaxPool2d:
273 // ndim == 3: CHW
274 // ndim == 4: NCHW
275 //
276 // MaxPool3d:
277 // ndim == 4: CDHW
278 // ndim == 5: NCDHW
279 int64_t channels;
280 if (is_3d) {
281 channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
282 } else {
283 channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
284 }
285 int64_t input_depth = is_3d ? input.size(-3) : 1;
286 int64_t input_height = input.size(-2);
287 int64_t input_width = input.size(-1);
288 int64_t output_depth = is_3d ? output.size(-3) : 1;
289 int64_t output_height = output.size(-2);
290 int64_t output_width = output.size(-1);
291
292 using opmath_t = at::opmath_type<scalar_t>;
293 // parallel on dim N, C
294 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
295 for (int64_t c = begin; c < end; c++) {
296 const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
297 scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
298 int64_t* indices_ptr = indices_data + c * output_depth * output_height * output_width;
299
300 for (int64_t od = 0; od < output_depth; od++) {
301 int64_t id0 = od * dD - padD;
302 int64_t id1 = std::min(id0 + (kD - 1) * dilationD + 1, input_depth);
303 while(id0 < 0) { id0 += dilationD; }
304
305 for (int64_t oh = 0; oh < output_height; oh++) {
306 int64_t ih0 = oh * dH - padH;
307 int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
308 while(ih0 < 0) { ih0 += dilationH; }
309
310 for (int64_t ow = 0; ow < output_width; ow++) {
311 int64_t iw0 = ow * dW - padW;
312 int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
313 while(iw0 < 0) { iw0 += dilationW; }
314
315 // compute local max
316 int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
317 opmath_t maxval;
318 if (std::numeric_limits<opmath_t>::has_infinity) {
319 maxval = -std::numeric_limits<opmath_t>::infinity();
320 } else {
321 maxval = std::numeric_limits<opmath_t>::min();
322 }
323
324 for (int64_t id = id0; id < id1; id += dilationD) {
325 for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
326 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
327 int64_t index = id * input_height * input_width + ih * input_width + iw;
328 opmath_t val = input_ptr[index];
329 if ((val > maxval) || is_nan(static_cast<double>(val))) {
330 maxval = val;
331 maxindex = index;
332 }
333 }
334 }
335 }
336
337 // set output to local max and store location of max
338 int64_t i = od * output_height * output_width + oh * output_width + ow;
339 output_ptr[i] = scalar_t(maxval);
340 indices_ptr[i] = maxindex;
341 }
342 }
343 }
344 }
345 });
346
347 if (!output_.is_contiguous()) {
348 output_.copy_(output);
349 }
350 if (!indices_.is_contiguous()) {
351 indices_.copy_(indices);
352 }
353 }
354
355 template <typename scalar_t, bool is_3d>
cpu_max_pool_channels_last(const Tensor & output_,const Tensor & indices_,const Tensor & input_,IntArrayRef kWHD,IntArrayRef dWHD,IntArrayRef padWHD,IntArrayRef dilWHD)356 void cpu_max_pool_channels_last(
357 const Tensor& output_,
358 const Tensor& indices_,
359 const Tensor& input_,
360 IntArrayRef kWHD,
361 IntArrayRef dWHD,
362 IntArrayRef padWHD,
363 IntArrayRef dilWHD) {
364 size_t dims = is_3d ? 3 : 2;
365 TORCH_CHECK(kWHD.size() == dims && dWHD.size() == dims && padWHD.size() == dims && dilWHD.size() == dims,
366 "max pooling 2d/3d are not matched");
367 int64_t ndim = input_.ndimension();
368 // MaxPool2d: NHWC
369 // MaxPool3d: NDHWC
370 if (is_3d) {
371 TORCH_CHECK(ndim == 5, "max pooling 3d with channels last format supports tensors with 5 dims");
372 } else {
373 TORCH_CHECK(ndim == 4, "max pooling 2d with channels last format supports tensors with 4 dims");
374 }
375
376 int kW = kWHD[0];
377 int kH = kWHD[1];
378 int dW = dWHD[0];
379 int dH = dWHD[1];
380 int padW = padWHD[0];
381 int padH = padWHD[1];
382 int dilationW = dilWHD[0];
383 int dilationH = dilWHD[1];
384
385 int kD = is_3d ? kWHD[dims - 1] : 1;
386 int dD = is_3d ? dWHD[dims - 1] : 1;
387 int padD = is_3d ? padWHD[dims - 1] : 0;
388 int dilationD = is_3d ? dilWHD[dims - 1] : 1;
389
390 auto memory_format = is_3d ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
391 auto input = input_.contiguous(memory_format);
392 auto output = output_.contiguous(memory_format);
393 auto indices = indices_.contiguous(memory_format);
394
395 auto input_data = input.const_data_ptr<scalar_t>();
396 auto output_data = output.data_ptr<scalar_t>();
397 auto indices_data = indices.data_ptr<int64_t>();
398
399 int64_t nbatch = input.size(0);
400 int64_t channels = input.size(1);
401 int64_t input_depth = is_3d ? input.size(-3) : 1;
402 int64_t input_height = input.size(-2);
403 int64_t input_width = input.size(-1);
404 int64_t output_depth = is_3d ? output.size(-3) : 1;
405 int64_t output_height = output.size(-2);
406 int64_t output_width = output.size(-1);
407
408 using opmath_t = at::opmath_type<scalar_t>;
409 using Vec = vec::Vectorized<scalar_t>;
410 using integer_t = vec::int_same_size_t<opmath_t>;
411 // for the convenience of vectorization, use integer of the same size of scalar_t,
412 // e.g. int32_t for float, int64_t for double
413 // need to make sure doesn't overflow
414 TORCH_CHECK(input_depth * input_height * input_width <= std::numeric_limits<integer_t>::max());
415
416 // parallel on dim N, {D}, H, W
417 at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
418 int64_t n = 0;
419 int64_t od = 0;
420 int64_t oh = 0;
421 int64_t ow = 0;
422 data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
423
424 int64_t size = channels;
425 int64_t len = size - (size % Vec::size());
426 // temp buffer holding index with integer_t
427 auto index_buffer = std::make_unique<integer_t []>(len);
428 integer_t * index_ptr = index_buffer.get();
429 // temp buffer holding max value with opmath_t
430 std::unique_ptr<opmath_t []> max_arr;
431 opmath_t* max_ptr = nullptr;
432 if (!std::is_same<scalar_t, opmath_t>::value) {
433 max_arr = std::make_unique<opmath_t[]>(size);
434 max_ptr = max_arr.get();
435 }
436
437 for (int64_t i = begin; i < end; i++) {
438 int64_t id0 = od * dD - padD;
439 int64_t ih0 = oh * dH - padH;
440 int64_t iw0 = ow * dW - padW;
441 int64_t id1 = std::min(id0 + (kD - 1) * dilationD + 1, input_depth);
442 int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
443 int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
444 while(id0 < 0) { id0 += dilationD; }
445 while(ih0 < 0) { ih0 += dilationH; }
446 while(iw0 < 0) { iw0 += dilationW; }
447
448 scalar_t* out = output_data + i * channels;
449 int64_t* ind = indices_data + i * channels;
450
451 compute_internal(input_data, out, max_ptr, index_ptr, ind, input_depth, input_height, input_width, channels,
452 n, len, size, id0, id1, ih0, ih1, iw0, iw1,
453 dilationD, dilationH, dilationW);
454
455 // convert indice data type
456 vec::convert<integer_t, int64_t>(index_buffer.get(), ind, len);
457
458 // move on to next output index
459 data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
460 }
461 });
462
463 if (!output_.is_contiguous(memory_format)) {
464 output_.copy_(output);
465 }
466 if (!indices_.is_contiguous(memory_format)) {
467 indices_.copy_(indices);
468 }
469 }
470
471
472 template <typename scalar_t, bool is_3d>
cpu_max_pool_backward(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & indices_)473 void cpu_max_pool_backward(
474 const Tensor& grad_input_,
475 const Tensor& grad_output_,
476 const Tensor& indices_) {
477 auto grad_output = grad_output_.contiguous();
478 auto indices = indices_.contiguous();
479 auto grad_input = grad_input_.contiguous();
480
481 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
482 auto indices_data = indices.const_data_ptr<int64_t>();
483 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
484
485 // treat batch size and channels as one dimension
486 //
487 // MaxPool2d:
488 // ndim == 3: CHW
489 // ndim == 4: NCHW
490 //
491 // MaxPool3d:
492 // ndim == 4: CDHW
493 // ndim == 5: NCDHW
494 int64_t ndim = grad_output.ndimension();
495 int64_t channels;
496 if (is_3d) {
497 channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
498 } else {
499 channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
500 }
501 int64_t input_depth = is_3d ? grad_input.size(-3) : 1;
502
503 int64_t input_height = grad_input.size(-2);
504 int64_t input_width = grad_input.size(-1);
505 int64_t output_depth = is_3d ? grad_output.size(-3) : 1;
506 int64_t output_height = grad_output.size(-2);
507 int64_t output_width = grad_output.size(-1);
508
509 // parallel on dim of N, C
510 at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
511 for (const auto c : c10::irange(begin, end)) {
512 scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
513 const scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
514 const int64_t * indices_ptr = indices_data + c * output_depth * output_height * output_width;
515
516 for (int64_t od = 0; od < output_depth; od++) {
517 for (int64_t oh = 0; oh < output_height; oh++) {
518 for (int64_t ow = 0; ow < output_width; ow++) {
519 // retrieve position of max
520 int64_t index = od * output_height * output_width + oh * output_width + ow;
521 int64_t maxindex = indices_ptr[index];
522 if (maxindex != -1) {
523 // update gradient
524 grad_input_ptr[maxindex] += grad_output_ptr[index];
525 }
526 }
527 }
528 }
529 }
530 });
531
532 if (!grad_input_.is_contiguous()) {
533 grad_input_.copy_(grad_input);
534 }
535 }
536
537 template <typename scalar_t, bool is_3d>
cpu_max_pool_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & indices_)538 void cpu_max_pool_backward_channels_last(
539 const Tensor& grad_input_,
540 const Tensor& grad_output_,
541 const Tensor& indices_) {
542 int64_t ndim = grad_output_.ndimension();
543 if (is_3d) {
544 TORCH_CHECK(ndim == 5, "MaxPool3d backward with channels last format supports tensors with 5 dims.");
545 } else {
546 TORCH_CHECK(ndim == 4, "MaxPool2d backward with channels last format supports tensors with 4 dims.");
547 }
548 auto memory_format = is_3d ? at::MemoryFormat::ChannelsLast3d
549 : at::MemoryFormat::ChannelsLast;
550 auto grad_input = grad_input_.contiguous(memory_format);
551 auto grad_output = grad_output_.contiguous(memory_format);
552 auto indices = indices_.contiguous(memory_format);
553
554 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
555 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
556 auto indices_data = indices.const_data_ptr<int64_t>();
557
558 // MaxPool2d: NHWC
559 // MaxPool3d: NDHWC
560 int64_t nbatch = grad_input.size(0);
561 int64_t channels = grad_input.size(1);
562 int64_t input_depth = is_3d ? grad_input.size(2) : 1;
563 int64_t input_height = grad_input.size(-2);
564 int64_t input_width = grad_input.size(-1);
565 int64_t output_depth = is_3d ? grad_output.size(2) : 1;
566 int64_t output_height = grad_output.size(-2);
567 int64_t output_width = grad_output.size(-1);
568
569 // parallel on dim N
570 at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
571 for (const auto n : c10::irange(begin, end)) {
572 scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
573 const scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
574 const int64_t* indices_ptr = indices_data + n * output_depth * output_height * output_width * channels;
575
576 for (int64_t od = 0; od < output_depth; od++) {
577 for (int64_t oh = 0; oh < output_height; oh++) {
578 for (int64_t ow = 0; ow < output_width; ow++) {
579 const scalar_t* gout = grad_output_ptr + (od * output_height * output_width + oh * output_width + ow) * channels;
580 const int64_t* ind = indices_ptr + (od * output_height * output_width + oh * output_width + ow) * channels;
581 // TODO: gcc vectorization
582 for (int64_t c = 0; c < channels; c++) {
583 int64_t maxindex = ind[c];
584 if (maxindex != -1) {
585 grad_input_ptr[maxindex * channels + c] += gout[c];
586 }
587 }
588 }
589 }
590 }
591 }
592 });
593
594 if (!grad_input_.is_contiguous(memory_format)) {
595 grad_input_.copy_(grad_input);
596 }
597 }
598
max_pool2d_kernel_impl(const Tensor & output,const Tensor & indices,const Tensor & input,int kW,int kH,int dW,int dH,int padW,int padH,int dilationW,int dilationH)599 void max_pool2d_kernel_impl(
600 const Tensor& output,
601 const Tensor& indices,
602 const Tensor& input,
603 int kW, int kH,
604 int dW, int dH,
605 int padW, int padH,
606 int dilationW, int dilationH) {
607 switch (input.suggest_memory_format()) {
608 case at::MemoryFormat::Contiguous: {
609 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
610 cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
611 });
612 break;
613 }
614 case at::MemoryFormat::ChannelsLast: {
615 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
616 cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
617 });
618 break;
619 }
620 default:
621 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
622 }
623 }
624
max_pool3d_kernel_impl(Tensor & output,Tensor & indices,const Tensor & input,int kW,int kH,int kD,int dW,int dH,int dD,int padW,int padH,int padD,int dilationW,int dilationH,int dilationD)625 void max_pool3d_kernel_impl(
626 Tensor& output,
627 Tensor& indices,
628 const Tensor& input,
629 int kW, int kH, int kD,
630 int dW, int dH, int dD,
631 int padW, int padH, int padD,
632 int dilationW, int dilationH, int dilationD) {
633 if (input.ndimension() == 4) {
634 Tensor input_cl_check = input.unsqueeze(0);
635 // align with cuda:
636 // work around buggy behavior of suggest_memory_format here where
637 // suggested format of unsqueezed tensor is contiguous while it is
638 // really only contiguous in ChannelsLast3d
639 if ((!input_cl_check.is_contiguous()) &&
640 input_cl_check.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
641 TORCH_CHECK(output.ndimension() == 4 && indices.ndimension() == 4);
642 DimVector out_sizes(output.sizes().begin(), output.sizes().end());
643 out_sizes.insert(out_sizes.begin(), 1);
644 output.resize_(out_sizes, at::MemoryFormat::ChannelsLast3d);
645 DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
646 indices_sizes.insert(indices_sizes.begin(), 1);
647 indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
648 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
649 cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
650 {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
651 });
652 output.squeeze_(0);
653 indices.squeeze_(0);
654 return;
655 }
656 }
657 switch (input.suggest_memory_format()) {
658 case at::MemoryFormat::Contiguous: {
659 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
660 cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
661 {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
662 });
663 break;
664 }
665 case at::MemoryFormat::ChannelsLast3d: {
666 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
667 cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
668 {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
669 });
670 break;
671 }
672 default:
673 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
674 }
675 }
676
max_pool2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,const Tensor & indices)677 void max_pool2d_backward_kernel_impl(
678 const Tensor& grad_input,
679 const Tensor& grad_output,
680 const Tensor& indices) {
681 switch (grad_output.suggest_memory_format()) {
682 case at::MemoryFormat::Contiguous: {
683 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
684 cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
685 });
686 break;
687 }
688 case at::MemoryFormat::ChannelsLast: {
689 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
690 cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
691 });
692 break;
693 }
694 default:
695 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
696 }
697 }
698
max_pool3d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output,const Tensor & indices)699 void max_pool3d_backward_kernel_impl(
700 Tensor& grad_input,
701 const Tensor& grad_output,
702 const Tensor& indices) {
703 if (grad_output.ndimension() == 4) {
704 Tensor grad_output_cl_check = grad_output.unsqueeze(0);
705 // align with cuda:
706 // work around buggy behavior of suggest_memory_format here where
707 // suggested format of unsqueezed tensor is contiguous while it is
708 // really only contiguous in ChannelsLast3d
709 if ((!grad_output_cl_check.is_contiguous()) &&
710 grad_output_cl_check.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
711 TORCH_CHECK(grad_input.ndimension() == 4 && indices.ndimension() == 4);
712 DimVector sizes(grad_input.sizes().begin(), grad_input.sizes().end());
713 sizes.insert(sizes.begin(), 1);
714 grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
715 auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
716 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
717 cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
718 });
719 grad_input.squeeze_(0);
720 return;
721 }
722 }
723 switch (grad_output.suggest_memory_format()) {
724 case at::MemoryFormat::Contiguous: {
725 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
726 cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
727 });
728 break;
729 }
730 case at::MemoryFormat::ChannelsLast3d: {
731 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
732 cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
733 });
734 break;
735 }
736 default:
737 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
738 }
739 }
740
741 } // anonymous namespace
742
743 REGISTER_DISPATCH(max_pool2d_kernel, &max_pool2d_kernel_impl);
744 REGISTER_DISPATCH(max_pool2d_backward_kernel, &max_pool2d_backward_kernel_impl);
745 REGISTER_DISPATCH(max_pool3d_kernel, &max_pool3d_kernel_impl);
746 REGISTER_DISPATCH(max_pool3d_backward_kernel, &max_pool3d_backward_kernel_impl);
747 } // at::native
748