xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DilatedMaxPool3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/native/Pool.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/max_pool3d_with_indices_backward_native.h>
13 #include <ATen/ops/max_pool3d_with_indices_native.h>
14 #endif
15 
16 namespace at::native {
17 
18 namespace {
19 
20 
max_pool3d_with_indices_out_cpu_template(Tensor & output,Tensor & indices,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)21 void max_pool3d_with_indices_out_cpu_template(
22           Tensor& output,
23           Tensor& indices,
24           const Tensor& input,
25           IntArrayRef kernel_size,
26           IntArrayRef stride,
27           IntArrayRef padding,
28           IntArrayRef dilation,
29           bool ceil_mode)
30 {
31   // #20866, #22032: Guarantee this for the official C++ API?
32   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
33     "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
34   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
35   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
36   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
37 
38   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
39     "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
40   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
41   const int dH = stride.empty() ? kH :
42                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
43   const int dW = stride.empty() ? kW :
44                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
45 
46   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
47     "max_pool3d: padding must either be a single int, or a tuple of three ints");
48   const int pT = safe_downcast<int, int64_t>(padding[0]);
49   const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
50   const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
51 
52   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
53     "max_pool3d: dilation must be either a single int, or a tuple of three ints");
54   const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
55   const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
56   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
57 
58   const auto memory_format = input.suggest_memory_format();
59   if (memory_format == at::MemoryFormat::ChannelsLast3d) {
60     TORCH_CHECK(input.ndimension() == 5,
61       "non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
62   } else if (memory_format == at::MemoryFormat::Contiguous) {
63     TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
64       "non-empty 4D or 5D (batch mode) tensor expected for input");
65   } else {
66     TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
67   }
68 
69   const int64_t nslices = input.size(-4);
70   const int64_t itime = input.size(-3);
71   const int64_t iheight = input.size(-2);
72   const int64_t iwidth = input.size(-1);
73 
74   const int64_t otime = pooling_output_shape<int64_t>(itime, kT, pT, dT, dilationT, ceil_mode);
75   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
76   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
77 
78   pool3d_shape_check(
79     input,
80     nslices,
81     kT, kH, kW,
82     dT, dH, dW,
83     pT, pH, pW,
84     dilationT, dilationH, dilationW,
85     itime, iheight, iwidth,
86     otime, oheight, owidth,
87     "max_pool3d_with_indices_out_cpu_template()");
88 
89 
90   if (input.dim() == 4) { /* non-batch mode */
91     /* resize output */
92     output.resize_({nslices, otime, oheight, owidth});
93     /* indices will contain ti,i,j locations for each output point */
94     indices.resize_({nslices, otime, oheight, owidth});
95   }
96   else { /* batch mode */
97     const int64_t nbatch = input.size(0);
98 
99     /* resize output */
100     output.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
101     /* indices will contain ti,i,j locations for each output point */
102     indices.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
103   }
104   max_pool3d_kernel(
105       kCPU, output, indices, input,
106       kW, kH, kT,
107       dW, dH, dT,
108       pW, pH, pT,
109       dilationW, dilationH, dilationT);
110 }
111 
max_pool3d_with_indices_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,const Tensor & indices,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)112 Tensor& max_pool3d_with_indices_backward_out_cpu_template(
113           Tensor& gradInput,
114           const Tensor& gradOutput,
115           const Tensor& input,
116           const Tensor& indices,
117           IntArrayRef kernel_size,
118           IntArrayRef stride,
119           IntArrayRef padding,
120           IntArrayRef dilation,
121           bool ceil_mode)
122 {
123   // #20866, #22032: Guarantee this for the official C++ API?
124   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
125     "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
126   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
127   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
128   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
129 
130   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
131     "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
132   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
133   const int dH = stride.empty() ? kH :
134                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
135   const int dW = stride.empty() ? kW :
136                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
137 
138   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
139     "max_pool3d: padding must either be a single int, or a tuple of three ints");
140   const int pT = safe_downcast<int, int64_t>(padding[0]);
141   const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
142   const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
143 
144   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
145     "max_pool3d: dilation must be either a single int, or a tuple of three ints");
146   const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
147   const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
148   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
149 
150   TORCH_CHECK(input.dtype() == gradOutput.dtype(),
151     "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
152 
153   const auto memory_format = input.suggest_memory_format();
154   if (memory_format == at::MemoryFormat::ChannelsLast3d) {
155     TORCH_CHECK(input.ndimension() == 5,
156       "non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
157   } else if (memory_format == at::MemoryFormat::Contiguous) {
158     TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
159       "non-empty 4D or 5D (batch mode) tensor expected for input");
160   } else {
161     TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
162   }
163 
164   const int64_t nslices = input.size(-4);
165   const int64_t itime = input.size(-3);
166   const int64_t iheight = input.size(-2);
167   const int64_t iwidth = input.size(-1);
168 
169 
170   /* resize */
171   gradInput.resize_(input.sizes(), memory_format);
172   gradInput.zero_();
173 
174   const int64_t otime = gradOutput.size(-3);
175   const int64_t oheight = gradOutput.size(-2);
176   const int64_t owidth = gradOutput.size(-1);
177 
178   max_pool3d_backward_shape_check(
179     input,
180     gradOutput,
181     indices,
182     nslices,
183     kT, kH, kW,
184     dT, dH, dW,
185     pT, pH, pW,
186     dilationT, dilationH, dilationW,
187     itime, iheight, iwidth,
188     otime, oheight, owidth,
189     "max_pool3d_with_indices_backward_out_cpu_template()");
190 
191   max_pool3d_backward_kernel(
192       kCPU, gradInput,
193       gradOutput, indices);
194 
195   return gradInput;
196 }
197 
198 } // namespace
199 
max_pool3d_with_indices_out_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & output,Tensor & indices)200 std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_cpu(const Tensor& input,
201   IntArrayRef kernel_size,
202   IntArrayRef stride,
203   IntArrayRef padding,
204   IntArrayRef dilation,
205   bool ceil_mode,
206   Tensor& output,
207   Tensor& indices)
208 {
209   max_pool3d_with_indices_out_cpu_template(
210     output,
211     indices,
212     input,
213     kernel_size,
214     stride,
215     padding,
216     dilation,
217     ceil_mode);
218   return std::tuple<Tensor&, Tensor&>(output, indices);
219 }
220 
max_pool3d_with_indices_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)221 std::tuple<Tensor, Tensor> max_pool3d_with_indices_cpu(
222   const Tensor& input,
223   IntArrayRef kernel_size,
224   IntArrayRef stride,
225   IntArrayRef padding,
226   IntArrayRef dilation,
227   bool ceil_mode)
228 {
229   NoNamesGuard guard;
230 
231   Tensor output = at::empty({0}, input.options());
232   Tensor indices = at::empty({0}, input.options().dtype(kLong));
233   max_pool3d_with_indices_out_cpu_template(
234     output,
235     indices,
236     input,
237     kernel_size,
238     stride,
239     padding,
240     dilation,
241     ceil_mode);
242 
243   guard.reset();
244   namedinference::propagate_names(output, input);
245   namedinference::propagate_names(indices, input);
246 
247   return std::tuple<Tensor, Tensor>(output, indices);
248 }
249 
max_pool3d_with_indices_backward_out_cpu(const Tensor & gradOutput_,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices,Tensor & gradInput)250 Tensor& max_pool3d_with_indices_backward_out_cpu(const Tensor& gradOutput_,
251   const Tensor& input,
252   IntArrayRef kernel_size,
253   IntArrayRef stride,
254   IntArrayRef padding,
255   IntArrayRef dilation,
256   bool ceil_mode,
257   const Tensor& indices,
258   Tensor& gradInput)
259 {
260   max_pool3d_with_indices_backward_out_cpu_template(
261     gradInput,
262     gradOutput_,
263     input,
264     indices,
265     kernel_size,
266     stride,
267     padding,
268     dilation,
269     ceil_mode);
270   return gradInput;
271 }
272 
max_pool3d_with_indices_backward_cpu(const Tensor & gradOutput_,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices)273 Tensor max_pool3d_with_indices_backward_cpu(
274   const Tensor& gradOutput_,
275   const Tensor& input,
276   IntArrayRef kernel_size,
277   IntArrayRef stride,
278   IntArrayRef padding,
279   IntArrayRef dilation,
280   bool ceil_mode,
281   const Tensor& indices)
282 {
283   auto gradInput = at::empty({0}, input.options());
284   max_pool3d_with_indices_backward_out_cpu_template(
285     gradInput,
286     gradOutput_,
287     input,
288     indices,
289     kernel_size,
290     stride,
291     padding,
292     dilation,
293     ceil_mode);
294   return gradInput;
295 }
296 
297 DEFINE_DISPATCH(max_pool3d_kernel);
298 DEFINE_DISPATCH(max_pool3d_backward_kernel);
299 } // namespace at::native
300