1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/Unfold3d.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8
9 #if AT_MKL_ENABLED()
10 #include <mkl.h>
11 #endif // AT_MKL_ENABLED()
12
13 namespace at::native {
14
15 namespace {
16
IsAGeZeroAndALtB(int64_t a,int64_t b)17 bool IsAGeZeroAndALtB(int64_t a, int64_t b) {
18 return static_cast<uint64_t>(a) < static_cast<uint64_t>(b);
19 }
20
21 template <typename T>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const T * A,T * B)22 void MatCopy(int64_t M, int64_t N, int64_t lda, int64_t ldb, const T* A, T* B) {
23 for (const auto i : c10::irange(M)) {
24 std::memcpy(B + i * ldb, A + i * lda, N * sizeof(T));
25 }
26 }
27
28 template <typename T>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const T * A,T * B)29 void MatCopy(
30 int64_t M,
31 int64_t N,
32 int64_t lda,
33 int64_t stridea,
34 int64_t ldb,
35 int64_t strideb,
36 const T* A,
37 T* B) {
38 for (const auto i : c10::irange(M)) {
39 const T* A_ptr = A + i * lda;
40 T* B_ptr = B + i * ldb;
41 for (const auto j : c10::irange(N)) {
42 B_ptr[j * strideb] = A_ptr[j * stridea];
43 }
44 }
45 }
46
47 // Y += X
48 template <typename T>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const T * X,T * Y)49 void MatAdd(int64_t M, int64_t N, int64_t ldx, int64_t ldy, const T* X, T* Y) {
50 for (const auto i : c10::irange(M)) {
51 for (const auto j : c10::irange(N)) {
52 Y[i * ldy + j] += X[i * ldx + j];
53 }
54 }
55 }
56
57 // Y += X
58 template <typename T>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const T * X,T * Y)59 void MatAdd(
60 int64_t M,
61 int64_t N,
62 int64_t ldx,
63 int64_t stridex,
64 int64_t ldy,
65 int64_t stridey,
66 const T* X,
67 T* Y) {
68 for (const auto i : c10::irange(M)) {
69 for (const auto j : c10::irange(N)) {
70 Y[i * ldy + j * stridey] += X[i * ldx + j * stridex];
71 }
72 }
73 }
74
75 #if AT_MKL_ENABLED()
76
77 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const float * A,float * B)78 void MatCopy<float>(
79 int64_t M,
80 int64_t N,
81 int64_t lda,
82 int64_t ldb,
83 const float* A,
84 float* B) {
85 mkl_somatcopy('R', 'N', M, N, 1.0f, A, lda, B, ldb);
86 }
87
88 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const double * A,double * B)89 void MatCopy<double>(
90 int64_t M,
91 int64_t N,
92 int64_t lda,
93 int64_t ldb,
94 const double* A,
95 double* B) {
96 mkl_domatcopy('R', 'N', M, N, 1.0, A, lda, B, ldb);
97 }
98
99 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const float * A,float * B)100 void MatCopy<float>(
101 int64_t M,
102 int64_t N,
103 int64_t lda,
104 int64_t stridea,
105 int64_t ldb,
106 int64_t strideb,
107 const float* A,
108 float* B) {
109 mkl_somatcopy2('R', 'N', M, N, 1.0f, A, lda, stridea, B, ldb, strideb);
110 }
111
112 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const double * A,double * B)113 void MatCopy<double>(
114 int64_t M,
115 int64_t N,
116 int64_t lda,
117 int64_t stridea,
118 int64_t ldb,
119 int64_t strideb,
120 const double* A,
121 double* B) {
122 mkl_domatcopy2('R', 'N', M, N, 1.0, A, lda, stridea, B, ldb, strideb);
123 }
124
125 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const float * X,float * Y)126 void MatAdd<float>(
127 int64_t M,
128 int64_t N,
129 int64_t ldx,
130 int64_t ldy,
131 const float* X,
132 float* Y) {
133 mkl_somatadd('R', 'N', 'N', M, N, 1.0f, X, ldx, 1.0f, Y, ldy, Y, ldy);
134 }
135
136 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const double * X,double * Y)137 void MatAdd<double>(
138 int64_t M,
139 int64_t N,
140 int64_t ldx,
141 int64_t ldy,
142 const double* X,
143 double* Y) {
144 mkl_domatadd('R', 'N', 'N', M, N, 1.0, X, ldx, 1.0, Y, ldy, Y, ldy);
145 }
146
147 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const float * X,float * Y)148 void MatAdd(
149 int64_t M,
150 int64_t N,
151 int64_t ldx,
152 int64_t stridex,
153 int64_t ldy,
154 int64_t stridey,
155 const float* X,
156 float* Y) {
157 for (const auto i : c10::irange(M)) {
158 cblas_saxpy(N, 1.0f, X + i * ldx, stridex, Y + i * ldy, stridey);
159 }
160 }
161
162 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const double * X,double * Y)163 void MatAdd(
164 int64_t M,
165 int64_t N,
166 int64_t ldx,
167 int64_t stridex,
168 int64_t ldy,
169 int64_t stridey,
170 const double* X,
171 double* Y) {
172 for (const auto i : c10::irange(M)) {
173 cblas_daxpy(N, 1.0, X + i * ldx, stridex, Y + i * ldy, stridey);
174 }
175 }
176
177 #endif // AT_MKL_ENABLED()
178
179 template <typename T>
Unfold3dZeroPaddingCopyKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,const T * src,T * dst)180 void Unfold3dZeroPaddingCopyKernelImpl(
181 int64_t C,
182 int64_t X_D,
183 int64_t X_H,
184 int64_t X_W,
185 int64_t Y_D,
186 int64_t Y_H,
187 int64_t Y_W,
188 int64_t kernel_d,
189 int64_t kernel_h,
190 int64_t kernel_w,
191 int64_t stride_d,
192 int64_t stride_h,
193 int64_t stride_w,
194 const T* src,
195 T* dst) {
196 const int64_t n = C * kernel_d * kernel_h * kernel_w;
197 const int64_t X_size = X_D * X_H * X_W;
198 const int64_t Y_size = Y_D * Y_H * Y_W;
199 at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
200 for (const auto p : c10::irange(begin, end)) {
201 int64_t c = p;
202 const int64_t kw = c % kernel_w;
203 c /= kernel_w;
204 const int64_t kh = c % kernel_h;
205 c /= kernel_h;
206 const int64_t kd = c % kernel_d;
207 c /= kernel_d;
208 for (const auto yd : c10::irange(Y_D)) {
209 const int64_t xd = yd * stride_d + kd;
210 const T* src_ptr = src + c * X_size + xd * X_H * X_W + kh * X_W + kw;
211 T* dst_ptr = dst + p * Y_size + yd * Y_H * Y_W;
212 if (stride_w == 1) {
213 MatCopy<T>(Y_H, Y_W, stride_h * X_W, Y_W, src_ptr, dst_ptr);
214 } else {
215 MatCopy<T>(
216 Y_H, Y_W, stride_h * X_W, stride_w, Y_W, 1, src_ptr, dst_ptr);
217 }
218 }
219 }
220 });
221 }
222
223 template <typename T>
Unfold3dCopyKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,const T * src,T * dst)224 void Unfold3dCopyKernelImpl(
225 int64_t C,
226 int64_t X_D,
227 int64_t X_H,
228 int64_t X_W,
229 int64_t Y_D,
230 int64_t Y_H,
231 int64_t Y_W,
232 int64_t kernel_d,
233 int64_t kernel_h,
234 int64_t kernel_w,
235 int64_t stride_d,
236 int64_t stride_h,
237 int64_t stride_w,
238 int64_t pad_d,
239 int64_t pad_h,
240 int64_t pad_w,
241 const T* src,
242 T* dst) {
243 if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
244 Unfold3dZeroPaddingCopyKernelImpl<T>(
245 C,
246 X_D,
247 X_H,
248 X_W,
249 Y_D,
250 Y_H,
251 Y_W,
252 kernel_d,
253 kernel_h,
254 kernel_w,
255 stride_d,
256 stride_h,
257 stride_w,
258 src,
259 dst);
260 return;
261 }
262
263 const int64_t n = C * kernel_d * kernel_h * kernel_w;
264 const int64_t X_size = X_D * X_H * X_W;
265 const int64_t Y_size = Y_D * Y_H * Y_W;
266 at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
267 for (const auto p : c10::irange(begin, end)) {
268 int64_t c = p;
269 const int64_t kw = c % kernel_w;
270 c /= kernel_w;
271 const int64_t kh = c % kernel_h;
272 c /= kernel_h;
273 const int64_t kd = c % kernel_d;
274 c /= kernel_d;
275 const T* src_ptr = src + c * X_size;
276 T* dst_ptr = dst + p * Y_size;
277 for (const auto yd : c10::irange(Y_D)) {
278 const int64_t xd = yd * stride_d - pad_d + kd;
279 if (!IsAGeZeroAndALtB(xd, X_D)) {
280 std::memset(dst_ptr + yd * Y_H * Y_W, 0, Y_H * Y_W * sizeof(T));
281 continue;
282 }
283 for (const auto yh : c10::irange(Y_H)) {
284 const int64_t xh = yh * stride_h - pad_h + kh;
285 if (!IsAGeZeroAndALtB(xh, X_H)) {
286 std::memset(
287 dst_ptr + yd * Y_H * Y_W + yh * Y_W, 0, Y_W * sizeof(T));
288 continue;
289 }
290 for (const auto yw : c10::irange(Y_W)) {
291 const int64_t xw = yw * stride_w - pad_w + kw;
292 dst_ptr[yd * Y_H * Y_W + yh * Y_W + yw] = IsAGeZeroAndALtB(xw, X_W)
293 ? src_ptr[xd * X_H * X_W + xh * X_W + xw]
294 : T(0);
295 }
296 }
297 }
298 }
299 });
300 }
301
302 template <typename T>
Unfold3dZeroPaddingAccKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,const T * src,T * dst)303 void Unfold3dZeroPaddingAccKernelImpl(
304 int64_t C,
305 int64_t X_D,
306 int64_t X_H,
307 int64_t X_W,
308 int64_t Y_D,
309 int64_t Y_H,
310 int64_t Y_W,
311 int64_t kernel_d,
312 int64_t kernel_h,
313 int64_t kernel_w,
314 int64_t stride_d,
315 int64_t stride_h,
316 int64_t stride_w,
317 const T* src,
318 T* dst) {
319 const int64_t X_size = X_D * X_H * X_W;
320 const int64_t Y_size = Y_D * Y_H * Y_W;
321 const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
322 at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
323 std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
324 for (const auto c : c10::irange(begin, end)) {
325 for (const auto kd : c10::irange(kernel_d)) {
326 for (const auto kh : c10::irange(kernel_h)) {
327 for (const auto kw : c10::irange(kernel_w)) {
328 const int64_t p =
329 c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
330 for (const auto yd : c10::irange(Y_D)) {
331 const int64_t xd = yd * stride_d + kd;
332 const T* src_ptr = src + p * Y_size + yd * Y_H * Y_W;
333 T* dst_ptr = dst + c * X_size + xd * X_H * X_W + kh * X_W + kw;
334 if (stride_w == 1) {
335 MatAdd<T>(Y_H, Y_W, Y_W, stride_h * X_W, src_ptr, dst_ptr);
336 } else {
337 MatAdd<T>(
338 Y_H,
339 Y_W,
340 Y_W,
341 1,
342 stride_h * X_W,
343 stride_w,
344 src_ptr,
345 dst_ptr);
346 }
347 }
348 }
349 }
350 }
351 }
352 });
353 }
354
355 template <typename T>
Unfold3dAccKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,const T * src,T * dst)356 void Unfold3dAccKernelImpl(
357 int64_t C,
358 int64_t X_D,
359 int64_t X_H,
360 int64_t X_W,
361 int64_t Y_D,
362 int64_t Y_H,
363 int64_t Y_W,
364 int64_t kernel_d,
365 int64_t kernel_h,
366 int64_t kernel_w,
367 int64_t stride_d,
368 int64_t stride_h,
369 int64_t stride_w,
370 int64_t pad_d,
371 int64_t pad_h,
372 int64_t pad_w,
373 const T* src,
374 T* dst) {
375 if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
376 Unfold3dZeroPaddingAccKernelImpl<T>(
377 C,
378 X_D,
379 X_H,
380 X_W,
381 Y_D,
382 Y_H,
383 Y_W,
384 kernel_d,
385 kernel_h,
386 kernel_w,
387 stride_d,
388 stride_h,
389 stride_w,
390 src,
391 dst);
392 return;
393 }
394 const int64_t X_size = X_D * X_H * X_W;
395 const int64_t Y_size = Y_D * Y_H * Y_W;
396 const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
397 at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
398 std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
399 for (const auto c : c10::irange(begin, end)) {
400 T* dst_ptr = dst + c * X_size;
401 for (const auto kd : c10::irange(kernel_d)) {
402 for (const auto kh : c10::irange(kernel_h)) {
403 for (const auto kw : c10::irange(kernel_w)) {
404 const int64_t p =
405 c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
406 const T* src_ptr = src + p * Y_size;
407 for (const auto yd : c10::irange(Y_D)) {
408 const int64_t xd = yd * stride_d - pad_d + kd;
409 if (!IsAGeZeroAndALtB(xd, X_D)) {
410 continue;
411 }
412 for (const auto yh : c10::irange(Y_H)) {
413 const int64_t xh = yh * stride_h - pad_h + kh;
414 if (!IsAGeZeroAndALtB(xh, X_H)) {
415 continue;
416 }
417 for (const auto yw : c10::irange(Y_W)) {
418 const int64_t xw = yw * stride_w - pad_w + kw;
419 if (IsAGeZeroAndALtB(xw, X_W)) {
420 dst_ptr[xd * X_H * X_W + xh * X_W + xw] +=
421 src_ptr[yd * Y_H * Y_W + yh * Y_W + yw];
422 }
423 }
424 }
425 }
426 }
427 }
428 }
429 }
430 });
431 }
432
433 } // namespace
434
Unfold3dCopyCPU(ScalarType dtype,const void * src,int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,void * dst)435 void Unfold3dCopyCPU(
436 ScalarType dtype,
437 const void *src,
438 int64_t C,
439 int64_t X_D,
440 int64_t X_H,
441 int64_t X_W,
442 int64_t Y_D,
443 int64_t Y_H,
444 int64_t Y_W,
445 int64_t kernel_d,
446 int64_t kernel_h,
447 int64_t kernel_w,
448 int64_t stride_d,
449 int64_t stride_h,
450 int64_t stride_w,
451 int64_t pad_d,
452 int64_t pad_h,
453 int64_t pad_w,
454 void* dst) {
455 AT_DISPATCH_ALL_TYPES_AND2(
456 at::ScalarType::BFloat16,
457 at::ScalarType::Half,
458 dtype,
459 "Unfold3dCopyCPU",
460 [=, &src]() {
461 Unfold3dCopyKernelImpl<scalar_t>(
462 C,
463 X_D,
464 X_H,
465 X_W,
466 Y_D,
467 Y_H,
468 Y_W,
469 kernel_d,
470 kernel_h,
471 kernel_w,
472 stride_d,
473 stride_h,
474 stride_w,
475 pad_d,
476 pad_h,
477 pad_w,
478 static_cast<const scalar_t*>(src),
479 static_cast<scalar_t*>(dst));
480 });
481 }
482
Unfold3dAccCPU(ScalarType dtype,const void * src,int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,void * dst)483 void Unfold3dAccCPU(
484 ScalarType dtype,
485 const void *src,
486 int64_t C,
487 int64_t X_D,
488 int64_t X_H,
489 int64_t X_W,
490 int64_t Y_D,
491 int64_t Y_H,
492 int64_t Y_W,
493 int64_t kernel_d,
494 int64_t kernel_h,
495 int64_t kernel_w,
496 int64_t stride_d,
497 int64_t stride_h,
498 int64_t stride_w,
499 int64_t pad_d,
500 int64_t pad_h,
501 int64_t pad_w,
502 void* dst) {
503 AT_DISPATCH_ALL_TYPES_AND2(
504 at::ScalarType::BFloat16,
505 at::ScalarType::Half,
506 dtype,
507 "Unfold3dAccCPU",
508 [=, &src]() {
509 Unfold3dAccKernelImpl<scalar_t>(
510 C,
511 X_D,
512 X_H,
513 X_W,
514 Y_D,
515 Y_H,
516 Y_W,
517 kernel_d,
518 kernel_h,
519 kernel_w,
520 stride_d,
521 stride_h,
522 stride_w,
523 pad_d,
524 pad_h,
525 pad_w,
526 static_cast<const scalar_t*>(src),
527 static_cast<scalar_t*>(dst));
528 });
529 }
530
531 } // namespace at::native
532