xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/MaxUnpooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/cpu/MaxUnpoolKernel.h>
4 #include <c10/util/irange.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/max_unpool2d_native.h>
12 #include <ATen/ops/max_unpool3d_native.h>
13 #endif
14 
15 namespace at::native {
16 
max_unpooling2d_forward_out_cpu(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & output)17 Tensor& max_unpooling2d_forward_out_cpu(
18     const Tensor& self_,
19     const Tensor& indices_,
20     IntArrayRef output_size,
21     Tensor& output) {
22   // See Note [Writing Nondeterministic Operations]
23   // Nondeterministic with duplicate indices
24   at::globalContext().alertNotDeterministic("max_unpooling2d_forward_out");
25 
26   auto oheight = output_size[0];
27   auto owidth = output_size[1];
28   TORCH_CHECK(
29       indices_.scalar_type() == at::ScalarType::Long,
30       "elements in indices should be type int64 but got: ", indices_.scalar_type());
31   TORCH_CHECK(
32       output_size.size() == 2,
33       "There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
34   TORCH_CHECK(
35       (self_.ndimension() == 3 || self_.ndimension() == 4),
36       "Input to max_unpooling2d should be a 3d or 4d Tensor, but got a tensor with ", self_.ndimension(), " dimensions.");
37   TORCH_CHECK(
38       self_.sizes() == indices_.sizes(),
39       "Expected shape of indices to be same as that of the input tensor (", self_.sizes(),
40       ") but got indices tensor with shape: ", indices_.sizes());
41 
42   for (const auto i : c10::irange(1, self_.ndimension())) {
43     TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ",
44                 "Expected input to have non-zero size for non-batch dimensions, but got ",
45                 self_.sizes(), " with dimension ", i , " being empty.");
46   }
47 
48   auto memory_format = self_.suggest_memory_format();
49   auto self = self_.contiguous(memory_format);
50   auto indices = indices_.contiguous(memory_format);
51 
52   if (self.ndimension() == 3) {
53     int64_t numChannels = self.size(0);
54     output.resize_({numChannels, oheight, owidth});
55   } else {
56     int64_t numBatch = self.size(0);
57     int64_t numChannels = self.size(1);
58     output.resize_({numBatch, numChannels, oheight, owidth}, memory_format);
59   }
60   output.zero_();
61 
62   if (output.numel() != 0) {
63     max_unpool2d_kernel(kCPU, output, self, indices);
64   }
65 
66   return output;
67 };
68 
max_unpooling2d_forward_cpu(const Tensor & self,const Tensor & indices,IntArrayRef output_size)69 Tensor max_unpooling2d_forward_cpu(
70     const Tensor& self,
71     const Tensor& indices,
72     IntArrayRef output_size) {
73   auto output = at::empty({0}, self.options());
74   at::native::max_unpooling2d_forward_out_cpu(self, indices, output_size, output);
75   return output;
76 }
77 
max_unpooling3d_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,const char * fn_name)78 static void max_unpooling3d_shape_check(
79     const Tensor& input,
80     const Tensor& gradOutput,
81     const Tensor& indices,
82     IntArrayRef output_size,
83     IntArrayRef stride,
84     IntArrayRef padding,
85     const char *fn_name) {
86 
87   TORCH_CHECK(
88       indices.scalar_type() == at::ScalarType::Long,
89       "elements in indices should be type int64");
90   TORCH_CHECK(
91       (input.ndimension() == 4 || input.ndimension() == 5),
92       "Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with ", input.ndimension(), " dimensions.");
93   TORCH_CHECK(
94       output_size.size() == 3,
95       "There should be exactly three elements (depth, height, width) in output_size, but got ", output_size.size(), " elements.");
96   TORCH_CHECK(
97       stride.size() == 3,
98       "There should be exactly three elements (depth, height, width) in stride, but got: ", stride.size(), " elements.");
99   TORCH_CHECK(
100       padding.size() == 3,
101       "There should be exactly three elements (depth, height, width) in padding, but got: ", padding.size(), " elements.");
102   TORCH_CHECK(
103       input.sizes() == indices.sizes(),
104       "Expected shape of indices to be same as that of the input tensor (", input.sizes(),
105       ") but got indices tensor with shape: ", indices.sizes());
106 
107   for (const auto i : c10::irange(1, input.ndimension())) {
108     TORCH_CHECK(input.size(i) > 0, fn_name,
109                 ": Expected input to have non-zero size for non-batch dimensions, but got ",
110                 input.sizes(), " with dimension ", i , " being empty.");
111   }
112 
113   TORCH_CHECK(
114       stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
115       "strides should be greater than zero, but got stride: ",
116       stride);
117 
118   int64_t oT = output_size[0];
119   int64_t oH = output_size[1];
120   int64_t oW = output_size[2];
121 
122   int dimw = 3;
123   int dimh = 2;
124   int dimt = 1;
125   int dimn = 0;
126 
127   if (input.ndimension() == 5) {
128     dimw++;
129     dimh++;
130     dimt++;
131     dimn++;
132   }
133 
134   int nslices = input.size(dimn);
135 
136   if (gradOutput.defined()) {
137     if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) ||
138         oW != gradOutput.size(dimw)) {
139       AT_ERROR(
140           "Inconsistent gradOutput size. oT= ",
141           oT,
142           ", oH= ",
143           oH,
144           ", oW= ",
145           oW,
146           ". gradOutput: ",
147           gradOutput.size(dimt),
148           "x",
149           gradOutput.size(dimh),
150           "x",
151           gradOutput.size(dimw));
152     }
153     TORCH_CHECK(
154         gradOutput.ndimension() == input.ndimension() &&
155             gradOutput.size(dimn) == nslices,
156         "gradOutput and input Tensors should have same number of dimensions and also the same number of channels/slices");
157   }
158 }
159 
max_unpooling3d_forward_out_cpu(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & output)160 Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
161     const Tensor& indices_,
162     IntArrayRef output_size,
163     IntArrayRef stride,
164     IntArrayRef padding,
165     Tensor& output) {
166   // See Note [Writing Nondeterministic Operations]
167   // Nondeterministic with duplicate indices
168   at::globalContext().alertNotDeterministic("max_unpooling3d_forward_out");
169 
170   TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
171 
172   auto self = self_.contiguous();
173   auto indices = indices_.contiguous();
174 
175   max_unpooling3d_shape_check(
176     self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");
177 
178   int64_t oT = output_size[0];
179   int64_t oH = output_size[1];
180   int64_t oW = output_size[2];
181 
182   if (self_.ndimension() == 5) {
183     output.resize_({self.size(0), self.size(1), oT, oH, oW});
184   } else {
185     output.resize_({self.size(0), oT, oH, oW});
186   }
187   output.zero_();
188   if (output.numel() != 0) {
189     max_unpool3d_kernel(kCPU, output, self, indices);
190   }
191 
192   return output;
193 }
194 
max_unpooling3d_forward_cpu(const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)195 Tensor max_unpooling3d_forward_cpu(
196     const Tensor& self,
197     const Tensor& indices,
198     IntArrayRef output_size,
199     IntArrayRef stride,
200     IntArrayRef padding) {
201   auto output = at::empty({0}, self.options());
202   at::native::max_unpooling3d_forward_out_cpu(
203       self, indices, output_size, stride, padding, output);
204   return output;
205 }
206 
207 DEFINE_DISPATCH(max_unpool2d_kernel);
208 DEFINE_DISPATCH(max_unpool3d_kernel);
209 
210 } // namespace at::native
211