xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Im2Col.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 
6 #include <ATen/native/im2col.h>
7 #include <ATen/native/im2col_shape_check.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/col2im_native.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/im2col_native.h>
17 #endif
18 
19 namespace at::native {
20 namespace {
21 
im2col_out_cpu_template(Tensor & output,const Tensor & input_,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)22 static void im2col_out_cpu_template(
23     Tensor& output,
24     const Tensor& input_,
25     IntArrayRef kernel_size,
26     IntArrayRef dilation,
27     IntArrayRef padding,
28     IntArrayRef stride) {
29   TORCH_CHECK(
30       kernel_size.size() == 2,
31       "It is expected kernel_size equals to 2, but got size ",
32       kernel_size.size());
33 
34   TORCH_CHECK(
35       dilation.size() == 2,
36       "It is expected dilation equals to 2, but got size ",
37       dilation.size());
38 
39   TORCH_CHECK(
40       padding.size() == 2,
41       "It is expected padding equals to 2, but got size ",
42       padding.size());
43 
44   TORCH_CHECK(
45       stride.size() == 2,
46       "It is expected stride equals to 2, but got size ",
47       stride.size());
48 
49   int64_t kernel_height = kernel_size[0];
50   int64_t kernel_width = kernel_size[1];
51   int64_t dilation_height = dilation[0];
52   int64_t dilation_width = dilation[1];
53   int64_t pad_height = padding[0];
54   int64_t pad_width = padding[1];
55   int64_t stride_height = stride[0];
56   int64_t stride_width = stride[1];
57 
58   im2col_shape_check(
59       input_,
60       Tensor(),
61       kernel_height,
62       kernel_width,
63       dilation_height,
64       dilation_width,
65       pad_height,
66       pad_width,
67       stride_height,
68       stride_width);
69 
70   Tensor input = input_.contiguous();
71 
72   bool batched_input = true;
73 
74   if (input.dim() == 3) {
75     batched_input = false;
76     input = input.view({1, input.size(0), input.size(1), input.size(2)});
77   }
78 
79   int64_t batch_size = input.size(0);
80   int64_t n_input_plane = input.size(1);
81   int64_t input_height = input.size(2);
82   int64_t input_width = input.size(3);
83 
84   int64_t output_height = (input_height + 2 * pad_height -
85                            (dilation_height * (kernel_height - 1) + 1)) /
86           stride_height +
87       1;
88   int64_t output_width = (input_width + 2 * pad_width -
89                           (dilation_width * (kernel_width - 1) + 1)) /
90           stride_width +
91       1;
92   int64_t n_output_plane = n_input_plane * kernel_width * kernel_height;
93   int64_t output_length = output_height * output_width;
94 
95   output.resize_({batch_size, n_output_plane, output_length});
96 
97   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
98       input.scalar_type(), "im2col_out_cpu", [&] {
99         Tensor input_n;
100         Tensor output_n;
101 
102         for (const auto elt : c10::irange(batch_size)) {
103           input_n = input.select(0, elt);
104           output_n = output.select(0, elt);
105 
106           im2col<scalar_t>(
107               input_n.const_data_ptr<scalar_t>(),
108               n_input_plane,
109               input_height,
110               input_width,
111               output_height,
112               output_width,
113               kernel_height,
114               kernel_width,
115               pad_height,
116               pad_width,
117               stride_height,
118               stride_width,
119               dilation_height,
120               dilation_width,
121               output_n.mutable_data_ptr<scalar_t>());
122         }
123 
124         if (!batched_input) {
125           output.resize_({n_output_plane, output_length});
126         }
127       });
128 }
129 
130 } // namespace
131 
im2col_out_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)132 Tensor& im2col_out_cpu(const Tensor& input,
133     IntArrayRef kernel_size,
134     IntArrayRef dilation,
135     IntArrayRef padding,
136     IntArrayRef stride,
137     Tensor& output) {
138   im2col_out_cpu_template(
139       output, input, kernel_size, dilation, padding, stride);
140   return output;
141 }
142 
im2col_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)143 Tensor im2col_cpu(
144     const Tensor& input,
145     IntArrayRef kernel_size,
146     IntArrayRef dilation,
147     IntArrayRef padding,
148     IntArrayRef stride) {
149   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
150 
151   im2col_out_cpu_template(
152       output, input, kernel_size, dilation, padding, stride);
153   return output;
154 }
155 
156 } // namespace at::native
157