1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/AdaptivePooling.h>
4 #include <c10/util/irange.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/adaptive_max_pool2d_backward_native.h>
11 #include <ATen/ops/adaptive_max_pool2d_native.h>
12 #endif
13
14 namespace at::meta {
TORCH_META_FUNC(adaptive_max_pool2d)15 TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_size) {
16 int ndim = input.ndimension();
17 TORCH_CHECK(ndim == 3 || ndim == 4,
18 "adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: ",
19 input.sizes());
20 for (const auto i : c10::irange(1, ndim)) {
21 TORCH_CHECK(input.size(i) > 0,
22 "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
23 "but input has sizes ", input.sizes(), " with dimension ", i,
24 " being empty");
25 }
26
27 TORCH_CHECK(output_size.size() == 2,
28 "adaptive_max_pool2d(): internal error: output_size.size() must be 2");
29
30 int dimH = 1;
31 int64_t sizeB = 1;
32 int64_t sizeD = 0;
33
34 if (input.ndimension() == 4) {
35 sizeB = input.size(0);
36 dimH++;
37 }
38
39 sizeD = input.size(dimH - 1);
40
41 int64_t osizeH = output_size[0];
42 int64_t osizeW = output_size[1];
43
44 /* resize output */
45 if (input.ndimension() == 3) {
46 set_output_raw_strided(0, {sizeD, osizeH, osizeW}, {}, input.options());
47 /* indices will contain i,j locations for each output point */
48 set_output_raw_strided(1, {sizeD, osizeH, osizeW}, {}, input.options().dtype(kLong));
49 } else {
50 set_output_raw_strided(0, {sizeB, sizeD, osizeH, osizeW}, {}, input.options().memory_format(input.suggest_memory_format()));
51 /* indices will contain i,j locations for each output point */
52 set_output_raw_strided(1, {sizeB, sizeD, osizeH, osizeW}, {}, input.options().memory_format(input.suggest_memory_format()).dtype(kLong));
53 }
54 }
55
TORCH_META_FUNC(adaptive_max_pool2d_backward)56 TORCH_META_FUNC(adaptive_max_pool2d_backward)
57 (const Tensor& grad_output, const Tensor& input, const Tensor& indices) {
58 int64_t ndim = grad_output.ndimension();
59 TORCH_CHECK(ndim == 3 || ndim == 4,
60 "adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: ", grad_output.sizes());
61
62 at::native::adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward");
63
64 TORCH_CHECK(input.dtype() == grad_output.dtype(),
65 "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype());
66
67 set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(input.suggest_memory_format()));
68 }
69 } // namespace at::meta
70
71 namespace at::native {
72
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cpu)73 TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cpu)
74 (const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
75 adaptive_max_pool2d_kernel(kCPU, output, indices, input, output_size);
76 }
77
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cpu)78 TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cpu)
79 (const Tensor& grad_output, const Tensor& input, const Tensor& indices, const Tensor& grad_input) {
80 grad_input.zero_();
81 adaptive_max_pool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
82 }
83
84 DEFINE_DISPATCH(adaptive_max_pool2d_kernel);
85 DEFINE_DISPATCH(adaptive_max_pool2d_backward_kernel);
86
87 } // namespace at::native
88