1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5
6 #include <ATen/native/im2col.h>
7 #include <ATen/native/im2col_shape_check.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/col2im_native.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/im2col_native.h>
17 #endif
18
19 // Note [im2col/col2im output padding]
20 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 // Our implementations of im2col and col2im take both the input height/width as
22 // well as a seemingly redundant output height/width. In principle, you could
23 // compute the output height/width by using the convolution shape formulas. So,
24 // what's up with that?
25 //
26 // The trouble arises when one runs the backward of a transposed convolution
27 // with output_padding >= stride. (BTW, output_padding is known as adj inside
28 // THNN.) Let's consider a simple case where we have kernel=2, dilation=2,
29 // stride=1, output_padding=1 for a 4x4 input:
30 //
31 // Input: X
32 //
33 // Output: X.X.
34 // ....
35 // X.X.
36 // ....
37 //
38 // If we compute backwards of output with a standard convolution on the output
39 // with the same parameters, we would end up with a 2x2 grad_input (because you
40 // can slide the stencil over to the right once and down once). But that is all
41 // out-of-bounds if you're computing backwards for a 1x1 input.
42 //
43 // "Now Edward," you might say, "the real problem is that you set output_padding
44 // >= stride, surely an error should have been raised in this case." To
45 // understand why it is useful to handle this case, we have to understand how we
46 // compute the weight gradient of a convolution. Suppose we have a convolution
47 // with kernel=2, stride=2 on a 5x5 input. Let us see all the contributions of
48 // weight[0][0] (which we have labeled w) in the output:
49 //
50 // Input: a.b.. Weight: w.
51 // ..... ..
52 // c.d..
53 // .....
54 // .....
55 //
56 // Output: [ aw+... bw+... ]
57 // [ cw+... dw+... ]
58 //
59 // From this diagram, it easy to see that we can compute the weight gradient
60 // by performing a *dilated* convolution between the input and the
61 // output gradients with kernel=2, dilation=2, stride=1. But there's a rub: if
62 // we do a dilated convolution directly, we'll end up with a 3x3 weight
63 // gradient, when we clearly wanted a 2x2. So how do we avoid going out
64 // of bounds? We could add a notion of 'output_padding' for non-transposed
65 // convolution, but another simple and effective fix is to just accept
66 // the desired output size directly, and compute only within those bounds.
67 //
68 //
69 // ALSO do vol2col
70
71 namespace at::native {
72 namespace {
73
col2im_out_cpu_template(Tensor & output,const Tensor & input_,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)74 static void col2im_out_cpu_template(
75 Tensor& output,
76 const Tensor& input_,
77 IntArrayRef output_size,
78 IntArrayRef kernel_size,
79 IntArrayRef dilation,
80 IntArrayRef padding,
81 IntArrayRef stride) {
82 TORCH_CHECK(
83 output_size.size() == 2,
84 "It is expected output_size equals to 2, but got size ",
85 output_size.size());
86
87 TORCH_CHECK(
88 kernel_size.size() == 2,
89 "It is expected kernel_size equals to 2, but got size ",
90 kernel_size.size());
91
92 TORCH_CHECK(
93 dilation.size() == 2,
94 "It is expected dilation equals to 2, but got size ",
95 dilation.size());
96
97 TORCH_CHECK(
98 padding.size() == 2,
99 "It is expected padding equals to 2, but got size ",
100 padding.size());
101
102 TORCH_CHECK(
103 stride.size() == 2,
104 "It is expected stride equals to 2, but got size ",
105 stride.size());
106
107 int64_t output_height = output_size[0];
108 int64_t output_width = output_size[1];
109 int64_t kernel_height = kernel_size[0];
110 int64_t kernel_width = kernel_size[1];
111 int64_t dilation_height = dilation[0];
112 int64_t dilation_width = dilation[1];
113 int64_t pad_height = padding[0];
114 int64_t pad_width = padding[1];
115 int64_t stride_height = stride[0];
116 int64_t stride_width = stride[1];
117
118 col2im_shape_check(
119 input_,
120 Tensor(),
121 output_height,
122 output_width,
123 kernel_height,
124 kernel_width,
125 dilation_height,
126 dilation_width,
127 pad_height,
128 pad_width,
129 stride_height,
130 stride_width);
131
132 Tensor input = input_.contiguous();
133
134 bool batched_input = true;
135 if (input.dim() == 2) {
136 // Force batch
137 batched_input = false;
138 input = input.view({1, input.size(0), input.size(1)});
139 }
140
141 int64_t batch_size = input.size(0);
142 int64_t n_input_plane = input.size(1);
143 int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height);
144
145 output.resize_({batch_size, n_output_plane, output_height, output_width});
146
147 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
148 input.scalar_type(), "col2im_out_cpu", [&] {
149 Tensor input_n = Tensor();
150 Tensor output_n = Tensor();
151
152 int64_t height_col = (output_height + 2 * pad_height -
153 (dilation_height * (kernel_height - 1) + 1)) /
154 stride_height +
155 1;
156 int64_t width_col = (output_width + 2 * pad_width -
157 (dilation_width * (kernel_width - 1) + 1)) /
158 stride_width +
159 1;
160
161 for (const auto elt : c10::irange(batch_size)) {
162 input_n = input.select(0, elt);
163 output_n = output.select(0, elt);
164
165 col2im<scalar_t>(
166 input_n.const_data_ptr<scalar_t>(),
167 n_output_plane,
168 output_height,
169 output_width,
170 height_col,
171 width_col,
172 kernel_height,
173 kernel_width,
174 pad_height,
175 pad_width,
176 stride_height,
177 stride_width,
178 dilation_height,
179 dilation_width,
180 output_n.mutable_data_ptr<scalar_t>());
181 }
182
183 if (!batched_input) {
184 output.resize_({n_output_plane, output_height, output_width});
185 }
186 });
187 }
188
189 } // namespace
190
col2im_out_cpu(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)191 Tensor& col2im_out_cpu(const Tensor& input,
192 IntArrayRef output_size,
193 IntArrayRef kernel_size,
194 IntArrayRef dilation,
195 IntArrayRef padding,
196 IntArrayRef stride,
197 Tensor& output) {
198 col2im_out_cpu_template(
199 output, input, output_size, kernel_size, dilation, padding, stride);
200 return output;
201 }
202
col2im_cpu(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)203 Tensor col2im_cpu(
204 const Tensor& input,
205 IntArrayRef output_size,
206 IntArrayRef kernel_size,
207 IntArrayRef dilation,
208 IntArrayRef padding,
209 IntArrayRef stride) {
210 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
211
212 col2im_out_cpu_template(
213 output, input, output_size, kernel_size, dilation, padding, stride);
214 return output;
215 }
216
217 } // namespace at::native
218