xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/MaxPoolKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/AdaptivePooling.h>
3 #include <ATen/core/Tensor.h>
4 
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/Pool.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/irange.h>
12 #include <type_traits>
13 #include <ATen/OpMathType.h>
14 #include <ATen/native/ReduceOpsUtils.h>
15 
16 namespace at::native {
17 
18 namespace {
19 
20 template <typename scalar_t>
is_nan(scalar_t v)21 bool is_nan(scalar_t v) {
22   if (std::is_integral<scalar_t>::value || std::is_same<scalar_t, unsigned char>::value) {
23     return false;
24   }
25   return std::isnan(v);
26 }
27 
28 template <typename scalar_t>
is_nan_vec(vec::Vectorized<scalar_t> vec)29 vec::Vectorized<scalar_t> is_nan_vec(vec::Vectorized<scalar_t> vec) {
30   return vec.isnan();
31 }
32 
33 // TODO: use is_integeral/is_same to check the scalar_t and simplify the implementation
34 // currently it does not work
35 template <>
is_nan_vec(vec::Vectorized<unsigned char> vec)36 vec::Vectorized<unsigned char> is_nan_vec<unsigned char>(vec::Vectorized<unsigned char> vec) {
37   Vectorized<unsigned char> ret(false);
38   return ret;
39 }
40 
41 template <>
is_nan_vec(vec::Vectorized<signed char> vec)42 vec::Vectorized<signed char> is_nan_vec<signed char>(vec::Vectorized<signed char> vec) {
43   Vectorized<signed char> ret(false);
44   return ret;
45 }
46 
47 template <>
is_nan_vec(vec::Vectorized<short> vec)48 vec::Vectorized<short> is_nan_vec<short>(vec::Vectorized<short> vec) {
49   Vectorized<short> ret(false);
50   return ret;
51 }
52 
53 template <>
is_nan_vec(vec::Vectorized<int> vec)54 vec::Vectorized<int> is_nan_vec<int>(vec::Vectorized<int> vec) {
55   Vectorized<int> ret(false);
56   return ret;
57 }
58 
59 template <>
is_nan_vec(vec::Vectorized<int64_t> vec)60 vec::Vectorized<int64_t> is_nan_vec<int64_t>(vec::Vectorized<int64_t> vec) {
61   Vectorized<int64_t> ret(false);
62   return ret;
63 }
64 
65 template <typename scalar_t, typename opmath_t>
66 inline
67 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(const scalar_t * input_data,scalar_t * out_data,opmath_t * max_ptr,vec::int_same_size_t<opmath_t> * index_ptr,int64_t * ind,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t channels,int64_t n,int64_t len,int64_t size,int64_t id0,int64_t id1,int64_t ih0,int64_t ih1,int64_t iw0,int64_t iw1,int64_t dilationD,int64_t dilationH,int64_t dilationW)68 compute_internal(
69   const scalar_t* input_data,
70   scalar_t* out_data,
71   opmath_t* max_ptr,
72   vec::int_same_size_t<opmath_t>* index_ptr,
73   int64_t* ind,
74   int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
75   int64_t n,
76   int64_t len,
77   int64_t size,
78   int64_t id0, int64_t id1,
79   int64_t ih0, int64_t ih1,
80   int64_t iw0, int64_t iw1,
81   int64_t dilationD,
82   int64_t dilationH,
83   int64_t dilationW) {
84   using Vec = vec::Vectorized<scalar_t>;
85   using integer_t = vec::int_same_size_t<opmath_t>;
86   using iVec = vec::Vectorized<integer_t>;
87   // Pass I: init out lane
88   iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
89 
90   scalar_t min_value = lower_bound<scalar_t>();
91   Vec out_vec = Vec(min_value);
92   int64_t d1 = 0;
93   for (; d1 < len; d1 += Vec::size()) {
94     index0_vec.store(index_ptr + d1);
95     out_vec.store(out_data + d1);
96   }
97   for (; d1 < size; d1++) {
98     ind[d1] = ih0 * input_width + iw0;
99     out_data[d1] = min_value;
100   }
101   // Pass II: compute local max
102   for (int64_t id = id0; id < id1; id += dilationD) {
103     for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
104       for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
105         const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
106             id * input_height * input_width + ih * input_width + iw) * channels;
107 
108         int64_t d2 = 0;
109         for (; d2 < len; d2 += Vec::size()) {
110           iVec index_vec = iVec(id * input_height * input_width + ih * input_width + iw);
111           Vec val_vec = Vec::loadu(in + d2);
112           iVec maxindex_vec = iVec::loadu(index_ptr + d2);
113           Vec maxval_vec = Vec::loadu(out_data + d2);
114 
115           // true = all ones, false = all zeros
116           Vec mask = (val_vec > maxval_vec) | is_nan_vec(val_vec);
117           iVec imask = vec::cast<integer_t>(mask);
118           Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
119           iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);
120 
121           out_vec.store(out_data + d2);
122           ind_vec.store(index_ptr + d2);
123         }
124         for (; d2 < size; d2++) {
125           int64_t index = id * input_height * input_width + ih * input_width + iw;
126           scalar_t val = in[d2];
127           int64_t maxindex = ind[d2];
128           scalar_t maxval = out_data[d2];
129 
130           bool mask = (val > maxval) || is_nan(static_cast<double>(val));
131           out_data[d2] = mask ? val : maxval;
132           ind[d2] = mask ? index : maxindex;
133         }
134       }
135     }
136   }
137 }
138 
139 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
140 template <typename scalar_t, typename opmath_t>
141 inline
142 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(const scalar_t * input_data,scalar_t * out_data,opmath_t * max_ptr,vec::int_same_size_t<opmath_t> * index_ptr,int64_t * ind,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t channels,int64_t n,int64_t len,int64_t size,int64_t id0,int64_t id1,int64_t ih0,int64_t ih1,int64_t iw0,int64_t iw1,int64_t dilationD,int64_t dilationH,int64_t dilationW)143 compute_internal(
144   const scalar_t* input_data,
145   scalar_t* out_data,
146   opmath_t* max_ptr,
147   vec::int_same_size_t<opmath_t>* index_ptr,
148   int64_t* ind,
149   int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
150   int64_t n,
151   int64_t len,
152   int64_t size,
153   int64_t id0, int64_t id1,
154   int64_t ih0, int64_t ih1,
155   int64_t iw0, int64_t iw1,
156   int64_t dilationD,
157   int64_t dilationH,
158   int64_t dilationW) {
159   using Vec = vec::Vectorized<scalar_t>;
160   using fVec = vec::Vectorized<opmath_t>;
161   using iVec = vec::Vectorized<int32_t>;
162   // Pass I: init out lane
163   iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
164   fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
165   int64_t d1 = 0;
166   for (; d1 < len; d1 += fVec::size()) {
167     index0_vec.store(index_ptr + d1);
168     out_vec.store(max_ptr + d1);
169   }
170   for (; d1 < size; d1++) {
171     ind[d1] = ih0 * input_width + iw0;
172     max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
173   }
174   // Pass II: compute local max
175   for (int64_t id = id0; id < id1; id += dilationD) {
176     for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
177       for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
178         const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
179             id * input_height * input_width + ih * input_width + iw) * channels;
180 
181         int64_t d2 = 0;
182         for (; d2 < len; d2 += Vec::size()) {
183           iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
184           Vec val_bvec = Vec::loadu(in + d2);
185           auto [val_fvec0, val_fvec1] = convert_to_float<scalar_t>(val_bvec);
186 
187           iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
188           iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
189           fVec maxval_fvec0 = fVec::loadu(max_ptr + d2);
190           fVec maxval_fvec1 = fVec::loadu(max_ptr + d2 + fVec::size());
191 
192           // true = all ones, false = all zeros
193           fVec mask0 = (val_fvec0 > maxval_fvec0) | is_nan_vec(val_fvec0);
194           fVec mask1 = (val_fvec1 > maxval_fvec1) | is_nan_vec(val_fvec1);
195           iVec imask0 = vec::cast<int32_t>(mask0);
196           iVec imask1 = vec::cast<int32_t>(mask1);
197 
198           fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
199           fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
200           iVec ind_vec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
201           iVec ind_vec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
202 
203           max_fvec0.store(max_ptr + d2);
204           max_fvec1.store(max_ptr + d2 + fVec::size());
205           // out_vec.store(out + d2);
206           ind_vec0.store(index_ptr + d2);
207           ind_vec1.store(index_ptr + d2 + iVec::size());
208         }
209         for (; d2 < size; d2++) {
210           int64_t index = id * input_height * input_width + ih * input_width + iw;
211           opmath_t val = opmath_t(in[d2]);
212           int64_t maxindex = ind[d2];
213           opmath_t maxval = max_ptr[d2];
214 
215           bool mask = (val > maxval) || std::isnan(val);
216           max_ptr[d2] = mask ? val : maxval;
217           ind[d2] = mask ? index : maxindex;
218         }
219       }
220     }
221   }
222   // Convert max values from float to bfloat16/half
223   int64_t d3 = 0;
224   for (; d3 < len; d3 += Vec::size()) {
225     fVec max_fvec0 = fVec::loadu(max_ptr + d3);
226     fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
227     Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
228     max_bvec.store(out_data + d3);
229   }
230   for (; d3 < size; d3++) {
231     out_data[d3] = scalar_t(max_ptr[d3]);
232   }
233 }
234 
235 template <typename scalar_t, bool is_3d>
cpu_max_pool(const Tensor & output_,const Tensor & indices_,const Tensor & input_,IntArrayRef kWHD,IntArrayRef dWHD,IntArrayRef padWHD,IntArrayRef dilWHD)236 void cpu_max_pool(
237     const Tensor& output_,
238     const Tensor& indices_,
239     const Tensor& input_,
240     IntArrayRef kWHD,
241     IntArrayRef dWHD,
242     IntArrayRef padWHD,
243     IntArrayRef dilWHD) {
244   size_t dims =  is_3d ? 3 : 2;
245   TORCH_CHECK(kWHD.size() == dims && dWHD.size() == dims && padWHD.size() == dims && dilWHD.size() == dims,
246               "max pooling 2d/3d are not matched");
247   int kW = kWHD[0];
248   int kH = kWHD[1];
249   int dW = dWHD[0];
250   int dH = dWHD[1];
251   int padW = padWHD[0];
252   int padH = padWHD[1];
253   int dilationW = dilWHD[0];
254   int dilationH = dilWHD[1];
255 
256   int kD = is_3d ? kWHD[dims - 1] : 1;
257   int dD = is_3d ? dWHD[dims - 1] : 1;
258   int padD = is_3d ? padWHD[dims - 1] : 0;
259   int dilationD = is_3d ? dilWHD[dims - 1] : 1;
260 
261   auto input = input_.contiguous();
262   auto output = output_.contiguous();
263   auto indices = indices_.contiguous();
264 
265   auto input_data = input.const_data_ptr<scalar_t>();
266   auto output_data = output.data_ptr<scalar_t>();
267   auto indices_data = indices.data_ptr<int64_t>();
268 
269   int64_t ndim = input.ndimension();
270   // treat batch size and channels as one dimension
271   //
272   // MaxPool2d:
273   //   ndim == 3: CHW
274   //   ndim == 4: NCHW
275   //
276   // MaxPool3d:
277   //   ndim == 4: CDHW
278   //   ndim == 5: NCDHW
279   int64_t channels;
280   if (is_3d) {
281     channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
282   } else {
283     channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
284   }
285   int64_t input_depth = is_3d ? input.size(-3) : 1;
286   int64_t input_height = input.size(-2);
287   int64_t input_width = input.size(-1);
288   int64_t output_depth = is_3d ? output.size(-3) : 1;
289   int64_t output_height = output.size(-2);
290   int64_t output_width = output.size(-1);
291 
292   using opmath_t = at::opmath_type<scalar_t>;
293   // parallel on dim N, C
294   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
295     for (int64_t c = begin; c < end; c++) {
296       const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
297       scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
298       int64_t* indices_ptr = indices_data + c * output_depth * output_height * output_width;
299 
300       for (int64_t od = 0; od < output_depth; od++) {
301         int64_t id0 = od * dD - padD;
302         int64_t id1 = std::min(id0 + (kD - 1) * dilationD + 1, input_depth);
303         while(id0 < 0) { id0 += dilationD; }
304 
305         for (int64_t oh = 0; oh < output_height; oh++) {
306           int64_t ih0 = oh * dH - padH;
307           int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
308           while(ih0 < 0) { ih0 += dilationH; }
309 
310           for (int64_t ow = 0; ow < output_width; ow++) {
311             int64_t iw0 = ow * dW - padW;
312             int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
313             while(iw0 < 0) { iw0 += dilationW; }
314 
315             // compute local max
316             int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
317             opmath_t maxval;
318             if (std::numeric_limits<opmath_t>::has_infinity) {
319               maxval = -std::numeric_limits<opmath_t>::infinity();
320             } else {
321               maxval = std::numeric_limits<opmath_t>::min();
322             }
323 
324             for (int64_t id = id0; id < id1; id += dilationD) {
325               for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
326                 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
327                   int64_t index = id * input_height * input_width + ih * input_width + iw;
328                   opmath_t val = input_ptr[index];
329                   if ((val > maxval) || is_nan(static_cast<double>(val))) {
330                     maxval = val;
331                     maxindex = index;
332                   }
333                 }
334               }
335             }
336 
337             // set output to local max and store location of max
338             int64_t i = od * output_height * output_width + oh * output_width + ow;
339             output_ptr[i] = scalar_t(maxval);
340             indices_ptr[i] = maxindex;
341           }
342         }
343       }
344     }
345   });
346 
347   if (!output_.is_contiguous()) {
348     output_.copy_(output);
349   }
350   if (!indices_.is_contiguous()) {
351     indices_.copy_(indices);
352   }
353 }
354 
355 template <typename scalar_t, bool is_3d>
cpu_max_pool_channels_last(const Tensor & output_,const Tensor & indices_,const Tensor & input_,IntArrayRef kWHD,IntArrayRef dWHD,IntArrayRef padWHD,IntArrayRef dilWHD)356 void cpu_max_pool_channels_last(
357     const Tensor& output_,
358     const Tensor& indices_,
359     const Tensor& input_,
360     IntArrayRef kWHD,
361     IntArrayRef dWHD,
362     IntArrayRef padWHD,
363     IntArrayRef dilWHD) {
364   size_t dims =  is_3d ? 3 : 2;
365   TORCH_CHECK(kWHD.size() == dims && dWHD.size() == dims && padWHD.size() == dims && dilWHD.size() == dims,
366               "max pooling 2d/3d are not matched");
367   int64_t ndim = input_.ndimension();
368   // MaxPool2d: NHWC
369   // MaxPool3d: NDHWC
370   if (is_3d) {
371     TORCH_CHECK(ndim == 5, "max pooling 3d with channels last format supports tensors with 5 dims");
372   } else {
373     TORCH_CHECK(ndim == 4, "max pooling 2d with channels last format supports tensors with 4 dims");
374   }
375 
376   int kW = kWHD[0];
377   int kH = kWHD[1];
378   int dW = dWHD[0];
379   int dH = dWHD[1];
380   int padW = padWHD[0];
381   int padH = padWHD[1];
382   int dilationW = dilWHD[0];
383   int dilationH = dilWHD[1];
384 
385   int kD = is_3d ? kWHD[dims - 1] : 1;
386   int dD = is_3d ? dWHD[dims - 1] : 1;
387   int padD = is_3d ? padWHD[dims - 1] : 0;
388   int dilationD = is_3d ? dilWHD[dims - 1] : 1;
389 
390   auto memory_format = is_3d ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
391   auto input = input_.contiguous(memory_format);
392   auto output = output_.contiguous(memory_format);
393   auto indices = indices_.contiguous(memory_format);
394 
395   auto input_data = input.const_data_ptr<scalar_t>();
396   auto output_data = output.data_ptr<scalar_t>();
397   auto indices_data = indices.data_ptr<int64_t>();
398 
399   int64_t nbatch = input.size(0);
400   int64_t channels = input.size(1);
401   int64_t input_depth = is_3d ? input.size(-3) : 1;
402   int64_t input_height = input.size(-2);
403   int64_t input_width = input.size(-1);
404   int64_t output_depth = is_3d ? output.size(-3) : 1;
405   int64_t output_height = output.size(-2);
406   int64_t output_width = output.size(-1);
407 
408   using opmath_t = at::opmath_type<scalar_t>;
409   using Vec = vec::Vectorized<scalar_t>;
410   using integer_t = vec::int_same_size_t<opmath_t>;
411   // for the convenience of vectorization, use integer of the same size of scalar_t,
412   //   e.g. int32_t for float, int64_t for double
413   // need to make sure doesn't overflow
414   TORCH_CHECK(input_depth * input_height * input_width <= std::numeric_limits<integer_t>::max());
415 
416   // parallel on dim N, {D}, H, W
417   at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
418     int64_t n = 0;
419     int64_t od = 0;
420     int64_t oh = 0;
421     int64_t ow = 0;
422     data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
423 
424     int64_t size = channels;
425     int64_t len = size - (size % Vec::size());
426     // temp buffer holding index with integer_t
427     auto index_buffer = std::make_unique<integer_t []>(len);
428     integer_t * index_ptr = index_buffer.get();
429     // temp buffer holding max value with opmath_t
430     std::unique_ptr<opmath_t []> max_arr;
431     opmath_t* max_ptr = nullptr;
432     if (!std::is_same<scalar_t, opmath_t>::value) {
433       max_arr = std::make_unique<opmath_t[]>(size);
434       max_ptr = max_arr.get();
435     }
436 
437     for (int64_t i = begin; i < end; i++) {
438       int64_t id0 = od * dD - padD;
439       int64_t ih0 = oh * dH - padH;
440       int64_t iw0 = ow * dW - padW;
441       int64_t id1 = std::min(id0 + (kD - 1) * dilationD + 1, input_depth);
442       int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
443       int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
444       while(id0 < 0) { id0 += dilationD; }
445       while(ih0 < 0) { ih0 += dilationH; }
446       while(iw0 < 0) { iw0 += dilationW; }
447 
448       scalar_t* out = output_data + i * channels;
449       int64_t* ind = indices_data + i * channels;
450 
451       compute_internal(input_data, out, max_ptr, index_ptr, ind, input_depth, input_height, input_width, channels,
452                         n, len, size, id0, id1, ih0, ih1, iw0, iw1,
453                         dilationD, dilationH, dilationW);
454 
455       // convert indice data type
456       vec::convert<integer_t, int64_t>(index_buffer.get(), ind, len);
457 
458       // move on to next output index
459       data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
460     }
461   });
462 
463   if (!output_.is_contiguous(memory_format)) {
464     output_.copy_(output);
465   }
466   if (!indices_.is_contiguous(memory_format)) {
467     indices_.copy_(indices);
468   }
469 }
470 
471 
472 template <typename scalar_t, bool is_3d>
cpu_max_pool_backward(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & indices_)473 void cpu_max_pool_backward(
474     const Tensor& grad_input_,
475     const Tensor& grad_output_,
476     const Tensor& indices_) {
477   auto grad_output = grad_output_.contiguous();
478   auto indices = indices_.contiguous();
479   auto grad_input = grad_input_.contiguous();
480 
481   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
482   auto indices_data = indices.const_data_ptr<int64_t>();
483   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
484 
485   // treat batch size and channels as one dimension
486   //
487   // MaxPool2d:
488   //   ndim == 3: CHW
489   //   ndim == 4: NCHW
490   //
491   // MaxPool3d:
492   //   ndim == 4: CDHW
493   //   ndim == 5: NCDHW
494   int64_t ndim = grad_output.ndimension();
495   int64_t channels;
496   if (is_3d) {
497     channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
498   } else {
499     channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
500   }
501   int64_t input_depth = is_3d ? grad_input.size(-3) : 1;
502 
503   int64_t input_height = grad_input.size(-2);
504   int64_t input_width = grad_input.size(-1);
505   int64_t output_depth = is_3d ? grad_output.size(-3) : 1;
506   int64_t output_height = grad_output.size(-2);
507   int64_t output_width = grad_output.size(-1);
508 
509   // parallel on dim of N, C
510   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
511     for (const auto c : c10::irange(begin, end)) {
512       scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
513       const scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
514       const int64_t * indices_ptr = indices_data + c * output_depth * output_height * output_width;
515 
516       for (int64_t od = 0; od < output_depth; od++) {
517         for (int64_t oh = 0; oh < output_height; oh++) {
518           for (int64_t ow = 0; ow < output_width; ow++) {
519             // retrieve position of max
520             int64_t index = od * output_height * output_width + oh * output_width + ow;
521             int64_t maxindex = indices_ptr[index];
522             if (maxindex != -1) {
523               // update gradient
524               grad_input_ptr[maxindex] += grad_output_ptr[index];
525             }
526           }
527         }
528       }
529     }
530   });
531 
532   if (!grad_input_.is_contiguous()) {
533     grad_input_.copy_(grad_input);
534   }
535 }
536 
537 template <typename scalar_t, bool is_3d>
cpu_max_pool_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,const Tensor & indices_)538 void cpu_max_pool_backward_channels_last(
539     const Tensor& grad_input_,
540     const Tensor& grad_output_,
541     const Tensor& indices_) {
542   int64_t ndim = grad_output_.ndimension();
543   if (is_3d) {
544     TORCH_CHECK(ndim == 5, "MaxPool3d backward with channels last format supports tensors with 5 dims.");
545   } else {
546     TORCH_CHECK(ndim == 4, "MaxPool2d backward with channels last format supports tensors with 4 dims.");
547   }
548   auto memory_format = is_3d ? at::MemoryFormat::ChannelsLast3d
549                              : at::MemoryFormat::ChannelsLast;
550   auto grad_input = grad_input_.contiguous(memory_format);
551   auto grad_output = grad_output_.contiguous(memory_format);
552   auto indices = indices_.contiguous(memory_format);
553 
554   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
555   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
556   auto indices_data = indices.const_data_ptr<int64_t>();
557 
558   // MaxPool2d: NHWC
559   // MaxPool3d: NDHWC
560   int64_t nbatch = grad_input.size(0);
561   int64_t channels = grad_input.size(1);
562   int64_t input_depth = is_3d ? grad_input.size(2) : 1;
563   int64_t input_height = grad_input.size(-2);
564   int64_t input_width = grad_input.size(-1);
565   int64_t output_depth = is_3d ? grad_output.size(2) : 1;
566   int64_t output_height = grad_output.size(-2);
567   int64_t output_width = grad_output.size(-1);
568 
569   // parallel on dim N
570   at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
571     for (const auto n : c10::irange(begin, end)) {
572       scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
573       const scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
574       const int64_t* indices_ptr = indices_data + n * output_depth * output_height * output_width * channels;
575 
576       for (int64_t od = 0; od < output_depth; od++) {
577         for (int64_t oh = 0; oh < output_height; oh++) {
578           for (int64_t ow = 0; ow < output_width; ow++) {
579             const scalar_t* gout = grad_output_ptr + (od * output_height * output_width + oh * output_width + ow) * channels;
580             const int64_t* ind = indices_ptr + (od * output_height * output_width + oh * output_width + ow) * channels;
581             // TODO: gcc vectorization
582             for (int64_t c = 0; c < channels; c++) {
583               int64_t maxindex = ind[c];
584               if (maxindex != -1) {
585                 grad_input_ptr[maxindex * channels + c] += gout[c];
586               }
587             }
588           }
589         }
590       }
591     }
592   });
593 
594   if (!grad_input_.is_contiguous(memory_format)) {
595     grad_input_.copy_(grad_input);
596   }
597 }
598 
max_pool2d_kernel_impl(const Tensor & output,const Tensor & indices,const Tensor & input,int kW,int kH,int dW,int dH,int padW,int padH,int dilationW,int dilationH)599 void max_pool2d_kernel_impl(
600     const Tensor& output,
601     const Tensor& indices,
602     const Tensor& input,
603     int kW, int kH,
604     int dW, int dH,
605     int padW, int padH,
606     int dilationW, int dilationH) {
607   switch (input.suggest_memory_format()) {
608     case at::MemoryFormat::Contiguous: {
609       AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
610         cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
611       });
612       break;
613     }
614     case at::MemoryFormat::ChannelsLast: {
615       AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
616         cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
617       });
618       break;
619     }
620     default:
621       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
622   }
623 }
624 
max_pool3d_kernel_impl(Tensor & output,Tensor & indices,const Tensor & input,int kW,int kH,int kD,int dW,int dH,int dD,int padW,int padH,int padD,int dilationW,int dilationH,int dilationD)625 void max_pool3d_kernel_impl(
626     Tensor& output,
627     Tensor& indices,
628     const Tensor& input,
629     int kW, int kH, int kD,
630     int dW, int dH, int dD,
631     int padW, int padH, int padD,
632     int dilationW, int dilationH, int dilationD) {
633   if (input.ndimension() == 4) {
634     Tensor input_cl_check = input.unsqueeze(0);
635     // align with cuda:
636     // work around buggy behavior of suggest_memory_format here where
637     // suggested format of unsqueezed tensor is contiguous while it is
638     // really only contiguous in ChannelsLast3d
639     if ((!input_cl_check.is_contiguous()) &&
640                      input_cl_check.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
641       TORCH_CHECK(output.ndimension() == 4 && indices.ndimension() == 4);
642       DimVector out_sizes(output.sizes().begin(), output.sizes().end());
643       out_sizes.insert(out_sizes.begin(), 1);
644       output.resize_(out_sizes, at::MemoryFormat::ChannelsLast3d);
645       DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
646       indices_sizes.insert(indices_sizes.begin(), 1);
647       indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
648       AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
649         cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
650           {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
651       });
652       output.squeeze_(0);
653       indices.squeeze_(0);
654       return;
655     }
656   }
657   switch (input.suggest_memory_format()) {
658     case at::MemoryFormat::Contiguous: {
659       AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
660         cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
661             {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
662       });
663       break;
664     }
665     case at::MemoryFormat::ChannelsLast3d: {
666       AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
667         cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
668           {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
669       });
670       break;
671     }
672     default:
673       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
674   }
675 }
676 
max_pool2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,const Tensor & indices)677 void max_pool2d_backward_kernel_impl(
678     const Tensor& grad_input,
679     const Tensor& grad_output,
680     const Tensor& indices) {
681   switch (grad_output.suggest_memory_format()) {
682     case at::MemoryFormat::Contiguous: {
683       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
684         cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
685       });
686       break;
687     }
688     case at::MemoryFormat::ChannelsLast: {
689       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
690         cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
691       });
692       break;
693     }
694     default:
695       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
696   }
697 }
698 
max_pool3d_backward_kernel_impl(Tensor & grad_input,const Tensor & grad_output,const Tensor & indices)699 void max_pool3d_backward_kernel_impl(
700     Tensor& grad_input,
701     const Tensor& grad_output,
702     const Tensor& indices) {
703   if (grad_output.ndimension() == 4) {
704     Tensor grad_output_cl_check = grad_output.unsqueeze(0);
705     // align with cuda:
706     // work around buggy behavior of suggest_memory_format here where
707     // suggested format of unsqueezed tensor is contiguous while it is
708     // really only contiguous in ChannelsLast3d
709     if ((!grad_output_cl_check.is_contiguous()) &&
710                      grad_output_cl_check.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
711       TORCH_CHECK(grad_input.ndimension() == 4 && indices.ndimension() == 4);
712       DimVector sizes(grad_input.sizes().begin(), grad_input.sizes().end());
713       sizes.insert(sizes.begin(), 1);
714       grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
715       auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
716       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
717         cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
718       });
719       grad_input.squeeze_(0);
720       return;
721     }
722   }
723   switch (grad_output.suggest_memory_format()) {
724     case at::MemoryFormat::Contiguous: {
725       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
726         cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
727       });
728       break;
729     }
730     case at::MemoryFormat::ChannelsLast3d: {
731       AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
732         cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
733       });
734       break;
735     }
736     default:
737       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
738   }
739 }
740 
741 } // anonymous namespace
742 
743 REGISTER_DISPATCH(max_pool2d_kernel, &max_pool2d_kernel_impl);
744 REGISTER_DISPATCH(max_pool2d_backward_kernel, &max_pool2d_backward_kernel_impl);
745 REGISTER_DISPATCH(max_pool3d_kernel, &max_pool3d_kernel_impl);
746 REGISTER_DISPATCH(max_pool3d_backward_kernel, &max_pool3d_backward_kernel_impl);
747 } // at::native
748