1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/cpu/vec/vec.h>
5 #include <ATen/native/Unfold2d.h>
6 #include <ATen/native/cpu/Loops.h>
7 #include <c10/util/irange.h>
8 #include <ATen/native/cpu/utils.h>
9 #include <cmath>
10
11 namespace at::native {
12
13 namespace {
14
15 template <typename scalar_t>
cadd(scalar_t * z,const scalar_t * x,const scalar_t * y,int64_t n)16 static inline void cadd(
17 scalar_t* z,
18 const scalar_t* x,
19 const scalar_t* y,
20 int64_t n) {
21 using Vec = vec::Vectorized<scalar_t>;
22 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
23 char* ptrs[] = {reinterpret_cast<char*>(z),
24 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
25 reinterpret_cast<char*>(const_cast<scalar_t*>(x)),
26 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
27 reinterpret_cast<char*>(const_cast<scalar_t*>(y))};
28 vectorized_loop(
29 ptrs,
30 n,
31 -1,
32 [](scalar_t x, scalar_t y) -> scalar_t { return x + y; },
33 [](Vec x, Vec y) -> Vec { return x + y; });
34 }
35
36 template <typename scalar_t>
unfolded2d_acc(scalar_t * finput_data,scalar_t * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)37 static void unfolded2d_acc(
38 scalar_t* finput_data,
39 scalar_t* input_data,
40 int64_t kH,
41 int64_t kW,
42 int64_t dH,
43 int64_t dW,
44 int64_t padH,
45 int64_t padW,
46 int64_t n_input_plane,
47 int64_t input_height,
48 int64_t input_width,
49 int64_t output_height,
50 int64_t output_width) {
51 at::parallel_for(0, n_input_plane, 0, [&](int64_t start, int64_t end) {
52 for (const auto nip : c10::irange(start, end)) {
53 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
54 int64_t kw, kh, y, x;
55 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
56 int64_t ix, iy;
57 for (kh = 0; kh < kH; kh++) {
58 for (kw = 0; kw < kW; kw++) {
59 scalar_t* src = finput_data +
60 nip * ((size_t)kH * kW * output_height * output_width) +
61 kh * ((size_t)kW * output_height * output_width) +
62 kw * ((size_t)output_height * output_width);
63 scalar_t* dst =
64 input_data + nip * ((size_t)input_height * input_width);
65 if (padW > 0 || padH > 0) {
66 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
67 int64_t lpad, rpad;
68 for (y = 0; y < output_height; y++) {
69 iy = (int64_t)y * dH - padH + kh;
70 if (iy < 0 || iy >= input_height) {
71 } else {
72 if (dW == 1) {
73 ix = 0 - padW + kw;
74 lpad = std::max<int64_t>(0, padW - kw);
75 rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
76 scalar_t* dst_slice =
77 dst + (size_t)iy * input_width + ix + lpad;
78 cadd(
79 dst_slice,
80 dst_slice,
81 src + (size_t)y * output_width + lpad,
82 output_width - lpad - rpad);
83 } else {
84 for (x = 0; x < output_width; x++) {
85 ix = (int64_t)x * dW - padW + kw;
86 if (ix < 0 || ix >= input_width) {
87 } else {
88 scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
89 *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
90 }
91 }
92 }
93 }
94 }
95 } else {
96 for (y = 0; y < output_height; y++) {
97 iy = (int64_t)y * dH + kh;
98 ix = 0 + kw;
99 if (dW == 1) {
100 scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
101 cadd(
102 dst_slice,
103 dst_slice,
104 src + (size_t)y * output_width,
105 output_width);
106 } else {
107 for (x = 0; x < output_width; x++) {
108 scalar_t* dst_slice =
109 dst + (size_t)iy * input_width + ix + x * dW;
110 *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
111 }
112 }
113 }
114 }
115 }
116 }
117 }
118 });
119 }
120
121 template <typename scalar_t>
unfolded2d_acc_channels_last(scalar_t * finput_data,scalar_t * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)122 static void unfolded2d_acc_channels_last(
123 scalar_t* finput_data,
124 scalar_t* input_data,
125 int64_t kH,
126 int64_t kW,
127 int64_t dH,
128 int64_t dW,
129 int64_t padH,
130 int64_t padW,
131 int64_t n_input_plane,
132 int64_t input_height,
133 int64_t input_width,
134 int64_t output_height,
135 int64_t output_width) {
136
137 for (int64_t y = 0; y < output_height; y++) {
138 for (int64_t x = 0; x < output_width; x++) {
139 scalar_t* src = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
140 scalar_t* dst = input_data;
141
142 if (padW > 0 || padH > 0) {
143 for (int64_t kh = 0; kh < kH; kh++) {
144 for (int64_t kw = 0; kw < kW; kw++) {
145 int64_t iy = y * dH - padH + kh;
146 int64_t ix = x * dW - padW + kw;
147 if (iy < 0 || iy >= input_height || ix < 0 || ix >= input_width) {
148 } else {
149 scalar_t* dst_slice = dst + iy * input_width * n_input_plane + ix * n_input_plane;
150 scalar_t* src_slice = src + kh * kW * n_input_plane + kw * n_input_plane;
151 cadd(dst_slice,
152 dst_slice,
153 src_slice,
154 n_input_plane);
155 }
156 }
157 }
158 } else {
159 for (int64_t kh = 0; kh < kH; kh++) {
160 for (int64_t kw = 0; kw < kW; kw++) {
161 int64_t iy = y * dH + kh;
162 int64_t ix = x * dW + kw;
163 scalar_t* dst_slice = dst + iy * input_width * n_input_plane + ix * n_input_plane;
164 scalar_t* src_slice = src + kh * kW * n_input_plane + kw * n_input_plane;
165 cadd(dst_slice,
166 dst_slice,
167 src_slice,
168 n_input_plane);
169 }
170 }
171 }
172 }
173 }
174 }
175
176 /* note: due to write issues, this one cannot be parallelized as well as
177 * unfolded2d_copy */
unfolded2d_acc_kernel(ScalarType dtype,void * finput_data,void * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,bool is_channels_last)178 void unfolded2d_acc_kernel(
179 ScalarType dtype,
180 void *finput_data,
181 void *input_data,
182 int64_t kH,
183 int64_t kW,
184 int64_t dH,
185 int64_t dW,
186 int64_t padH,
187 int64_t padW,
188 int64_t n_input_plane,
189 int64_t input_height,
190 int64_t input_width,
191 int64_t output_height,
192 int64_t output_width,
193 bool is_channels_last) {
194 // This function assumes that
195 // output_height*dH does not overflow a int64_t
196 // output_width*dW does not overflow a int64_t
197
198 if (is_channels_last) {
199 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_acc_channels_last", [&] {
200 unfolded2d_acc_channels_last(
201 static_cast<scalar_t*>(finput_data),
202 static_cast<scalar_t*>(input_data),
203 kH, kW,
204 dH, dW,
205 padH, padW,
206 n_input_plane,
207 input_height,
208 input_width,
209 output_height,
210 output_width);
211 });
212 } else {
213 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_acc", [&] {
214 unfolded2d_acc(
215 static_cast<scalar_t*>(finput_data),
216 static_cast<scalar_t*>(input_data),
217 kH, kW,
218 dH, dW,
219 padH, padW,
220 n_input_plane,
221 input_height,
222 input_width,
223 output_height,
224 output_width);
225 });
226 }
227 }
228
229 template <typename scalar_t>
unfolded2d_copy(const scalar_t * input_data,scalar_t * finput_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)230 static void unfolded2d_copy(
231 const scalar_t* input_data,
232 scalar_t* finput_data,
233 int64_t kH,
234 int64_t kW,
235 int64_t dH,
236 int64_t dW,
237 int64_t padH,
238 int64_t padW,
239 int64_t n_input_plane,
240 int64_t input_height,
241 int64_t input_width,
242 int64_t output_height,
243 int64_t output_width) {
244 at::parallel_for(
245 0, (int64_t)n_input_plane * kH * kW, 0, [&](int64_t start, int64_t end) {
246 for (const auto k : c10::irange(start, end)) {
247 int64_t nip = k / (kH * kW);
248 int64_t rest = k % (kH * kW);
249 int64_t kh = rest / kW;
250 int64_t kw = rest % kW;
251 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
252 int64_t x, y;
253 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
254 int64_t ix, iy;
255 scalar_t* dst = finput_data +
256 nip * ((size_t)kH * kW * output_height * output_width) +
257 kh * ((size_t)kW * output_height * output_width) +
258 kw * ((size_t)output_height * output_width);
259 const scalar_t* src =
260 input_data + nip * ((size_t)input_height * input_width);
261 if (padW > 0 || padH > 0) {
262 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
263 int64_t lpad, rpad;
264 for (y = 0; y < output_height; y++) {
265 iy = (int64_t)y * dH - padH + kh;
266 if (iy < 0 || iy >= input_height) {
267 memset(
268 dst + (size_t)y * output_width,
269 0,
270 sizeof(scalar_t) * output_width);
271 } else {
272 if (dW == 1) {
273 ix = 0 - padW + kw;
274 lpad = std::max<int64_t>(0, padW - kw);
275 rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
276 if (output_width - rpad - lpad <= 0) {
277 memset(
278 dst + (size_t)y * output_width,
279 0,
280 sizeof(scalar_t) * output_width);
281 } else {
282 if (lpad > 0)
283 memset(
284 dst + (size_t)y * output_width,
285 0,
286 sizeof(scalar_t) * lpad);
287 memcpy(
288 dst + (size_t)y * output_width + lpad,
289 src + (size_t)iy * input_width + ix + lpad,
290 sizeof(scalar_t) * (output_width - rpad - lpad));
291 if (rpad > 0)
292 memset(
293 dst + (size_t)y * output_width + output_width - rpad,
294 0,
295 sizeof(scalar_t) * rpad);
296 }
297 } else {
298 for (x = 0; x < output_width; x++) {
299 ix = (int64_t)x * dW - padW + kw;
300 if (ix < 0 || ix >= input_width)
301 memset(
302 dst + (size_t)y * output_width + x,
303 0,
304 sizeof(scalar_t) * 1);
305 else
306 memcpy(
307 dst + (size_t)y * output_width + x,
308 src + (size_t)iy * input_width + ix,
309 sizeof(scalar_t) * (1));
310 }
311 }
312 }
313 }
314 } else {
315 for (y = 0; y < output_height; y++) {
316 iy = (int64_t)y * dH + kh;
317 ix = 0 + kw;
318 if (dW == 1)
319 memcpy(
320 dst + (size_t)y * output_width,
321 src + (size_t)iy * input_width + ix,
322 sizeof(scalar_t) * output_width);
323 else {
324 for (x = 0; x < output_width; x++)
325 memcpy(
326 dst + (size_t)y * output_width + x,
327 src + (size_t)iy * input_width + ix + (int64_t)x * dW,
328 sizeof(scalar_t) * (1));
329 }
330 }
331 }
332 }
333 });
334 }
335
336 template <typename scalar_t>
unfolded2d_copy_channels_last(const scalar_t * input_data,scalar_t * finput_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)337 static void unfolded2d_copy_channels_last(
338 const scalar_t* input_data,
339 scalar_t* finput_data,
340 int64_t kH,
341 int64_t kW,
342 int64_t dH,
343 int64_t dW,
344 int64_t padH,
345 int64_t padW,
346 int64_t n_input_plane,
347 int64_t input_height,
348 int64_t input_width,
349 int64_t output_height,
350 int64_t output_width) {
351 at::parallel_for(0, output_height * output_width, 0, [&](int64_t start, int64_t end) {
352 int64_t y = 0;
353 int64_t x = 0;
354 data_index_init(start, y, output_height, x, output_width);
355
356 for (const auto k C10_UNUSED: c10::irange(start, end)) {
357 scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
358 const scalar_t* src = input_data;
359
360 if (padW > 0 || padH > 0) {
361 for (int64_t kh = 0; kh < kH; kh++) {
362 for (int64_t kw = 0; kw < kW; kw++) {
363 int64_t iy = y * dH - padH + kh;
364 int64_t ix = x * dW - padW + kw;
365 if (iy < 0 || iy >= input_height || ix < 0 || ix >= input_width) {
366 memset(dst + kh * kW * n_input_plane + kw * n_input_plane,
367 0,
368 sizeof(scalar_t) * n_input_plane);
369 } else {
370 memcpy(dst + kh * kW * n_input_plane + kw * n_input_plane,
371 src + iy * input_width * n_input_plane + ix * n_input_plane,
372 sizeof(scalar_t) * n_input_plane);
373 }
374 }
375 }
376 } else {
377 for (int64_t kh = 0; kh < kH; kh++) {
378 for (int64_t kw = 0; kw < kW; kw++) {
379 int64_t iy = y * dH + kh;
380 int64_t ix = x * dW + kw;
381 memcpy(dst + kh * kW * n_input_plane + kw * n_input_plane,
382 src + iy * input_width * n_input_plane + ix * n_input_plane,
383 sizeof(scalar_t) * n_input_plane);
384 }
385 }
386 }
387 // move on to next output index
388 data_index_step(y, output_height, x, output_width);
389 }
390 });
391 }
392
unfolded2d_copy_kernel(ScalarType dtype,void * finput_data,const void * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,bool is_channels_last)393 void unfolded2d_copy_kernel(
394 ScalarType dtype,
395 void *finput_data,
396 const void *input_data,
397 int64_t kH,
398 int64_t kW,
399 int64_t dH,
400 int64_t dW,
401 int64_t padH,
402 int64_t padW,
403 int64_t n_input_plane,
404 int64_t input_height,
405 int64_t input_width,
406 int64_t output_height,
407 int64_t output_width,
408 bool is_channels_last) {
409 // This function assumes that
410 // kH*kW does not overflow an int
411 // n_input_plane*kH*kW does not overflow a int64_t
412 // output_height*dH does not overflow a int64_t
413 // output_width*dW does not overflow a int64_t
414
415 if (is_channels_last) {
416 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy_channels_last", [&] {
417 unfolded2d_copy_channels_last(
418 static_cast<const scalar_t*>(input_data),
419 static_cast<scalar_t*>(finput_data),
420 kH, kW,
421 dH, dW,
422 padH, padW,
423 n_input_plane,
424 input_height,
425 input_width,
426 output_height,
427 output_width);
428 });
429 } else {
430 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy", [&] {
431 unfolded2d_copy(
432 static_cast<const scalar_t*>(input_data),
433 static_cast<scalar_t*>(finput_data),
434 kH, kW,
435 dH, dW,
436 padH, padW,
437 n_input_plane,
438 input_height,
439 input_width,
440 output_height,
441 output_width);
442 });
443 }
444 }
445
446 } // namespace
447
448 REGISTER_DISPATCH(unfolded2d_copy_stub, &unfolded2d_copy_kernel);
449 REGISTER_DISPATCH(unfolded2d_acc_stub, &unfolded2d_acc_kernel);
450
451 } // namespace at::native
452