1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/div_rtn.h>
5
6 namespace at::native {
7
col2im_shape_check(const Tensor & input,const Tensor & grad_output,int64_t output_height,int64_t output_width,int64_t kernel_height,int64_t kernel_width,int64_t dilation_height,int64_t dilation_width,int64_t pad_height,int64_t pad_width,int64_t stride_height,int64_t stride_width)8 inline void col2im_shape_check(
9 const Tensor& input,
10 const Tensor& grad_output,
11 int64_t output_height,
12 int64_t output_width,
13 int64_t kernel_height,
14 int64_t kernel_width,
15 int64_t dilation_height,
16 int64_t dilation_width,
17 int64_t pad_height,
18 int64_t pad_width,
19 int64_t stride_height,
20 int64_t stride_width) {
21 TORCH_CHECK(
22 kernel_width > 0 && kernel_height > 0,
23 "kernel size should be greater than zero, but got kernel_height: ",
24 kernel_height,
25 " kernel_width: ",
26 kernel_width);
27 TORCH_CHECK(
28 stride_width > 0 && stride_height > 0,
29 "stride should be greater than zero, but got stride_height: ",
30 stride_height,
31 " stride_width: ",
32 stride_width);
33 TORCH_CHECK(
34 dilation_width > 0 && dilation_height > 0,
35 "dilation should be greater than zero, but got dilation_height: ",
36 dilation_height,
37 " dilation_width: ",
38 dilation_width);
39 TORCH_CHECK(
40 pad_width >= 0 && pad_height >= 0,
41 "padding should be non-negative, but got pad_height: ",
42 pad_height,
43 " pad_width: ",
44 pad_width);
45
46
47 int64_t ndim = input.ndimension();
48 // allow dim=0 only the batch dimension.
49 TORCH_CHECK(
50 (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
51 (ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
52 "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
53 input.sizes());
54
55 int64_t batch_dim = (ndim == 3) ? 0 : -1;
56 int64_t n_input_plane = input.size(batch_dim + 1);
57
58 if (n_input_plane % (kernel_width * kernel_height) != 0) {
59 AT_ERROR(
60 "Expected size of input's dimension 1 to be divisible by the "
61 "product of kernel_size, but got input.size(1)=",
62 n_input_plane,
63 " and kernel_size=(",
64 kernel_height,
65 ", ",
66 kernel_width,
67 ").");
68 }
69
70 int64_t input_length = input.size(batch_dim + 2);
71 int64_t n_blocks_height =
72 div_rtn<int64_t>(
73 output_height + 2 * pad_height -
74 dilation_height * (kernel_height - 1) - 1,
75 stride_height) +
76 1;
77 int64_t n_blocks_width = div_rtn<int64_t>(
78 output_width + 2 * pad_width -
79 dilation_width * (kernel_width - 1) - 1,
80 stride_width) +
81 1;
82
83 if (input_length != (n_blocks_height * n_blocks_width)) {
84 AT_ERROR(
85 "Given output_size=(",
86 output_height,
87 ", ",
88 output_width,
89 "), kernel_size=(",
90 kernel_height,
91 ", ",
92 kernel_width,
93 "), dilation=(",
94 dilation_height,
95 ", ",
96 dilation_width,
97 "), padding=(",
98 pad_height,
99 ", ",
100 pad_width,
101 "), stride=(",
102 stride_height,
103 ", ",
104 stride_width,
105 "), expected size of input's dimension 2 to match the calculated number of ",
106 "sliding blocks ",
107 n_blocks_height,
108 " * ",
109 n_blocks_width,
110 " = ",
111 (n_blocks_height * n_blocks_width),
112 ", but got input.size(2)=",
113 input_length,
114 ".");
115 }
116
117 TORCH_CHECK(
118 n_blocks_height >= 1 && n_blocks_width >= 1,
119 "Given output_size=(", output_height, ", ", output_width, "), ",
120 "kernel_size=(", kernel_height, ", ", kernel_width, "), ",
121 "dilation=(", dilation_height, ", ", dilation_width, "), ",
122 "padding=(", pad_height, ", ", pad_width, "), ",
123 "stride=(", stride_height, ", ", stride_width, "), ",
124 "calculated shape of the array of sliding blocks as ",
125 "(", n_blocks_height, ", ", n_blocks_width, "), ",
126 "which is too small (non-positive)");
127
128 if (output_width < 1 || output_height < 1) {
129 AT_ERROR(
130 "Expected output spatial size to be positive, but got: output_size=(",
131 output_height,
132 ", ",
133 output_width,
134 ").");
135 }
136 }
137
im2col_shape_check(const Tensor & input,const Tensor & grad_output,int64_t kernel_height,int64_t kernel_width,int64_t dilation_height,int64_t dilation_width,int64_t pad_height,int64_t pad_width,int64_t stride_height,int64_t stride_width)138 inline void im2col_shape_check(
139 const Tensor& input,
140 const Tensor& grad_output,
141 int64_t kernel_height,
142 int64_t kernel_width,
143 int64_t dilation_height,
144 int64_t dilation_width,
145 int64_t pad_height,
146 int64_t pad_width,
147 int64_t stride_height,
148 int64_t stride_width) {
149 TORCH_CHECK(
150 kernel_width > 0 && kernel_height > 0,
151 "kernel size should be greater than zero, but got kernel_height: ",
152 kernel_height,
153 " kernel_width: ",
154 kernel_width);
155
156 TORCH_CHECK(
157 dilation_width > 0 && dilation_height > 0,
158 "dilation should be greater than zero, but got dilation_height: ",
159 dilation_height,
160 " dilation_width: ",
161 dilation_width);
162
163 TORCH_CHECK(
164 pad_width >= 0 && pad_height >= 0,
165 "padding should be non-negative, but got pad_height: ",
166 pad_height,
167 " pad_width: ",
168 pad_width);
169
170 TORCH_CHECK(
171 stride_width > 0 && stride_height > 0,
172 "stride should be greater than zero, but got stride_height: ",
173 stride_height,
174 " stride_width: ",
175 stride_width);
176
177 int64_t ndim = input.ndimension();
178
179 // allow dim=0 only the batch dimension.
180 bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
181 TORCH_CHECK(
182 (ndim == 3 && input.size(0) && valid_dims) ||
183 (ndim == 4 && valid_dims && input.size(3) != 0),
184 "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
185 input.sizes());
186
187 int64_t dim_batch = 0;
188
189 if (ndim == 3) {
190 dim_batch = -1;
191 }
192
193 int64_t input_height = input.size(dim_batch + 2);
194 int64_t input_width = input.size(dim_batch + 3);
195 int64_t output_height = div_rtn<int64_t>(
196 input_height + 2 * pad_height -
197 (dilation_height * (kernel_height - 1) + 1),
198 stride_height) +
199 1;
200 int64_t output_width = div_rtn<int64_t>(
201 input_width + 2 * pad_width -
202 (dilation_width * (kernel_width - 1) + 1),
203 stride_width) +
204 1;
205
206 if (output_height < 1 || output_width < 1) {
207 AT_ERROR(
208 "Given input with spatial size (",
209 input_height,
210 ", ",
211 input_height,
212 "), kernel_size=(",
213 kernel_height,
214 ", ",
215 kernel_width,
216 "), dilation=(",
217 dilation_height,
218 ", ",
219 dilation_width,
220 "), padding=(",
221 pad_height,
222 ", ",
223 pad_width,
224 "), calculated shape of the array of sliding blocks as (",
225 output_height,
226 ", ",
227 output_width,
228 "), but its components must be at least one.");
229 }
230 }
231
232 } // namespace at::native
233