xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/DilatedConvolutionUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <vector>
5 
6 #include <ATen/div_rtn.h>
7 #include <ATen/core/Tensor.h>
8 #include <c10/util/irange.h>
9 
10 #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
11   TORCH_CHECK(                                       \
12       T.dim() == DIM && T.size(DIM_SIZE) == SIZE,    \
13       "Need " #T " of dimension ",                   \
14       DIM,                                           \
15       " and " #T ".size[",                           \
16       DIM_SIZE,                                      \
17       "] == ",                                       \
18       SIZE,                                          \
19       " but got input to be of shape ",              \
20       T.sizes())
21 
22 namespace at::native::internal {
23 namespace {
all_positive(IntArrayRef & arr)24 inline bool all_positive(IntArrayRef& arr) {
25   return std::all_of(
26       arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
27 }
28 
all_nonnegative(std::vector<int64_t> & arr)29 inline bool all_nonnegative(std::vector<int64_t>& arr) {
30   return std::all_of(
31       arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
32 }
33 
34 } // namespace
35 
36 // calculate the rear part of output tensor sizes
37 template <int64_t dim>
get_output_size(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)38 std::vector<int64_t> get_output_size(
39     const Tensor& input,
40     IntArrayRef kernel_size,
41     IntArrayRef stride_size,
42     IntArrayRef pad_size,
43     IntArrayRef dilation_size) {
44   std::vector<int64_t> sizes;
45   for (const auto index : c10::irange(dim)) {
46     sizes.push_back(
47         div_rtn<int64_t>(
48             input.size(index + input.dim() - dim) + 2 * pad_size[index] -
49                 (dilation_size[index] * (kernel_size[index] - 1) + 1),
50             stride_size[index]) +
51         1);
52   }
53   return sizes;
54 }
55 
56 // calculate the sizes of output tensor
57 template <int64_t dim>
get_output_size(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)58 std::vector<int64_t> get_output_size(
59     const Tensor& input,
60     const Tensor& weight,
61     IntArrayRef kernel_size,
62     IntArrayRef stride_size,
63     IntArrayRef pad_size,
64     IntArrayRef dilation_size) {
65   auto output_size = get_output_size<dim>(
66       input, kernel_size, stride_size, pad_size, dilation_size);
67   output_size.insert(output_size.begin(), weight.size(0));
68   if (input.dim() == dim + 2) {
69     output_size.insert(output_size.begin(), input.size(0));
70   }
71   return output_size;
72 }
73 /*
74   slow_conv_dilated_shape_check - check user-input to dilated convolution
75   forward and backward functions.
76 */
77 template <int64_t dim>
slow_conv_dilated_shape_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)78 void slow_conv_dilated_shape_check(
79     const Tensor& input,
80     const Tensor& weight,
81     const Tensor& bias,
82     const Tensor& grad_output,
83     IntArrayRef kernel_size,
84     IntArrayRef stride_size,
85     IntArrayRef pad_size,
86     IntArrayRef dilation_size) {
87   /*
88     When the following tensors are defined:
89 
90     bias, grad_weight, grad_output
91 
92     then these are assumed to be contiguous without checking
93     because of these tensors are made contiguous by calling
94     .contiguous() method or by resizing of zero-sized tensors in
95     forward/backward functions.
96 
97     When grad_weight is defined then it is assumed without
98     checking to have the same shape as weight, see backward
99     functions.
100    */
101   // Check size arguments
102   TORCH_CHECK(
103       kernel_size.size() == dim,
104       "kernel sizes length should be ",
105       dim,
106       ", but got ",
107       kernel_size.size());
108   TORCH_CHECK(
109       stride_size.size() == dim,
110       "strides length should be ",
111       dim,
112       ", but got ",
113       stride_size.size());
114   TORCH_CHECK(
115       dilation_size.size() == dim,
116       "dilations length should be ",
117       dim,
118       ", but got ",
119       dilation_size.size());
120   TORCH_CHECK(
121       pad_size.size() == dim,
122       "pads length should be ",
123       dim,
124       ", but got ",
125       pad_size.size());
126 
127   TORCH_CHECK(
128       all_positive(kernel_size),
129       "kernel size should be greater than zero, but got ",
130       kernel_size);
131   TORCH_CHECK(
132       all_positive(stride_size),
133       "stride should be greater than zero, but got ",
134       stride_size);
135   TORCH_CHECK(
136       all_positive(dilation_size),
137       "dilation should be greater than zero, but got ",
138       dilation_size);
139 
140   // check input
141   TORCH_CHECK(input.defined(), "input must be defined");
142   bool is_batch = input.dim() == dim + 2;
143   int64_t n = (is_batch ? 2 : 1);
144   int64_t ndim = n + dim;
145   if (!is_batch) {
146     // input dim has to be dim + 1 if not batched
147     TORCH_CHECK(
148         input.dim() == dim + 1,
149         "input must be 4D or 5D tensor but got ",
150         input.dim(),
151         "D tensor");
152   }
153 
154   // check output sizes
155   auto output_size = get_output_size<dim>(
156       input, kernel_size, stride_size, pad_size, dilation_size);
157 
158   TORCH_CHECK(
159       all_nonnegative(output_size),
160       "calculated output size ",
161       output_size,
162       " is too small (all sizes must be non-negative)");
163 
164   // check weight
165   TORCH_CHECK(weight.defined(), "weight must be defined");
166   TORCH_CHECK(
167       weight.dim() == dim + 2,
168       "weight must be ",
169       dim + 2,
170       "D tensor but got ",
171       weight.dim(),
172       "D tensor dim=",
173       dim);
174   TORCH_CHECK(
175       weight.sizes().slice(2) == kernel_size,
176       "weight[2:] shape ",
177       weight.sizes().slice(2),
178       " must be equal to kernel_size ",
179       kernel_size);
180 
181   TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
182 
183   // check bias when present
184   if (bias.defined()) {
185     TORCH_CHECK(
186         bias.dim() == 1,
187         "bias must be 1D tensor but got ",
188         bias.dim(),
189         "D tensor");
190     TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
191   }
192 
193   // check grad_output when present
194   if (grad_output.defined()) {
195     TORCH_CHECK(
196         grad_output.dim() == ndim,
197         "grad_output must be ",
198         ndim,
199         "D tensor but got ",
200         grad_output.dim(),
201         "D tensor");
202     if (is_batch) {
203       TORCH_CHECK(
204           grad_output.size(0) == input.size(0),
205           "grad_output.size(0)=",
206           grad_output.size(0),
207           " must be input.size(0)=",
208           input.size(0));
209     }
210     TORCH_CHECK(
211         grad_output.size(n - 1) == weight.size(0),
212         "grad_output.size(",
213         n - 1,
214         ")=",
215         grad_output.size(n - 1),
216         " must be weight.size(0)=",
217         weight.size(0));
218     TORCH_CHECK(
219         grad_output.sizes().slice(n) == output_size,
220         "grad_output[",
221         n,
222         ":] shape",
223         grad_output.sizes().slice(n),
224         " must be equal to output size ",
225         output_size);
226   }
227 }
228 
229 } // namespace at::native::internal
230