1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/cpu/MaxUnpoolKernel.h>
4 #include <c10/util/irange.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/max_unpool2d_native.h>
12 #include <ATen/ops/max_unpool3d_native.h>
13 #endif
14
15 namespace at::native {
16
max_unpooling2d_forward_out_cpu(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & output)17 Tensor& max_unpooling2d_forward_out_cpu(
18 const Tensor& self_,
19 const Tensor& indices_,
20 IntArrayRef output_size,
21 Tensor& output) {
22 // See Note [Writing Nondeterministic Operations]
23 // Nondeterministic with duplicate indices
24 at::globalContext().alertNotDeterministic("max_unpooling2d_forward_out");
25
26 auto oheight = output_size[0];
27 auto owidth = output_size[1];
28 TORCH_CHECK(
29 indices_.scalar_type() == at::ScalarType::Long,
30 "elements in indices should be type int64 but got: ", indices_.scalar_type());
31 TORCH_CHECK(
32 output_size.size() == 2,
33 "There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
34 TORCH_CHECK(
35 (self_.ndimension() == 3 || self_.ndimension() == 4),
36 "Input to max_unpooling2d should be a 3d or 4d Tensor, but got a tensor with ", self_.ndimension(), " dimensions.");
37 TORCH_CHECK(
38 self_.sizes() == indices_.sizes(),
39 "Expected shape of indices to be same as that of the input tensor (", self_.sizes(),
40 ") but got indices tensor with shape: ", indices_.sizes());
41
42 for (const auto i : c10::irange(1, self_.ndimension())) {
43 TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ",
44 "Expected input to have non-zero size for non-batch dimensions, but got ",
45 self_.sizes(), " with dimension ", i , " being empty.");
46 }
47
48 auto memory_format = self_.suggest_memory_format();
49 auto self = self_.contiguous(memory_format);
50 auto indices = indices_.contiguous(memory_format);
51
52 if (self.ndimension() == 3) {
53 int64_t numChannels = self.size(0);
54 output.resize_({numChannels, oheight, owidth});
55 } else {
56 int64_t numBatch = self.size(0);
57 int64_t numChannels = self.size(1);
58 output.resize_({numBatch, numChannels, oheight, owidth}, memory_format);
59 }
60 output.zero_();
61
62 if (output.numel() != 0) {
63 max_unpool2d_kernel(kCPU, output, self, indices);
64 }
65
66 return output;
67 };
68
max_unpooling2d_forward_cpu(const Tensor & self,const Tensor & indices,IntArrayRef output_size)69 Tensor max_unpooling2d_forward_cpu(
70 const Tensor& self,
71 const Tensor& indices,
72 IntArrayRef output_size) {
73 auto output = at::empty({0}, self.options());
74 at::native::max_unpooling2d_forward_out_cpu(self, indices, output_size, output);
75 return output;
76 }
77
max_unpooling3d_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,const char * fn_name)78 static void max_unpooling3d_shape_check(
79 const Tensor& input,
80 const Tensor& gradOutput,
81 const Tensor& indices,
82 IntArrayRef output_size,
83 IntArrayRef stride,
84 IntArrayRef padding,
85 const char *fn_name) {
86
87 TORCH_CHECK(
88 indices.scalar_type() == at::ScalarType::Long,
89 "elements in indices should be type int64");
90 TORCH_CHECK(
91 (input.ndimension() == 4 || input.ndimension() == 5),
92 "Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with ", input.ndimension(), " dimensions.");
93 TORCH_CHECK(
94 output_size.size() == 3,
95 "There should be exactly three elements (depth, height, width) in output_size, but got ", output_size.size(), " elements.");
96 TORCH_CHECK(
97 stride.size() == 3,
98 "There should be exactly three elements (depth, height, width) in stride, but got: ", stride.size(), " elements.");
99 TORCH_CHECK(
100 padding.size() == 3,
101 "There should be exactly three elements (depth, height, width) in padding, but got: ", padding.size(), " elements.");
102 TORCH_CHECK(
103 input.sizes() == indices.sizes(),
104 "Expected shape of indices to be same as that of the input tensor (", input.sizes(),
105 ") but got indices tensor with shape: ", indices.sizes());
106
107 for (const auto i : c10::irange(1, input.ndimension())) {
108 TORCH_CHECK(input.size(i) > 0, fn_name,
109 ": Expected input to have non-zero size for non-batch dimensions, but got ",
110 input.sizes(), " with dimension ", i , " being empty.");
111 }
112
113 TORCH_CHECK(
114 stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
115 "strides should be greater than zero, but got stride: ",
116 stride);
117
118 int64_t oT = output_size[0];
119 int64_t oH = output_size[1];
120 int64_t oW = output_size[2];
121
122 int dimw = 3;
123 int dimh = 2;
124 int dimt = 1;
125 int dimn = 0;
126
127 if (input.ndimension() == 5) {
128 dimw++;
129 dimh++;
130 dimt++;
131 dimn++;
132 }
133
134 int nslices = input.size(dimn);
135
136 if (gradOutput.defined()) {
137 if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) ||
138 oW != gradOutput.size(dimw)) {
139 AT_ERROR(
140 "Inconsistent gradOutput size. oT= ",
141 oT,
142 ", oH= ",
143 oH,
144 ", oW= ",
145 oW,
146 ". gradOutput: ",
147 gradOutput.size(dimt),
148 "x",
149 gradOutput.size(dimh),
150 "x",
151 gradOutput.size(dimw));
152 }
153 TORCH_CHECK(
154 gradOutput.ndimension() == input.ndimension() &&
155 gradOutput.size(dimn) == nslices,
156 "gradOutput and input Tensors should have same number of dimensions and also the same number of channels/slices");
157 }
158 }
159
max_unpooling3d_forward_out_cpu(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & output)160 Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
161 const Tensor& indices_,
162 IntArrayRef output_size,
163 IntArrayRef stride,
164 IntArrayRef padding,
165 Tensor& output) {
166 // See Note [Writing Nondeterministic Operations]
167 // Nondeterministic with duplicate indices
168 at::globalContext().alertNotDeterministic("max_unpooling3d_forward_out");
169
170 TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
171
172 auto self = self_.contiguous();
173 auto indices = indices_.contiguous();
174
175 max_unpooling3d_shape_check(
176 self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");
177
178 int64_t oT = output_size[0];
179 int64_t oH = output_size[1];
180 int64_t oW = output_size[2];
181
182 if (self_.ndimension() == 5) {
183 output.resize_({self.size(0), self.size(1), oT, oH, oW});
184 } else {
185 output.resize_({self.size(0), oT, oH, oW});
186 }
187 output.zero_();
188 if (output.numel() != 0) {
189 max_unpool3d_kernel(kCPU, output, self, indices);
190 }
191
192 return output;
193 }
194
max_unpooling3d_forward_cpu(const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)195 Tensor max_unpooling3d_forward_cpu(
196 const Tensor& self,
197 const Tensor& indices,
198 IntArrayRef output_size,
199 IntArrayRef stride,
200 IntArrayRef padding) {
201 auto output = at::empty({0}, self.options());
202 at::native::max_unpooling3d_forward_out_cpu(
203 self, indices, output_size, stride, padding, output);
204 return output;
205 }
206
207 DEFINE_DISPATCH(max_unpool2d_kernel);
208 DEFINE_DISPATCH(max_unpool3d_kernel);
209
210 } // namespace at::native
211