#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #endif namespace at::meta { TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_size) { auto ndim = input.ndimension(); TORCH_CHECK( ndim == 4 || ndim == 5, "adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: ", input.sizes()); for (const auto i : c10::irange(1, ndim)) { TORCH_CHECK( input.size(i) > 0, "adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being " "empty"); } TORCH_CHECK( output_size.size() == 3, "adaptive_max_pool3d(): internal error: output_size.size() must be 3"); int dimD = 0; int64_t sizeB = 1; int64_t sizeD = 0; if (ndim == 5) { sizeB = input.size(0); dimD++; } /* sizes */ sizeD = input.size(dimD); int64_t osizeT = output_size[0]; int64_t osizeH = output_size[1]; int64_t osizeW = output_size[2]; /* resize output */ if (ndim == 4) { set_output_raw_strided(0, {sizeD, osizeT, osizeH, osizeW}, {}, input.options()); /* indices will contain max input locations for each output point */ set_output_raw_strided(1, {sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong)); } else { set_output_raw_strided(0, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options()); /* indices will contain max input locations for each output point */ set_output_raw_strided(1, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong)); } } TORCH_META_FUNC(adaptive_max_pool3d_backward) (const Tensor& gradOutput, const Tensor& input, const Tensor& indices) { at::native::adaptive_pool_empty_output_check(gradOutput, "adaptive_max_pool3d_backward"); set_output_raw_strided(0, input.sizes(), {}, input.options()); } } // namespace meta namespace at::native { namespace { // #define START_IND(a,b,c) a * c / b // #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0 // 5d tensor B x D x T x H x W template static void adaptive_max_pool3d_single_out_frame( const scalar_t *input_p, scalar_t *output_p, int64_t *ind_p, int64_t sizeD, int64_t isizeT, int64_t isizeH, int64_t isizeW, int64_t osizeT, int64_t osizeH, int64_t osizeW, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { for (const auto d : c10::irange(start, end)) { /* loop over output */ int64_t ot = 0, oh = 0, ow = 0; for(ot = 0; ot < osizeT; ot++) { int64_t istartT = start_index(ot, osizeT, isizeT); int64_t iendT = end_index(ot, osizeT, isizeT); int64_t kT = iendT - istartT; for(oh = 0; oh < osizeH; oh++) { int64_t istartH = start_index(oh, osizeH, isizeH); int64_t iendH = end_index(oh, osizeH, isizeH); int64_t kH = iendH - istartH; for(ow = 0; ow < osizeW; ow++) { int64_t istartW = start_index(ow, osizeW, isizeW); int64_t iendW = end_index(ow, osizeW, isizeW); int64_t kW = iendW - istartW; /* local pointers */ const scalar_t *ip = input_p + d*istrideD + istartT *istrideT + istartH*istrideH + istartW*istrideW; scalar_t *op = output_p + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow; int64_t *indp = ind_p + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow; /* compute local max: */ int64_t it = 0, ih = 0, iw = 0; int64_t maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW); scalar_t maxval = -std::numeric_limits::infinity(); for(it = 0; it < kT; it++) { for(ih = 0; ih < kH; ih++) { for(iw = 0; iw < kW; iw++) { scalar_t val = *(ip + it*istrideT + ih*istrideH + iw*istrideW); if ((val > maxval) || std::isnan(val)) { maxval = val; maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW); } } } } /* set output to local max */ *op = maxval; /* store location of max */ *indp = maxindex; } } } } }); } template static void adaptive_max_pool3d_out_frame( const scalar_t *input_data, scalar_t *output_data, int64_t *indices_data, int64_t sizeB, int64_t sizeD, int64_t isizeT, int64_t isizeH, int64_t isizeW, int64_t osizeT, int64_t osizeH, int64_t osizeW, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { for (const auto b : c10::irange(start, end)) { adaptive_max_pool3d_single_out_frame(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW, indices_data+b*sizeD*osizeT*osizeH*osizeW, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW); } }); } template static void adaptive_max_pool3d_backward_single_out_frame( scalar_t *gradInput_p, const scalar_t *gradOutput_p, const int64_t *ind_p, int64_t sizeD, int64_t isizeT, int64_t isizeH, int64_t isizeW, int64_t osizeT, int64_t osizeH, int64_t osizeW) { at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { for (const auto d : c10::irange(start, end)) { scalar_t *gradInput_p_d = gradInput_p + d*isizeT*isizeH*isizeW; const scalar_t *gradOutput_p_d = gradOutput_p + d*osizeT*osizeH*osizeW; const int64_t *ind_p_d = ind_p + d*osizeT*osizeH*osizeW; /* calculate max points */ int64_t ot = 0, oh = 0, ow = 0; for(ot = 0; ot < osizeT; ot++) { for(oh = 0; oh < osizeH; oh++) { for(ow = 0; ow < osizeW; ow++) { /* retrieve position of max */ int64_t maxp = ind_p_d[ot*osizeH*osizeW + oh*osizeW + ow]; /* update gradient */ gradInput_p_d[maxp] += gradOutput_p_d[ot*osizeH*osizeW + oh*osizeW + ow]; } } } } }); } template static void adaptive_max_pool3d_backward_out_frame( scalar_t *gradInput_data, const scalar_t *gradOutput_data, const int64_t *indices_data, int64_t sizeB, int64_t sizeD, int64_t isizeT, int64_t isizeH, int64_t isizeW, int64_t osizeT, int64_t osizeH, int64_t osizeW) { at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { for (const auto b : c10::irange(start, end)) { adaptive_max_pool3d_backward_single_out_frame(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW, indices_data+b*sizeD*osizeT*osizeH*osizeW, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW); } }); } } // namespace TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu) (const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) { int dimD = 0; int dimT = 1; int dimH = 2; int dimW = 3; int64_t sizeB = 1; int64_t sizeD = 0; int64_t isizeT = 0; int64_t isizeH = 0; int64_t isizeW = 0; int64_t istrideB = 0; int64_t istrideD = 0; int64_t istrideT = 0; int64_t istrideH = 0; int64_t istrideW = 0; if (input.ndimension() == 5) { istrideB = input.stride(0); sizeB = input.size(0); dimD++; dimT++; dimH++; dimW++; } /* sizes */ sizeD = input.size(dimD); isizeT = input.size(dimT); isizeH = input.size(dimH); isizeW = input.size(dimW); /* strides */ istrideD = input.stride(dimD); istrideT = input.stride(dimT); istrideH = input.stride(dimH); istrideW = input.stride(dimW); int64_t osizeT = output_size[0]; int64_t osizeH = output_size[1]; int64_t osizeW = output_size[2]; if (input.ndimension() == 4) { AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); auto indices_data = indices.data_ptr(); adaptive_max_pool3d_single_out_frame( input_data, output_data, indices_data, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW); }); } else { AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); auto indices_data = indices.data_ptr(); adaptive_max_pool3d_out_frame( input_data, output_data, indices_data, sizeB, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, istrideB, istrideD, istrideT, istrideH, istrideW); }); } } TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu) (const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) { int dimD = 0; int dimT = 1; int dimH = 2; int dimW = 3; int64_t sizeB = 1; int64_t sizeD = 0; int64_t isizeT = 0; int64_t isizeH = 0; int64_t isizeW = 0; int64_t osizeT = 0; int64_t osizeH = 0; int64_t osizeW = 0; /* get contiguous gradOutput */ auto gradOutput_ = gradOutput.contiguous(); /* resize */ gradInput.zero_(); if (input.ndimension() == 5) { sizeB = input.size(0); dimD++; dimT++; dimH++; dimW++; } /* sizes */ sizeD = input.size(dimD); isizeT = input.size(dimT); isizeH = input.size(dimH); isizeW = input.size(dimW); osizeT = gradOutput_.size(dimT); osizeH = gradOutput_.size(dimH); osizeW = gradOutput_.size(dimW); /* backprop */ if (input.ndimension() == 4) { AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); const int64_t* indices_data = indices.const_data_ptr(); adaptive_max_pool3d_backward_single_out_frame( gradInput_data, gradOutput_data, indices_data, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW); }); } else { AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); const int64_t* indices_data = indices.const_data_ptr(); adaptive_max_pool3d_backward_out_frame( gradInput_data, gradOutput_data, indices_data, sizeB, sizeD, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW); }); } } } // namespace at::native