1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6
7 #include <ATen/native/AdaptivePooling.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_adaptive_avg_pool3d.h>
14 #include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
15 #include <ATen/ops/_adaptive_avg_pool3d_native.h>
16 #include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
17 #include <ATen/ops/adaptive_avg_pool3d_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/zeros_like.h>
20 #endif
21
22 namespace at::native {
23
24 namespace {
25
26 template <typename scalar_t>
adaptive_avg_pool3d_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)27 static void adaptive_avg_pool3d_out_frame(
28 const scalar_t* input_p,
29 scalar_t* output_p,
30 int64_t sizeD,
31 int64_t isizeT,
32 int64_t isizeH,
33 int64_t isizeW,
34 int64_t osizeT,
35 int64_t osizeH,
36 int64_t osizeW,
37 int64_t istrideD,
38 int64_t istrideT,
39 int64_t istrideH,
40 int64_t istrideW) {
41 at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
42 for (const auto d : c10::irange(start, end)) {
43 /* loop over output */
44 for (const auto ot : c10::irange(osizeT)) {
45 auto istartT = start_index(ot, osizeT, isizeT);
46 auto iendT = end_index(ot, osizeT, isizeT);
47 auto kT = iendT - istartT;
48
49 for (const auto oh : c10::irange(osizeH)) {
50 auto istartH = start_index(oh, osizeH, isizeH);
51 auto iendH = end_index(oh, osizeH, isizeH);
52 auto kH = iendH - istartH;
53
54 for (const auto ow : c10::irange(osizeW)) {
55 auto istartW = start_index(ow, osizeW, isizeW);
56 auto iendW = end_index(ow, osizeW, isizeW);
57 auto kW = iendW - istartW;
58
59 /* local pointers */
60 const scalar_t* ip = input_p + d * istrideD + istartT * istrideT +
61 istartH * istrideH + istartW * istrideW;
62 scalar_t* op = output_p + d * osizeT * osizeH * osizeW +
63 ot * osizeH * osizeW + oh * osizeW + ow;
64
65 /* compute local average: */
66 scalar_t sum = 0;
67 for (const auto it : c10::irange(kT)) {
68 for (const auto ih : c10::irange(kH)) {
69 for (const auto iw : c10::irange(kW)) {
70 scalar_t val =
71 *(ip + it * istrideT + ih * istrideH + iw * istrideW);
72 sum += val;
73 }
74 }
75 }
76
77 /* set output to local average */
78 *op = sum / kT / kH / kW;
79 }
80 }
81 }
82 }
83 });
84 }
85
adaptive_avg_pool3d_out_cpu_template(Tensor & output,Tensor const & input,IntArrayRef output_size)86 void adaptive_avg_pool3d_out_cpu_template(
87 Tensor& output,
88 Tensor const& input,
89 IntArrayRef output_size) {
90 TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
91
92 for (const auto i : c10::irange(1, input.ndimension())) {
93 TORCH_CHECK(
94 input.size(i) > 0,
95 "adaptive_avg_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
96 "but input has sizes ",
97 input.sizes(),
98 " with dimension ",
99 i,
100 " being "
101 "empty");
102 }
103
104 TORCH_CHECK(
105 (input.ndimension() == 4 || input.ndimension() == 5),
106 "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
107 input.sizes());
108 TORCH_CHECK(input.dtype() == output.dtype(),
109 "expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype());
110
111 /* sizes */
112 int64_t sizeD = input.size(-4);
113 int64_t isizeT = input.size(-3);
114 int64_t isizeH = input.size(-2);
115 int64_t isizeW = input.size(-1);
116 /* strides */
117 int64_t istrideD = input.stride(-4);
118 int64_t istrideT = input.stride(-3);
119 int64_t istrideH = input.stride(-2);
120 int64_t istrideW = input.stride(-1);
121 /* output sizes */
122 auto osizeT = output_size[0];
123 auto osizeH = output_size[1];
124 auto osizeW = output_size[2];
125
126 if (input.ndimension() == 4) {
127 output.resize_({sizeD, osizeT, osizeH, osizeW});
128
129 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
130 input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
131 auto input_data = input.const_data_ptr<scalar_t>();
132 auto output_data = output.data_ptr<scalar_t>();
133 adaptive_avg_pool3d_out_frame<scalar_t>(
134 input_data,
135 output_data,
136 sizeD,
137 isizeT,
138 isizeH,
139 isizeW,
140 osizeT,
141 osizeH,
142 osizeW,
143 istrideD,
144 istrideT,
145 istrideH,
146 istrideW);
147 });
148 } else {
149 output.resize_({input.size(-5), sizeD, osizeT, osizeH, osizeW});
150 int64_t n = input.size(0);
151
152 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
153 input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
154 auto input_data = input.const_data_ptr<scalar_t>();
155 auto output_data = output.data_ptr<scalar_t>();
156 at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
157 for (const auto b : c10::irange(start, end)) {
158 adaptive_avg_pool3d_out_frame<scalar_t>(
159 input_data + b * input.stride(0),
160 output_data + b * sizeD * osizeT * osizeH * osizeW,
161 sizeD,
162 isizeT,
163 isizeH,
164 isizeW,
165 osizeT,
166 osizeH,
167 osizeW,
168 istrideD,
169 istrideT,
170 istrideH,
171 istrideW);
172 }
173 });
174 });
175 }
176 }
177
178 template <typename scalar_t>
adaptive_avg_pool3d_backward_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)179 static void adaptive_avg_pool3d_backward_out_frame(
180 scalar_t* gradInput_p,
181 const scalar_t* gradOutput_p,
182 int64_t sizeD,
183 int64_t isizeT,
184 int64_t isizeH,
185 int64_t isizeW,
186 int64_t osizeT,
187 int64_t osizeH,
188 int64_t osizeW) {
189 at::parallel_for(0, sizeD, 1, [&](int64_t start, int64_t end) {
190 for (const auto d : c10::irange(start, end)) {
191 scalar_t* gradInput_p_d = gradInput_p + d * isizeT * isizeW * isizeH;
192 const scalar_t* gradOutput_p_d = gradOutput_p + d * osizeT * osizeW * osizeH;
193
194 /* calculate average */
195 for (const auto ot : c10::irange(osizeT)) {
196 auto istartT = start_index(ot, osizeT, isizeT);
197 auto iendT = end_index(ot, osizeT, isizeT);
198 auto kT = iendT - istartT;
199
200 for (const auto oh : c10::irange(osizeH)) {
201 auto istartH = start_index(oh, osizeH, isizeH);
202 auto iendH = end_index(oh, osizeH, isizeH);
203 auto kH = iendH - istartH;
204
205 for (const auto ow : c10::irange(osizeW)) {
206 auto istartW = start_index(ow, osizeW, isizeW);
207 auto iendW = end_index(ow, osizeW, isizeW);
208 auto kW = iendW - istartW;
209
210 scalar_t grad_delta =
211 gradOutput_p_d[ot * osizeH * osizeW + oh * osizeW + ow] / kT /
212 kH / kW;
213
214 for (const auto it : c10::irange(istartT, iendT)) {
215 for (const auto ih : c10::irange(istartH, iendH)) {
216 for (const auto iw : c10::irange(istartW, iendW)) {
217 /* update gradient */
218 gradInput_p_d[it * isizeH * isizeW + ih * isizeW + iw] +=
219 grad_delta;
220 }
221 }
222 }
223 }
224 }
225 }
226 }
227 });
228 }
229
adaptive_avg_pool3d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput_,const Tensor & input)230 Tensor& adaptive_avg_pool3d_backward_out_cpu_template(
231 Tensor& gradInput,
232 const Tensor& gradOutput_,
233 const Tensor& input) {
234 /* get contiguous gradOutput */
235 auto gradOutput = gradOutput_.contiguous();
236
237 adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
238
239 /* sizes */
240 int64_t sizeD = input.size(-4);
241 int64_t isizeT = input.size(-3);
242 int64_t isizeH = input.size(-2);
243 int64_t isizeW = input.size(-1);
244 int64_t osizeT = gradOutput.size(-3);
245 int64_t osizeH = gradOutput.size(-2);
246 int64_t osizeW = gradOutput.size(-1);
247
248 /* backprop */
249 if (input.ndimension() == 4) {
250 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
251 input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
252 /* get raw pointers */
253 scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
254 const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
255
256 adaptive_avg_pool3d_backward_out_frame<scalar_t>(
257 gradInput_data,
258 gradOutput_data,
259 sizeD,
260 isizeT,
261 isizeH,
262 isizeW,
263 osizeT,
264 osizeH,
265 osizeW);
266 });
267 } else {
268 int64_t n = input.size(0);
269
270 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
271 input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
272 /* get raw pointers */
273 scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
274 const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
275 at::parallel_for(0, n, 1, [&](int64_t start, int64_t end) {
276 for (const auto b : c10::irange(start, end)) {
277 adaptive_avg_pool3d_backward_out_frame<scalar_t>(
278 gradInput_data + b * sizeD * isizeT * isizeH * isizeW,
279 gradOutput_data + b * sizeD * osizeT * osizeH * osizeW,
280 sizeD,
281 isizeT,
282 isizeH,
283 isizeW,
284 osizeT,
285 osizeH,
286 osizeW);
287 }
288 });
289 });
290 }
291 return gradInput;
292 }
293
294 } // namespace
295
adaptive_avg_pool3d_out_cpu(const Tensor & input,IntArrayRef output_size,Tensor & output)296 Tensor& adaptive_avg_pool3d_out_cpu(const Tensor& input,
297 IntArrayRef output_size,
298 Tensor& output) {
299 adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
300 return output;
301 }
302
adaptive_avg_pool3d_cpu(Tensor const & input,IntArrayRef output_size)303 Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) {
304 auto output = at::empty({0}, input.options());
305 adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
306 return output;
307 }
308
adaptive_avg_pool3d_symint(Tensor const & input,SymIntArrayRef output_size)309 Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_size) {
310 TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
311 TORCH_CHECK(
312 (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
313 "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
314 "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}");
315
316 if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1) {
317 // in this case, adaptive pooling is just computing mean over hw
318 // dimensions, which can be done more efficiently
319 Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true);
320 if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d) {
321 // assert ndim == 5, since ndim = 4 doesn't give channels_last
322 const auto n = input.sym_size(0);
323 const auto c = input.sym_size(1);
324 out.as_strided__symint({n, c, 1, 1, 1}, {c, 1, c, c, c});
325 }
326 return out;
327 } else {
328 return _adaptive_avg_pool3d_symint(input, output_size);
329 }
330 }
331
adaptive_avg_pool3d_backward_out_cpu(const Tensor & gradOutput_,const Tensor & input,Tensor & gradInput)332 Tensor& adaptive_avg_pool3d_backward_out_cpu(const Tensor& gradOutput_,
333 const Tensor& input,
334 Tensor& gradInput) {
335 gradInput.resize_as_(input).zero_();
336 adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
337 return gradInput;
338 }
339
adaptive_avg_pool3d_backward_cpu(const Tensor & gradOutput_,const Tensor & input)340 Tensor adaptive_avg_pool3d_backward_cpu(const Tensor& gradOutput_,
341 const Tensor& input) {
342 auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
343 adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
344 return gradInput;
345 }
346
347 } // namespace at::native
348