xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/im2col_shape_check.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/div_rtn.h>
5 
6 namespace at::native {
7 
col2im_shape_check(const Tensor & input,const Tensor & grad_output,int64_t output_height,int64_t output_width,int64_t kernel_height,int64_t kernel_width,int64_t dilation_height,int64_t dilation_width,int64_t pad_height,int64_t pad_width,int64_t stride_height,int64_t stride_width)8 inline void col2im_shape_check(
9     const Tensor& input,
10     const Tensor& grad_output,
11     int64_t output_height,
12     int64_t output_width,
13     int64_t kernel_height,
14     int64_t kernel_width,
15     int64_t dilation_height,
16     int64_t dilation_width,
17     int64_t pad_height,
18     int64_t pad_width,
19     int64_t stride_height,
20     int64_t stride_width) {
21   TORCH_CHECK(
22       kernel_width > 0 && kernel_height > 0,
23       "kernel size should be greater than zero, but got kernel_height: ",
24       kernel_height,
25       " kernel_width: ",
26       kernel_width);
27   TORCH_CHECK(
28       stride_width > 0 && stride_height > 0,
29       "stride should be greater than zero, but got stride_height: ",
30       stride_height,
31       " stride_width: ",
32       stride_width);
33   TORCH_CHECK(
34       dilation_width > 0 && dilation_height > 0,
35       "dilation should be greater than zero, but got dilation_height: ",
36       dilation_height,
37       " dilation_width: ",
38       dilation_width);
39   TORCH_CHECK(
40       pad_width >= 0 && pad_height >= 0,
41       "padding should be non-negative, but got pad_height: ",
42       pad_height,
43       " pad_width: ",
44       pad_width);
45 
46 
47   int64_t ndim = input.ndimension();
48   // allow dim=0 only the batch dimension.
49   TORCH_CHECK(
50       (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
51       (ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
52       "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
53       input.sizes());
54 
55   int64_t batch_dim = (ndim == 3) ? 0 : -1;
56   int64_t n_input_plane = input.size(batch_dim + 1);
57 
58   if (n_input_plane % (kernel_width * kernel_height) != 0) {
59     AT_ERROR(
60         "Expected size of input's dimension 1 to be divisible by the "
61         "product of kernel_size, but got input.size(1)=",
62         n_input_plane,
63         " and kernel_size=(",
64         kernel_height,
65         ", ",
66         kernel_width,
67         ").");
68   }
69 
70   int64_t input_length = input.size(batch_dim + 2);
71   int64_t n_blocks_height =
72       div_rtn<int64_t>(
73           output_height + 2 * pad_height -
74               dilation_height * (kernel_height - 1) - 1,
75           stride_height) +
76       1;
77   int64_t n_blocks_width = div_rtn<int64_t>(
78                                    output_width + 2 * pad_width -
79                                        dilation_width * (kernel_width - 1) - 1,
80                                    stride_width) +
81       1;
82 
83   if (input_length != (n_blocks_height * n_blocks_width)) {
84     AT_ERROR(
85         "Given output_size=(",
86         output_height,
87         ", ",
88         output_width,
89         "), kernel_size=(",
90         kernel_height,
91         ", ",
92         kernel_width,
93         "), dilation=(",
94         dilation_height,
95         ", ",
96         dilation_width,
97         "), padding=(",
98         pad_height,
99         ", ",
100         pad_width,
101         "), stride=(",
102         stride_height,
103         ", ",
104         stride_width,
105         "), expected size of input's dimension 2 to match the calculated number of ",
106         "sliding blocks ",
107         n_blocks_height,
108         " * ",
109         n_blocks_width,
110         " = ",
111         (n_blocks_height * n_blocks_width),
112         ", but got input.size(2)=",
113         input_length,
114         ".");
115   }
116 
117   TORCH_CHECK(
118     n_blocks_height >= 1 && n_blocks_width >= 1,
119     "Given output_size=(", output_height, ", ", output_width, "), ",
120     "kernel_size=(", kernel_height, ", ", kernel_width, "), ",
121     "dilation=(", dilation_height, ", ", dilation_width, "), ",
122     "padding=(", pad_height, ", ", pad_width, "), ",
123     "stride=(", stride_height, ", ", stride_width, "), ",
124     "calculated shape of the array of sliding blocks as ",
125     "(", n_blocks_height, ", ", n_blocks_width, "), ",
126     "which is too small (non-positive)");
127 
128   if (output_width < 1 || output_height < 1) {
129     AT_ERROR(
130         "Expected output spatial size to be positive, but got: output_size=(",
131         output_height,
132         ", ",
133         output_width,
134         ").");
135   }
136 }
137 
im2col_shape_check(const Tensor & input,const Tensor & grad_output,int64_t kernel_height,int64_t kernel_width,int64_t dilation_height,int64_t dilation_width,int64_t pad_height,int64_t pad_width,int64_t stride_height,int64_t stride_width)138 inline void im2col_shape_check(
139     const Tensor& input,
140     const Tensor& grad_output,
141     int64_t kernel_height,
142     int64_t kernel_width,
143     int64_t dilation_height,
144     int64_t dilation_width,
145     int64_t pad_height,
146     int64_t pad_width,
147     int64_t stride_height,
148     int64_t stride_width) {
149   TORCH_CHECK(
150       kernel_width > 0 && kernel_height > 0,
151       "kernel size should be greater than zero, but got kernel_height: ",
152       kernel_height,
153       " kernel_width: ",
154       kernel_width);
155 
156   TORCH_CHECK(
157       dilation_width > 0 && dilation_height > 0,
158       "dilation should be greater than zero, but got dilation_height: ",
159       dilation_height,
160       " dilation_width: ",
161       dilation_width);
162 
163   TORCH_CHECK(
164       pad_width >= 0 && pad_height >= 0,
165       "padding should be non-negative, but got pad_height: ",
166       pad_height,
167       " pad_width: ",
168       pad_width);
169 
170   TORCH_CHECK(
171       stride_width > 0 && stride_height > 0,
172       "stride should be greater than zero, but got stride_height: ",
173       stride_height,
174       " stride_width: ",
175       stride_width);
176 
177   int64_t ndim = input.ndimension();
178 
179   // allow dim=0 only the batch dimension.
180   bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
181   TORCH_CHECK(
182       (ndim == 3 && input.size(0) && valid_dims) ||
183       (ndim == 4 && valid_dims && input.size(3) != 0),
184       "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
185       input.sizes());
186 
187   int64_t dim_batch = 0;
188 
189   if (ndim == 3) {
190     dim_batch = -1;
191   }
192 
193   int64_t input_height = input.size(dim_batch + 2);
194   int64_t input_width = input.size(dim_batch + 3);
195   int64_t output_height = div_rtn<int64_t>(
196                               input_height + 2 * pad_height -
197                                   (dilation_height * (kernel_height - 1) + 1),
198                               stride_height) +
199       1;
200   int64_t output_width = div_rtn<int64_t>(
201                              input_width + 2 * pad_width -
202                                  (dilation_width * (kernel_width - 1) + 1),
203                              stride_width) +
204       1;
205 
206   if (output_height < 1 || output_width < 1) {
207     AT_ERROR(
208         "Given input with spatial size (",
209         input_height,
210         ", ",
211         input_height,
212         "), kernel_size=(",
213         kernel_height,
214         ", ",
215         kernel_width,
216         "), dilation=(",
217         dilation_height,
218         ", ",
219         dilation_width,
220         "), padding=(",
221         pad_height,
222         ", ",
223         pad_width,
224         "), calculated shape of the array of sliding blocks as (",
225         output_height,
226         ", ",
227         output_width,
228         "), but its components must be at least one.");
229   }
230 }
231 
232 } // namespace at::native
233