xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Unfold3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/Unfold3d.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8 
9 #if AT_MKL_ENABLED()
10 #include <mkl.h>
11 #endif // AT_MKL_ENABLED()
12 
13 namespace at::native {
14 
15 namespace {
16 
IsAGeZeroAndALtB(int64_t a,int64_t b)17 bool IsAGeZeroAndALtB(int64_t a, int64_t b) {
18   return static_cast<uint64_t>(a) < static_cast<uint64_t>(b);
19 }
20 
21 template <typename T>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const T * A,T * B)22 void MatCopy(int64_t M, int64_t N, int64_t lda, int64_t ldb, const T* A, T* B) {
23   for (const auto i : c10::irange(M)) {
24     std::memcpy(B + i * ldb, A + i * lda, N * sizeof(T));
25   }
26 }
27 
28 template <typename T>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const T * A,T * B)29 void MatCopy(
30     int64_t M,
31     int64_t N,
32     int64_t lda,
33     int64_t stridea,
34     int64_t ldb,
35     int64_t strideb,
36     const T* A,
37     T* B) {
38   for (const auto i : c10::irange(M)) {
39     const T* A_ptr = A + i * lda;
40     T* B_ptr = B + i * ldb;
41     for (const auto j : c10::irange(N)) {
42       B_ptr[j * strideb] = A_ptr[j * stridea];
43     }
44   }
45 }
46 
47 // Y += X
48 template <typename T>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const T * X,T * Y)49 void MatAdd(int64_t M, int64_t N, int64_t ldx, int64_t ldy, const T* X, T* Y) {
50   for (const auto i : c10::irange(M)) {
51     for (const auto j : c10::irange(N)) {
52       Y[i * ldy + j] += X[i * ldx + j];
53     }
54   }
55 }
56 
57 // Y += X
58 template <typename T>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const T * X,T * Y)59 void MatAdd(
60     int64_t M,
61     int64_t N,
62     int64_t ldx,
63     int64_t stridex,
64     int64_t ldy,
65     int64_t stridey,
66     const T* X,
67     T* Y) {
68   for (const auto i : c10::irange(M)) {
69     for (const auto j : c10::irange(N)) {
70       Y[i * ldy + j * stridey] += X[i * ldx + j * stridex];
71     }
72   }
73 }
74 
75 #if AT_MKL_ENABLED()
76 
77 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const float * A,float * B)78 void MatCopy<float>(
79     int64_t M,
80     int64_t N,
81     int64_t lda,
82     int64_t ldb,
83     const float* A,
84     float* B) {
85   mkl_somatcopy('R', 'N', M, N, 1.0f, A, lda, B, ldb);
86 }
87 
88 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t ldb,const double * A,double * B)89 void MatCopy<double>(
90     int64_t M,
91     int64_t N,
92     int64_t lda,
93     int64_t ldb,
94     const double* A,
95     double* B) {
96   mkl_domatcopy('R', 'N', M, N, 1.0, A, lda, B, ldb);
97 }
98 
99 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const float * A,float * B)100 void MatCopy<float>(
101     int64_t M,
102     int64_t N,
103     int64_t lda,
104     int64_t stridea,
105     int64_t ldb,
106     int64_t strideb,
107     const float* A,
108     float* B) {
109   mkl_somatcopy2('R', 'N', M, N, 1.0f, A, lda, stridea, B, ldb, strideb);
110 }
111 
112 template <>
MatCopy(int64_t M,int64_t N,int64_t lda,int64_t stridea,int64_t ldb,int64_t strideb,const double * A,double * B)113 void MatCopy<double>(
114     int64_t M,
115     int64_t N,
116     int64_t lda,
117     int64_t stridea,
118     int64_t ldb,
119     int64_t strideb,
120     const double* A,
121     double* B) {
122   mkl_domatcopy2('R', 'N', M, N, 1.0, A, lda, stridea, B, ldb, strideb);
123 }
124 
125 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const float * X,float * Y)126 void MatAdd<float>(
127     int64_t M,
128     int64_t N,
129     int64_t ldx,
130     int64_t ldy,
131     const float* X,
132     float* Y) {
133   mkl_somatadd('R', 'N', 'N', M, N, 1.0f, X, ldx, 1.0f, Y, ldy, Y, ldy);
134 }
135 
136 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t ldy,const double * X,double * Y)137 void MatAdd<double>(
138     int64_t M,
139     int64_t N,
140     int64_t ldx,
141     int64_t ldy,
142     const double* X,
143     double* Y) {
144   mkl_domatadd('R', 'N', 'N', M, N, 1.0, X, ldx, 1.0, Y, ldy, Y, ldy);
145 }
146 
147 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const float * X,float * Y)148 void MatAdd(
149     int64_t M,
150     int64_t N,
151     int64_t ldx,
152     int64_t stridex,
153     int64_t ldy,
154     int64_t stridey,
155     const float* X,
156     float* Y) {
157   for (const auto i : c10::irange(M)) {
158     cblas_saxpy(N, 1.0f, X + i * ldx, stridex, Y + i * ldy, stridey);
159   }
160 }
161 
162 template <>
MatAdd(int64_t M,int64_t N,int64_t ldx,int64_t stridex,int64_t ldy,int64_t stridey,const double * X,double * Y)163 void MatAdd(
164     int64_t M,
165     int64_t N,
166     int64_t ldx,
167     int64_t stridex,
168     int64_t ldy,
169     int64_t stridey,
170     const double* X,
171     double* Y) {
172   for (const auto i : c10::irange(M)) {
173     cblas_daxpy(N, 1.0, X + i * ldx, stridex, Y + i * ldy, stridey);
174   }
175 }
176 
177 #endif // AT_MKL_ENABLED()
178 
179 template <typename T>
Unfold3dZeroPaddingCopyKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,const T * src,T * dst)180 void Unfold3dZeroPaddingCopyKernelImpl(
181     int64_t C,
182     int64_t X_D,
183     int64_t X_H,
184     int64_t X_W,
185     int64_t Y_D,
186     int64_t Y_H,
187     int64_t Y_W,
188     int64_t kernel_d,
189     int64_t kernel_h,
190     int64_t kernel_w,
191     int64_t stride_d,
192     int64_t stride_h,
193     int64_t stride_w,
194     const T* src,
195     T* dst) {
196   const int64_t n = C * kernel_d * kernel_h * kernel_w;
197   const int64_t X_size = X_D * X_H * X_W;
198   const int64_t Y_size = Y_D * Y_H * Y_W;
199   at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
200     for (const auto p : c10::irange(begin, end)) {
201       int64_t c = p;
202       const int64_t kw = c % kernel_w;
203       c /= kernel_w;
204       const int64_t kh = c % kernel_h;
205       c /= kernel_h;
206       const int64_t kd = c % kernel_d;
207       c /= kernel_d;
208       for (const auto yd : c10::irange(Y_D)) {
209         const int64_t xd = yd * stride_d + kd;
210         const T* src_ptr = src + c * X_size + xd * X_H * X_W + kh * X_W + kw;
211         T* dst_ptr = dst + p * Y_size + yd * Y_H * Y_W;
212         if (stride_w == 1) {
213           MatCopy<T>(Y_H, Y_W, stride_h * X_W, Y_W, src_ptr, dst_ptr);
214         } else {
215           MatCopy<T>(
216               Y_H, Y_W, stride_h * X_W, stride_w, Y_W, 1, src_ptr, dst_ptr);
217         }
218       }
219     }
220   });
221 }
222 
223 template <typename T>
Unfold3dCopyKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,const T * src,T * dst)224 void Unfold3dCopyKernelImpl(
225     int64_t C,
226     int64_t X_D,
227     int64_t X_H,
228     int64_t X_W,
229     int64_t Y_D,
230     int64_t Y_H,
231     int64_t Y_W,
232     int64_t kernel_d,
233     int64_t kernel_h,
234     int64_t kernel_w,
235     int64_t stride_d,
236     int64_t stride_h,
237     int64_t stride_w,
238     int64_t pad_d,
239     int64_t pad_h,
240     int64_t pad_w,
241     const T* src,
242     T* dst) {
243   if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
244     Unfold3dZeroPaddingCopyKernelImpl<T>(
245         C,
246         X_D,
247         X_H,
248         X_W,
249         Y_D,
250         Y_H,
251         Y_W,
252         kernel_d,
253         kernel_h,
254         kernel_w,
255         stride_d,
256         stride_h,
257         stride_w,
258         src,
259         dst);
260     return;
261   }
262 
263   const int64_t n = C * kernel_d * kernel_h * kernel_w;
264   const int64_t X_size = X_D * X_H * X_W;
265   const int64_t Y_size = Y_D * Y_H * Y_W;
266   at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
267     for (const auto p : c10::irange(begin, end)) {
268       int64_t c = p;
269       const int64_t kw = c % kernel_w;
270       c /= kernel_w;
271       const int64_t kh = c % kernel_h;
272       c /= kernel_h;
273       const int64_t kd = c % kernel_d;
274       c /= kernel_d;
275       const T* src_ptr = src + c * X_size;
276       T* dst_ptr = dst + p * Y_size;
277       for (const auto yd : c10::irange(Y_D)) {
278         const int64_t xd = yd * stride_d - pad_d + kd;
279         if (!IsAGeZeroAndALtB(xd, X_D)) {
280           std::memset(dst_ptr + yd * Y_H * Y_W, 0, Y_H * Y_W * sizeof(T));
281           continue;
282         }
283         for (const auto yh : c10::irange(Y_H)) {
284           const int64_t xh = yh * stride_h - pad_h + kh;
285           if (!IsAGeZeroAndALtB(xh, X_H)) {
286             std::memset(
287                 dst_ptr + yd * Y_H * Y_W + yh * Y_W, 0, Y_W * sizeof(T));
288             continue;
289           }
290           for (const auto yw : c10::irange(Y_W)) {
291             const int64_t xw = yw * stride_w - pad_w + kw;
292             dst_ptr[yd * Y_H * Y_W + yh * Y_W + yw] = IsAGeZeroAndALtB(xw, X_W)
293                 ? src_ptr[xd * X_H * X_W + xh * X_W + xw]
294                 : T(0);
295           }
296         }
297       }
298     }
299   });
300 }
301 
302 template <typename T>
Unfold3dZeroPaddingAccKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,const T * src,T * dst)303 void Unfold3dZeroPaddingAccKernelImpl(
304     int64_t C,
305     int64_t X_D,
306     int64_t X_H,
307     int64_t X_W,
308     int64_t Y_D,
309     int64_t Y_H,
310     int64_t Y_W,
311     int64_t kernel_d,
312     int64_t kernel_h,
313     int64_t kernel_w,
314     int64_t stride_d,
315     int64_t stride_h,
316     int64_t stride_w,
317     const T* src,
318     T* dst) {
319   const int64_t X_size = X_D * X_H * X_W;
320   const int64_t Y_size = Y_D * Y_H * Y_W;
321   const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
322   at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
323     std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
324     for (const auto c : c10::irange(begin, end)) {
325       for (const auto kd : c10::irange(kernel_d)) {
326         for (const auto kh : c10::irange(kernel_h)) {
327           for (const auto kw : c10::irange(kernel_w)) {
328             const int64_t p =
329                 c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
330             for (const auto yd : c10::irange(Y_D)) {
331               const int64_t xd = yd * stride_d + kd;
332               const T* src_ptr = src + p * Y_size + yd * Y_H * Y_W;
333               T* dst_ptr = dst + c * X_size + xd * X_H * X_W + kh * X_W + kw;
334               if (stride_w == 1) {
335                 MatAdd<T>(Y_H, Y_W, Y_W, stride_h * X_W, src_ptr, dst_ptr);
336               } else {
337                 MatAdd<T>(
338                     Y_H,
339                     Y_W,
340                     Y_W,
341                     1,
342                     stride_h * X_W,
343                     stride_w,
344                     src_ptr,
345                     dst_ptr);
346               }
347             }
348           }
349         }
350       }
351     }
352   });
353 }
354 
355 template <typename T>
Unfold3dAccKernelImpl(int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,const T * src,T * dst)356 void Unfold3dAccKernelImpl(
357     int64_t C,
358     int64_t X_D,
359     int64_t X_H,
360     int64_t X_W,
361     int64_t Y_D,
362     int64_t Y_H,
363     int64_t Y_W,
364     int64_t kernel_d,
365     int64_t kernel_h,
366     int64_t kernel_w,
367     int64_t stride_d,
368     int64_t stride_h,
369     int64_t stride_w,
370     int64_t pad_d,
371     int64_t pad_h,
372     int64_t pad_w,
373     const T* src,
374     T* dst) {
375   if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
376     Unfold3dZeroPaddingAccKernelImpl<T>(
377         C,
378         X_D,
379         X_H,
380         X_W,
381         Y_D,
382         Y_H,
383         Y_W,
384         kernel_d,
385         kernel_h,
386         kernel_w,
387         stride_d,
388         stride_h,
389         stride_w,
390         src,
391         dst);
392     return;
393   }
394   const int64_t X_size = X_D * X_H * X_W;
395   const int64_t Y_size = Y_D * Y_H * Y_W;
396   const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
397   at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
398     std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
399     for (const auto c : c10::irange(begin, end)) {
400       T* dst_ptr = dst + c * X_size;
401       for (const auto kd : c10::irange(kernel_d)) {
402         for (const auto kh : c10::irange(kernel_h)) {
403           for (const auto kw : c10::irange(kernel_w)) {
404             const int64_t p =
405                 c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
406             const T* src_ptr = src + p * Y_size;
407             for (const auto yd : c10::irange(Y_D)) {
408               const int64_t xd = yd * stride_d - pad_d + kd;
409               if (!IsAGeZeroAndALtB(xd, X_D)) {
410                 continue;
411               }
412               for (const auto yh : c10::irange(Y_H)) {
413                 const int64_t xh = yh * stride_h - pad_h + kh;
414                 if (!IsAGeZeroAndALtB(xh, X_H)) {
415                   continue;
416                 }
417                 for (const auto yw : c10::irange(Y_W)) {
418                   const int64_t xw = yw * stride_w - pad_w + kw;
419                   if (IsAGeZeroAndALtB(xw, X_W)) {
420                     dst_ptr[xd * X_H * X_W + xh * X_W + xw] +=
421                         src_ptr[yd * Y_H * Y_W + yh * Y_W + yw];
422                   }
423                 }
424               }
425             }
426           }
427         }
428       }
429     }
430   });
431 }
432 
433 } // namespace
434 
Unfold3dCopyCPU(ScalarType dtype,const void * src,int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,void * dst)435 void Unfold3dCopyCPU(
436     ScalarType dtype,
437     const void *src,
438     int64_t C,
439     int64_t X_D,
440     int64_t X_H,
441     int64_t X_W,
442     int64_t Y_D,
443     int64_t Y_H,
444     int64_t Y_W,
445     int64_t kernel_d,
446     int64_t kernel_h,
447     int64_t kernel_w,
448     int64_t stride_d,
449     int64_t stride_h,
450     int64_t stride_w,
451     int64_t pad_d,
452     int64_t pad_h,
453     int64_t pad_w,
454     void* dst) {
455   AT_DISPATCH_ALL_TYPES_AND2(
456       at::ScalarType::BFloat16,
457       at::ScalarType::Half,
458       dtype,
459       "Unfold3dCopyCPU",
460       [=, &src]() {
461         Unfold3dCopyKernelImpl<scalar_t>(
462             C,
463             X_D,
464             X_H,
465             X_W,
466             Y_D,
467             Y_H,
468             Y_W,
469             kernel_d,
470             kernel_h,
471             kernel_w,
472             stride_d,
473             stride_h,
474             stride_w,
475             pad_d,
476             pad_h,
477             pad_w,
478             static_cast<const scalar_t*>(src),
479             static_cast<scalar_t*>(dst));
480       });
481 }
482 
Unfold3dAccCPU(ScalarType dtype,const void * src,int64_t C,int64_t X_D,int64_t X_H,int64_t X_W,int64_t Y_D,int64_t Y_H,int64_t Y_W,int64_t kernel_d,int64_t kernel_h,int64_t kernel_w,int64_t stride_d,int64_t stride_h,int64_t stride_w,int64_t pad_d,int64_t pad_h,int64_t pad_w,void * dst)483 void Unfold3dAccCPU(
484     ScalarType dtype,
485     const void *src,
486     int64_t C,
487     int64_t X_D,
488     int64_t X_H,
489     int64_t X_W,
490     int64_t Y_D,
491     int64_t Y_H,
492     int64_t Y_W,
493     int64_t kernel_d,
494     int64_t kernel_h,
495     int64_t kernel_w,
496     int64_t stride_d,
497     int64_t stride_h,
498     int64_t stride_w,
499     int64_t pad_d,
500     int64_t pad_h,
501     int64_t pad_w,
502     void* dst) {
503   AT_DISPATCH_ALL_TYPES_AND2(
504       at::ScalarType::BFloat16,
505       at::ScalarType::Half,
506       dtype,
507       "Unfold3dAccCPU",
508       [=, &src]() {
509         Unfold3dAccKernelImpl<scalar_t>(
510             C,
511             X_D,
512             X_H,
513             X_W,
514             Y_D,
515             Y_H,
516             Y_W,
517             kernel_d,
518             kernel_h,
519             kernel_w,
520             stride_d,
521             stride_h,
522             stride_w,
523             pad_d,
524             pad_h,
525             pad_w,
526             static_cast<const scalar_t*>(src),
527             static_cast<scalar_t*>(dst));
528       });
529 }
530 
531 } // namespace at::native
532