1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5
6 #include <ATen/native/im2col.h>
7 #include <ATen/native/im2col_shape_check.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/col2im_native.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/im2col_native.h>
17 #endif
18
19 namespace at::native {
20 namespace {
21
im2col_out_cpu_template(Tensor & output,const Tensor & input_,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)22 static void im2col_out_cpu_template(
23 Tensor& output,
24 const Tensor& input_,
25 IntArrayRef kernel_size,
26 IntArrayRef dilation,
27 IntArrayRef padding,
28 IntArrayRef stride) {
29 TORCH_CHECK(
30 kernel_size.size() == 2,
31 "It is expected kernel_size equals to 2, but got size ",
32 kernel_size.size());
33
34 TORCH_CHECK(
35 dilation.size() == 2,
36 "It is expected dilation equals to 2, but got size ",
37 dilation.size());
38
39 TORCH_CHECK(
40 padding.size() == 2,
41 "It is expected padding equals to 2, but got size ",
42 padding.size());
43
44 TORCH_CHECK(
45 stride.size() == 2,
46 "It is expected stride equals to 2, but got size ",
47 stride.size());
48
49 int64_t kernel_height = kernel_size[0];
50 int64_t kernel_width = kernel_size[1];
51 int64_t dilation_height = dilation[0];
52 int64_t dilation_width = dilation[1];
53 int64_t pad_height = padding[0];
54 int64_t pad_width = padding[1];
55 int64_t stride_height = stride[0];
56 int64_t stride_width = stride[1];
57
58 im2col_shape_check(
59 input_,
60 Tensor(),
61 kernel_height,
62 kernel_width,
63 dilation_height,
64 dilation_width,
65 pad_height,
66 pad_width,
67 stride_height,
68 stride_width);
69
70 Tensor input = input_.contiguous();
71
72 bool batched_input = true;
73
74 if (input.dim() == 3) {
75 batched_input = false;
76 input = input.view({1, input.size(0), input.size(1), input.size(2)});
77 }
78
79 int64_t batch_size = input.size(0);
80 int64_t n_input_plane = input.size(1);
81 int64_t input_height = input.size(2);
82 int64_t input_width = input.size(3);
83
84 int64_t output_height = (input_height + 2 * pad_height -
85 (dilation_height * (kernel_height - 1) + 1)) /
86 stride_height +
87 1;
88 int64_t output_width = (input_width + 2 * pad_width -
89 (dilation_width * (kernel_width - 1) + 1)) /
90 stride_width +
91 1;
92 int64_t n_output_plane = n_input_plane * kernel_width * kernel_height;
93 int64_t output_length = output_height * output_width;
94
95 output.resize_({batch_size, n_output_plane, output_length});
96
97 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
98 input.scalar_type(), "im2col_out_cpu", [&] {
99 Tensor input_n;
100 Tensor output_n;
101
102 for (const auto elt : c10::irange(batch_size)) {
103 input_n = input.select(0, elt);
104 output_n = output.select(0, elt);
105
106 im2col<scalar_t>(
107 input_n.const_data_ptr<scalar_t>(),
108 n_input_plane,
109 input_height,
110 input_width,
111 output_height,
112 output_width,
113 kernel_height,
114 kernel_width,
115 pad_height,
116 pad_width,
117 stride_height,
118 stride_width,
119 dilation_height,
120 dilation_width,
121 output_n.mutable_data_ptr<scalar_t>());
122 }
123
124 if (!batched_input) {
125 output.resize_({n_output_plane, output_length});
126 }
127 });
128 }
129
130 } // namespace
131
im2col_out_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)132 Tensor& im2col_out_cpu(const Tensor& input,
133 IntArrayRef kernel_size,
134 IntArrayRef dilation,
135 IntArrayRef padding,
136 IntArrayRef stride,
137 Tensor& output) {
138 im2col_out_cpu_template(
139 output, input, kernel_size, dilation, padding, stride);
140 return output;
141 }
142
im2col_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)143 Tensor im2col_cpu(
144 const Tensor& input,
145 IntArrayRef kernel_size,
146 IntArrayRef dilation,
147 IntArrayRef padding,
148 IntArrayRef stride) {
149 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
150
151 im2col_out_cpu_template(
152 output, input, kernel_size, dilation, padding, stride);
153 return output;
154 }
155
156 } // namespace at::native
157