1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/native/Pool.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/max_pool3d_with_indices_backward_native.h>
13 #include <ATen/ops/max_pool3d_with_indices_native.h>
14 #endif
15
16 namespace at::native {
17
18 namespace {
19
20
max_pool3d_with_indices_out_cpu_template(Tensor & output,Tensor & indices,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)21 void max_pool3d_with_indices_out_cpu_template(
22 Tensor& output,
23 Tensor& indices,
24 const Tensor& input,
25 IntArrayRef kernel_size,
26 IntArrayRef stride,
27 IntArrayRef padding,
28 IntArrayRef dilation,
29 bool ceil_mode)
30 {
31 // #20866, #22032: Guarantee this for the official C++ API?
32 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
33 "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
34 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
35 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
36 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
37
38 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
39 "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
40 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
41 const int dH = stride.empty() ? kH :
42 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
43 const int dW = stride.empty() ? kW :
44 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
45
46 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
47 "max_pool3d: padding must either be a single int, or a tuple of three ints");
48 const int pT = safe_downcast<int, int64_t>(padding[0]);
49 const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
50 const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
51
52 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
53 "max_pool3d: dilation must be either a single int, or a tuple of three ints");
54 const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
55 const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
56 const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
57
58 const auto memory_format = input.suggest_memory_format();
59 if (memory_format == at::MemoryFormat::ChannelsLast3d) {
60 TORCH_CHECK(input.ndimension() == 5,
61 "non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
62 } else if (memory_format == at::MemoryFormat::Contiguous) {
63 TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
64 "non-empty 4D or 5D (batch mode) tensor expected for input");
65 } else {
66 TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
67 }
68
69 const int64_t nslices = input.size(-4);
70 const int64_t itime = input.size(-3);
71 const int64_t iheight = input.size(-2);
72 const int64_t iwidth = input.size(-1);
73
74 const int64_t otime = pooling_output_shape<int64_t>(itime, kT, pT, dT, dilationT, ceil_mode);
75 const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
76 const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
77
78 pool3d_shape_check(
79 input,
80 nslices,
81 kT, kH, kW,
82 dT, dH, dW,
83 pT, pH, pW,
84 dilationT, dilationH, dilationW,
85 itime, iheight, iwidth,
86 otime, oheight, owidth,
87 "max_pool3d_with_indices_out_cpu_template()");
88
89
90 if (input.dim() == 4) { /* non-batch mode */
91 /* resize output */
92 output.resize_({nslices, otime, oheight, owidth});
93 /* indices will contain ti,i,j locations for each output point */
94 indices.resize_({nslices, otime, oheight, owidth});
95 }
96 else { /* batch mode */
97 const int64_t nbatch = input.size(0);
98
99 /* resize output */
100 output.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
101 /* indices will contain ti,i,j locations for each output point */
102 indices.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
103 }
104 max_pool3d_kernel(
105 kCPU, output, indices, input,
106 kW, kH, kT,
107 dW, dH, dT,
108 pW, pH, pT,
109 dilationW, dilationH, dilationT);
110 }
111
max_pool3d_with_indices_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,const Tensor & indices,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)112 Tensor& max_pool3d_with_indices_backward_out_cpu_template(
113 Tensor& gradInput,
114 const Tensor& gradOutput,
115 const Tensor& input,
116 const Tensor& indices,
117 IntArrayRef kernel_size,
118 IntArrayRef stride,
119 IntArrayRef padding,
120 IntArrayRef dilation,
121 bool ceil_mode)
122 {
123 // #20866, #22032: Guarantee this for the official C++ API?
124 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
125 "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
126 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
127 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
128 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
129
130 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
131 "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
132 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
133 const int dH = stride.empty() ? kH :
134 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
135 const int dW = stride.empty() ? kW :
136 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
137
138 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
139 "max_pool3d: padding must either be a single int, or a tuple of three ints");
140 const int pT = safe_downcast<int, int64_t>(padding[0]);
141 const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
142 const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
143
144 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
145 "max_pool3d: dilation must be either a single int, or a tuple of three ints");
146 const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
147 const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
148 const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
149
150 TORCH_CHECK(input.dtype() == gradOutput.dtype(),
151 "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
152
153 const auto memory_format = input.suggest_memory_format();
154 if (memory_format == at::MemoryFormat::ChannelsLast3d) {
155 TORCH_CHECK(input.ndimension() == 5,
156 "non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
157 } else if (memory_format == at::MemoryFormat::Contiguous) {
158 TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
159 "non-empty 4D or 5D (batch mode) tensor expected for input");
160 } else {
161 TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
162 }
163
164 const int64_t nslices = input.size(-4);
165 const int64_t itime = input.size(-3);
166 const int64_t iheight = input.size(-2);
167 const int64_t iwidth = input.size(-1);
168
169
170 /* resize */
171 gradInput.resize_(input.sizes(), memory_format);
172 gradInput.zero_();
173
174 const int64_t otime = gradOutput.size(-3);
175 const int64_t oheight = gradOutput.size(-2);
176 const int64_t owidth = gradOutput.size(-1);
177
178 max_pool3d_backward_shape_check(
179 input,
180 gradOutput,
181 indices,
182 nslices,
183 kT, kH, kW,
184 dT, dH, dW,
185 pT, pH, pW,
186 dilationT, dilationH, dilationW,
187 itime, iheight, iwidth,
188 otime, oheight, owidth,
189 "max_pool3d_with_indices_backward_out_cpu_template()");
190
191 max_pool3d_backward_kernel(
192 kCPU, gradInput,
193 gradOutput, indices);
194
195 return gradInput;
196 }
197
198 } // namespace
199
max_pool3d_with_indices_out_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & output,Tensor & indices)200 std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_cpu(const Tensor& input,
201 IntArrayRef kernel_size,
202 IntArrayRef stride,
203 IntArrayRef padding,
204 IntArrayRef dilation,
205 bool ceil_mode,
206 Tensor& output,
207 Tensor& indices)
208 {
209 max_pool3d_with_indices_out_cpu_template(
210 output,
211 indices,
212 input,
213 kernel_size,
214 stride,
215 padding,
216 dilation,
217 ceil_mode);
218 return std::tuple<Tensor&, Tensor&>(output, indices);
219 }
220
max_pool3d_with_indices_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)221 std::tuple<Tensor, Tensor> max_pool3d_with_indices_cpu(
222 const Tensor& input,
223 IntArrayRef kernel_size,
224 IntArrayRef stride,
225 IntArrayRef padding,
226 IntArrayRef dilation,
227 bool ceil_mode)
228 {
229 NoNamesGuard guard;
230
231 Tensor output = at::empty({0}, input.options());
232 Tensor indices = at::empty({0}, input.options().dtype(kLong));
233 max_pool3d_with_indices_out_cpu_template(
234 output,
235 indices,
236 input,
237 kernel_size,
238 stride,
239 padding,
240 dilation,
241 ceil_mode);
242
243 guard.reset();
244 namedinference::propagate_names(output, input);
245 namedinference::propagate_names(indices, input);
246
247 return std::tuple<Tensor, Tensor>(output, indices);
248 }
249
max_pool3d_with_indices_backward_out_cpu(const Tensor & gradOutput_,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices,Tensor & gradInput)250 Tensor& max_pool3d_with_indices_backward_out_cpu(const Tensor& gradOutput_,
251 const Tensor& input,
252 IntArrayRef kernel_size,
253 IntArrayRef stride,
254 IntArrayRef padding,
255 IntArrayRef dilation,
256 bool ceil_mode,
257 const Tensor& indices,
258 Tensor& gradInput)
259 {
260 max_pool3d_with_indices_backward_out_cpu_template(
261 gradInput,
262 gradOutput_,
263 input,
264 indices,
265 kernel_size,
266 stride,
267 padding,
268 dilation,
269 ceil_mode);
270 return gradInput;
271 }
272
max_pool3d_with_indices_backward_cpu(const Tensor & gradOutput_,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices)273 Tensor max_pool3d_with_indices_backward_cpu(
274 const Tensor& gradOutput_,
275 const Tensor& input,
276 IntArrayRef kernel_size,
277 IntArrayRef stride,
278 IntArrayRef padding,
279 IntArrayRef dilation,
280 bool ceil_mode,
281 const Tensor& indices)
282 {
283 auto gradInput = at::empty({0}, input.options());
284 max_pool3d_with_indices_backward_out_cpu_template(
285 gradInput,
286 gradOutput_,
287 input,
288 indices,
289 kernel_size,
290 stride,
291 padding,
292 dilation,
293 ceil_mode);
294 return gradInput;
295 }
296
297 DEFINE_DISPATCH(max_pool3d_kernel);
298 DEFINE_DISPATCH(max_pool3d_backward_kernel);
299 } // namespace at::native
300