xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Col2Im.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 
6 #include <ATen/native/im2col.h>
7 #include <ATen/native/im2col_shape_check.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/col2im_native.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/im2col_native.h>
17 #endif
18 
19 // Note [im2col/col2im output padding]
20 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 // Our implementations of im2col and col2im take both the input height/width as
22 // well as a seemingly redundant output height/width.  In principle, you could
23 // compute the output height/width by using the convolution shape formulas.  So,
24 // what's up with that?
25 //
26 // The trouble arises when one runs the backward of a transposed convolution
27 // with output_padding >= stride.  (BTW, output_padding is known as adj inside
28 // THNN.) Let's consider a simple case where we have kernel=2, dilation=2,
29 // stride=1, output_padding=1 for a 4x4 input:
30 //
31 // Input:  X
32 //
33 // Output: X.X.
34 //         ....
35 //         X.X.
36 //         ....
37 //
38 // If we compute backwards of output with a standard convolution on the output
39 // with the same parameters, we would end up with a 2x2 grad_input (because you
40 // can slide the stencil over to the right once and down once).  But that is all
41 // out-of-bounds if you're computing backwards for a 1x1 input.
42 //
43 // "Now Edward," you might say, "the real problem is that you set output_padding
44 // >= stride, surely an error should have been raised in this case."  To
45 // understand why it is useful to handle this case, we have to understand how we
46 // compute the weight gradient of a convolution.  Suppose we have a convolution
47 // with kernel=2, stride=2 on a 5x5 input.  Let us see all the contributions of
48 // weight[0][0] (which we have labeled w) in the output:
49 //
50 // Input:  a.b..  Weight: w.
51 //         .....          ..
52 //         c.d..
53 //         .....
54 //         .....
55 //
56 // Output: [ aw+...  bw+... ]
57 //         [ cw+...  dw+... ]
58 //
59 // From this diagram, it easy to see that we can compute the weight gradient
60 // by performing a *dilated* convolution between the input and the
61 // output gradients with kernel=2, dilation=2, stride=1.  But there's a rub: if
62 // we do a dilated convolution directly, we'll end up with a 3x3 weight
63 // gradient, when we clearly wanted a 2x2.  So how do we avoid going out
64 // of bounds?  We could add a notion of 'output_padding' for non-transposed
65 // convolution, but another simple and effective fix is to just accept
66 // the desired output size directly, and compute only within those bounds.
67 //
68 //
69 // ALSO do vol2col
70 
71 namespace at::native {
72 namespace {
73 
col2im_out_cpu_template(Tensor & output,const Tensor & input_,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)74 static void col2im_out_cpu_template(
75     Tensor& output,
76     const Tensor& input_,
77     IntArrayRef output_size,
78     IntArrayRef kernel_size,
79     IntArrayRef dilation,
80     IntArrayRef padding,
81     IntArrayRef stride) {
82   TORCH_CHECK(
83       output_size.size() == 2,
84       "It is expected output_size equals to 2, but got size ",
85       output_size.size());
86 
87   TORCH_CHECK(
88       kernel_size.size() == 2,
89       "It is expected kernel_size equals to 2, but got size ",
90       kernel_size.size());
91 
92   TORCH_CHECK(
93       dilation.size() == 2,
94       "It is expected dilation equals to 2, but got size ",
95       dilation.size());
96 
97   TORCH_CHECK(
98       padding.size() == 2,
99       "It is expected padding equals to 2, but got size ",
100       padding.size());
101 
102   TORCH_CHECK(
103       stride.size() == 2,
104       "It is expected stride equals to 2, but got size ",
105       stride.size());
106 
107   int64_t output_height = output_size[0];
108   int64_t output_width = output_size[1];
109   int64_t kernel_height = kernel_size[0];
110   int64_t kernel_width = kernel_size[1];
111   int64_t dilation_height = dilation[0];
112   int64_t dilation_width = dilation[1];
113   int64_t pad_height = padding[0];
114   int64_t pad_width = padding[1];
115   int64_t stride_height = stride[0];
116   int64_t stride_width = stride[1];
117 
118   col2im_shape_check(
119       input_,
120       Tensor(),
121       output_height,
122       output_width,
123       kernel_height,
124       kernel_width,
125       dilation_height,
126       dilation_width,
127       pad_height,
128       pad_width,
129       stride_height,
130       stride_width);
131 
132   Tensor input = input_.contiguous();
133 
134   bool batched_input = true;
135   if (input.dim() == 2) {
136     // Force batch
137     batched_input = false;
138     input = input.view({1, input.size(0), input.size(1)});
139   }
140 
141   int64_t batch_size = input.size(0);
142   int64_t n_input_plane = input.size(1);
143   int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height);
144 
145   output.resize_({batch_size, n_output_plane, output_height, output_width});
146 
147   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool,
148       input.scalar_type(), "col2im_out_cpu", [&] {
149         Tensor input_n = Tensor();
150         Tensor output_n = Tensor();
151 
152         int64_t height_col = (output_height + 2 * pad_height -
153                               (dilation_height * (kernel_height - 1) + 1)) /
154                 stride_height +
155             1;
156         int64_t width_col = (output_width + 2 * pad_width -
157                              (dilation_width * (kernel_width - 1) + 1)) /
158                 stride_width +
159             1;
160 
161         for (const auto elt : c10::irange(batch_size)) {
162           input_n = input.select(0, elt);
163           output_n = output.select(0, elt);
164 
165           col2im<scalar_t>(
166               input_n.const_data_ptr<scalar_t>(),
167               n_output_plane,
168               output_height,
169               output_width,
170               height_col,
171               width_col,
172               kernel_height,
173               kernel_width,
174               pad_height,
175               pad_width,
176               stride_height,
177               stride_width,
178               dilation_height,
179               dilation_width,
180               output_n.mutable_data_ptr<scalar_t>());
181         }
182 
183         if (!batched_input) {
184           output.resize_({n_output_plane, output_height, output_width});
185         }
186       });
187 }
188 
189 } // namespace
190 
col2im_out_cpu(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)191 Tensor& col2im_out_cpu(const Tensor& input,
192     IntArrayRef output_size,
193     IntArrayRef kernel_size,
194     IntArrayRef dilation,
195     IntArrayRef padding,
196     IntArrayRef stride,
197     Tensor& output) {
198   col2im_out_cpu_template(
199       output, input, output_size, kernel_size, dilation, padding, stride);
200   return output;
201 }
202 
col2im_cpu(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)203 Tensor col2im_cpu(
204     const Tensor& input,
205     IntArrayRef output_size,
206     IntArrayRef kernel_size,
207     IntArrayRef dilation,
208     IntArrayRef padding,
209     IntArrayRef stride) {
210   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
211 
212   col2im_out_cpu_template(
213       output, input, output_size, kernel_size, dilation, padding, stride);
214   return output;
215 }
216 
217 } // namespace at::native
218