1 #pragma once
2
3 #include <algorithm>
4 #include <vector>
5
6 #include <ATen/div_rtn.h>
7 #include <ATen/core/Tensor.h>
8 #include <c10/util/irange.h>
9
10 #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
11 TORCH_CHECK( \
12 T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
13 "Need " #T " of dimension ", \
14 DIM, \
15 " and " #T ".size[", \
16 DIM_SIZE, \
17 "] == ", \
18 SIZE, \
19 " but got input to be of shape ", \
20 T.sizes())
21
22 namespace at::native::internal {
23 namespace {
all_positive(IntArrayRef & arr)24 inline bool all_positive(IntArrayRef& arr) {
25 return std::all_of(
26 arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
27 }
28
all_nonnegative(std::vector<int64_t> & arr)29 inline bool all_nonnegative(std::vector<int64_t>& arr) {
30 return std::all_of(
31 arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
32 }
33
34 } // namespace
35
36 // calculate the rear part of output tensor sizes
37 template <int64_t dim>
get_output_size(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)38 std::vector<int64_t> get_output_size(
39 const Tensor& input,
40 IntArrayRef kernel_size,
41 IntArrayRef stride_size,
42 IntArrayRef pad_size,
43 IntArrayRef dilation_size) {
44 std::vector<int64_t> sizes;
45 for (const auto index : c10::irange(dim)) {
46 sizes.push_back(
47 div_rtn<int64_t>(
48 input.size(index + input.dim() - dim) + 2 * pad_size[index] -
49 (dilation_size[index] * (kernel_size[index] - 1) + 1),
50 stride_size[index]) +
51 1);
52 }
53 return sizes;
54 }
55
56 // calculate the sizes of output tensor
57 template <int64_t dim>
get_output_size(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)58 std::vector<int64_t> get_output_size(
59 const Tensor& input,
60 const Tensor& weight,
61 IntArrayRef kernel_size,
62 IntArrayRef stride_size,
63 IntArrayRef pad_size,
64 IntArrayRef dilation_size) {
65 auto output_size = get_output_size<dim>(
66 input, kernel_size, stride_size, pad_size, dilation_size);
67 output_size.insert(output_size.begin(), weight.size(0));
68 if (input.dim() == dim + 2) {
69 output_size.insert(output_size.begin(), input.size(0));
70 }
71 return output_size;
72 }
73 /*
74 slow_conv_dilated_shape_check - check user-input to dilated convolution
75 forward and backward functions.
76 */
77 template <int64_t dim>
slow_conv_dilated_shape_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)78 void slow_conv_dilated_shape_check(
79 const Tensor& input,
80 const Tensor& weight,
81 const Tensor& bias,
82 const Tensor& grad_output,
83 IntArrayRef kernel_size,
84 IntArrayRef stride_size,
85 IntArrayRef pad_size,
86 IntArrayRef dilation_size) {
87 /*
88 When the following tensors are defined:
89
90 bias, grad_weight, grad_output
91
92 then these are assumed to be contiguous without checking
93 because of these tensors are made contiguous by calling
94 .contiguous() method or by resizing of zero-sized tensors in
95 forward/backward functions.
96
97 When grad_weight is defined then it is assumed without
98 checking to have the same shape as weight, see backward
99 functions.
100 */
101 // Check size arguments
102 TORCH_CHECK(
103 kernel_size.size() == dim,
104 "kernel sizes length should be ",
105 dim,
106 ", but got ",
107 kernel_size.size());
108 TORCH_CHECK(
109 stride_size.size() == dim,
110 "strides length should be ",
111 dim,
112 ", but got ",
113 stride_size.size());
114 TORCH_CHECK(
115 dilation_size.size() == dim,
116 "dilations length should be ",
117 dim,
118 ", but got ",
119 dilation_size.size());
120 TORCH_CHECK(
121 pad_size.size() == dim,
122 "pads length should be ",
123 dim,
124 ", but got ",
125 pad_size.size());
126
127 TORCH_CHECK(
128 all_positive(kernel_size),
129 "kernel size should be greater than zero, but got ",
130 kernel_size);
131 TORCH_CHECK(
132 all_positive(stride_size),
133 "stride should be greater than zero, but got ",
134 stride_size);
135 TORCH_CHECK(
136 all_positive(dilation_size),
137 "dilation should be greater than zero, but got ",
138 dilation_size);
139
140 // check input
141 TORCH_CHECK(input.defined(), "input must be defined");
142 bool is_batch = input.dim() == dim + 2;
143 int64_t n = (is_batch ? 2 : 1);
144 int64_t ndim = n + dim;
145 if (!is_batch) {
146 // input dim has to be dim + 1 if not batched
147 TORCH_CHECK(
148 input.dim() == dim + 1,
149 "input must be 4D or 5D tensor but got ",
150 input.dim(),
151 "D tensor");
152 }
153
154 // check output sizes
155 auto output_size = get_output_size<dim>(
156 input, kernel_size, stride_size, pad_size, dilation_size);
157
158 TORCH_CHECK(
159 all_nonnegative(output_size),
160 "calculated output size ",
161 output_size,
162 " is too small (all sizes must be non-negative)");
163
164 // check weight
165 TORCH_CHECK(weight.defined(), "weight must be defined");
166 TORCH_CHECK(
167 weight.dim() == dim + 2,
168 "weight must be ",
169 dim + 2,
170 "D tensor but got ",
171 weight.dim(),
172 "D tensor dim=",
173 dim);
174 TORCH_CHECK(
175 weight.sizes().slice(2) == kernel_size,
176 "weight[2:] shape ",
177 weight.sizes().slice(2),
178 " must be equal to kernel_size ",
179 kernel_size);
180
181 TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
182
183 // check bias when present
184 if (bias.defined()) {
185 TORCH_CHECK(
186 bias.dim() == 1,
187 "bias must be 1D tensor but got ",
188 bias.dim(),
189 "D tensor");
190 TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
191 }
192
193 // check grad_output when present
194 if (grad_output.defined()) {
195 TORCH_CHECK(
196 grad_output.dim() == ndim,
197 "grad_output must be ",
198 ndim,
199 "D tensor but got ",
200 grad_output.dim(),
201 "D tensor");
202 if (is_batch) {
203 TORCH_CHECK(
204 grad_output.size(0) == input.size(0),
205 "grad_output.size(0)=",
206 grad_output.size(0),
207 " must be input.size(0)=",
208 input.size(0));
209 }
210 TORCH_CHECK(
211 grad_output.size(n - 1) == weight.size(0),
212 "grad_output.size(",
213 n - 1,
214 ")=",
215 grad_output.size(n - 1),
216 " must be weight.size(0)=",
217 weight.size(0));
218 TORCH_CHECK(
219 grad_output.sizes().slice(n) == output_size,
220 "grad_output[",
221 n,
222 ":] shape",
223 grad_output.sizes().slice(n),
224 " must be equal to output size ",
225 output_size);
226 }
227 }
228
229 } // namespace at::native::internal
230