xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/Unfold2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/cpu/vec/vec.h>
5 #include <ATen/native/Unfold2d.h>
6 #include <ATen/native/cpu/Loops.h>
7 #include <c10/util/irange.h>
8 #include <ATen/native/cpu/utils.h>
9 #include <cmath>
10 
11 namespace at::native {
12 
13 namespace {
14 
15 template <typename scalar_t>
cadd(scalar_t * z,const scalar_t * x,const scalar_t * y,int64_t n)16 static inline void cadd(
17     scalar_t* z,
18     const scalar_t* x,
19     const scalar_t* y,
20     int64_t n) {
21   using Vec = vec::Vectorized<scalar_t>;
22   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
23   char* ptrs[] = {reinterpret_cast<char*>(z),
24                   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
25                   reinterpret_cast<char*>(const_cast<scalar_t*>(x)),
26                   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
27                   reinterpret_cast<char*>(const_cast<scalar_t*>(y))};
28   vectorized_loop(
29       ptrs,
30       n,
31       -1,
32       [](scalar_t x, scalar_t y) -> scalar_t { return x + y; },
33       [](Vec x, Vec y) -> Vec { return x + y; });
34 }
35 
36 template <typename scalar_t>
unfolded2d_acc(scalar_t * finput_data,scalar_t * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)37 static void unfolded2d_acc(
38     scalar_t* finput_data,
39     scalar_t* input_data,
40     int64_t kH,
41     int64_t kW,
42     int64_t dH,
43     int64_t dW,
44     int64_t padH,
45     int64_t padW,
46     int64_t n_input_plane,
47     int64_t input_height,
48     int64_t input_width,
49     int64_t output_height,
50     int64_t output_width) {
51   at::parallel_for(0, n_input_plane, 0, [&](int64_t start, int64_t end) {
52     for (const auto nip : c10::irange(start, end)) {
53       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
54       int64_t kw, kh, y, x;
55       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
56       int64_t ix, iy;
57       for (kh = 0; kh < kH; kh++) {
58         for (kw = 0; kw < kW; kw++) {
59           scalar_t* src = finput_data +
60               nip * ((size_t)kH * kW * output_height * output_width) +
61               kh * ((size_t)kW * output_height * output_width) +
62               kw * ((size_t)output_height * output_width);
63           scalar_t* dst =
64               input_data + nip * ((size_t)input_height * input_width);
65           if (padW > 0 || padH > 0) {
66             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
67             int64_t lpad, rpad;
68             for (y = 0; y < output_height; y++) {
69               iy = (int64_t)y * dH - padH + kh;
70               if (iy < 0 || iy >= input_height) {
71               } else {
72                 if (dW == 1) {
73                   ix = 0 - padW + kw;
74                   lpad = std::max<int64_t>(0, padW - kw);
75                   rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
76                   scalar_t* dst_slice =
77                       dst + (size_t)iy * input_width + ix + lpad;
78                   cadd(
79                       dst_slice,
80                       dst_slice,
81                       src + (size_t)y * output_width + lpad,
82                       output_width - lpad - rpad);
83                 } else {
84                   for (x = 0; x < output_width; x++) {
85                     ix = (int64_t)x * dW - padW + kw;
86                     if (ix < 0 || ix >= input_width) {
87                     } else {
88                       scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
89                       *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
90                     }
91                   }
92                 }
93               }
94             }
95           } else {
96             for (y = 0; y < output_height; y++) {
97               iy = (int64_t)y * dH + kh;
98               ix = 0 + kw;
99               if (dW == 1) {
100                 scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
101                 cadd(
102                     dst_slice,
103                     dst_slice,
104                     src + (size_t)y * output_width,
105                     output_width);
106               } else {
107                 for (x = 0; x < output_width; x++) {
108                   scalar_t* dst_slice =
109                       dst + (size_t)iy * input_width + ix + x * dW;
110                   *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
111                 }
112               }
113             }
114           }
115         }
116       }
117     }
118   });
119 }
120 
121 template <typename scalar_t>
unfolded2d_acc_channels_last(scalar_t * finput_data,scalar_t * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)122 static void unfolded2d_acc_channels_last(
123     scalar_t* finput_data,
124     scalar_t* input_data,
125     int64_t kH,
126     int64_t kW,
127     int64_t dH,
128     int64_t dW,
129     int64_t padH,
130     int64_t padW,
131     int64_t n_input_plane,
132     int64_t input_height,
133     int64_t input_width,
134     int64_t output_height,
135     int64_t output_width) {
136 
137   for (int64_t y = 0; y < output_height; y++) {
138     for (int64_t x = 0; x < output_width; x++) {
139       scalar_t* src = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
140       scalar_t* dst = input_data;
141 
142       if (padW > 0 || padH > 0) {
143         for (int64_t kh = 0; kh < kH; kh++) {
144           for (int64_t kw = 0; kw < kW; kw++) {
145             int64_t iy = y * dH - padH + kh;
146             int64_t ix = x * dW - padW + kw;
147             if (iy < 0 || iy >= input_height || ix < 0 || ix >= input_width) {
148             } else {
149               scalar_t* dst_slice = dst + iy * input_width * n_input_plane + ix * n_input_plane;
150               scalar_t* src_slice = src + kh * kW * n_input_plane + kw * n_input_plane;
151               cadd(dst_slice,
152                    dst_slice,
153                    src_slice,
154                    n_input_plane);
155             }
156           }
157         }
158       } else {
159         for (int64_t kh = 0; kh < kH; kh++) {
160           for (int64_t kw = 0; kw < kW; kw++) {
161             int64_t iy = y * dH + kh;
162             int64_t ix = x * dW + kw;
163             scalar_t* dst_slice = dst + iy * input_width * n_input_plane + ix * n_input_plane;
164             scalar_t* src_slice = src + kh * kW * n_input_plane + kw * n_input_plane;
165             cadd(dst_slice,
166                  dst_slice,
167                  src_slice,
168                  n_input_plane);
169           }
170         }
171       }
172     }
173   }
174 }
175 
176 /* note: due to write issues, this one cannot be parallelized as well as
177  * unfolded2d_copy */
unfolded2d_acc_kernel(ScalarType dtype,void * finput_data,void * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,bool is_channels_last)178 void unfolded2d_acc_kernel(
179     ScalarType dtype,
180     void *finput_data,
181     void *input_data,
182     int64_t kH,
183     int64_t kW,
184     int64_t dH,
185     int64_t dW,
186     int64_t padH,
187     int64_t padW,
188     int64_t n_input_plane,
189     int64_t input_height,
190     int64_t input_width,
191     int64_t output_height,
192     int64_t output_width,
193     bool is_channels_last) {
194   // This function assumes that
195   // output_height*dH does not overflow a int64_t
196   // output_width*dW does not overflow a int64_t
197 
198   if (is_channels_last) {
199     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_acc_channels_last", [&] {
200       unfolded2d_acc_channels_last(
201           static_cast<scalar_t*>(finput_data),
202           static_cast<scalar_t*>(input_data),
203           kH, kW,
204           dH, dW,
205           padH, padW,
206           n_input_plane,
207           input_height,
208           input_width,
209           output_height,
210           output_width);
211      });
212   } else {
213     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_acc", [&] {
214       unfolded2d_acc(
215           static_cast<scalar_t*>(finput_data),
216           static_cast<scalar_t*>(input_data),
217           kH, kW,
218           dH, dW,
219           padH, padW,
220           n_input_plane,
221           input_height,
222           input_width,
223           output_height,
224           output_width);
225       });
226   }
227 }
228 
229 template <typename scalar_t>
unfolded2d_copy(const scalar_t * input_data,scalar_t * finput_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)230 static void unfolded2d_copy(
231     const scalar_t* input_data,
232     scalar_t* finput_data,
233     int64_t kH,
234     int64_t kW,
235     int64_t dH,
236     int64_t dW,
237     int64_t padH,
238     int64_t padW,
239     int64_t n_input_plane,
240     int64_t input_height,
241     int64_t input_width,
242     int64_t output_height,
243     int64_t output_width) {
244   at::parallel_for(
245       0, (int64_t)n_input_plane * kH * kW, 0, [&](int64_t start, int64_t end) {
246         for (const auto k : c10::irange(start, end)) {
247           int64_t nip = k / (kH * kW);
248           int64_t rest = k % (kH * kW);
249           int64_t kh = rest / kW;
250           int64_t kw = rest % kW;
251           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
252           int64_t x, y;
253           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
254           int64_t ix, iy;
255           scalar_t* dst = finput_data +
256               nip * ((size_t)kH * kW * output_height * output_width) +
257               kh * ((size_t)kW * output_height * output_width) +
258               kw * ((size_t)output_height * output_width);
259           const scalar_t* src =
260               input_data + nip * ((size_t)input_height * input_width);
261           if (padW > 0 || padH > 0) {
262             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
263             int64_t lpad, rpad;
264             for (y = 0; y < output_height; y++) {
265               iy = (int64_t)y * dH - padH + kh;
266               if (iy < 0 || iy >= input_height) {
267                 memset(
268                     dst + (size_t)y * output_width,
269                     0,
270                     sizeof(scalar_t) * output_width);
271               } else {
272                 if (dW == 1) {
273                   ix = 0 - padW + kw;
274                   lpad = std::max<int64_t>(0, padW - kw);
275                   rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
276                   if (output_width - rpad - lpad <= 0) {
277                     memset(
278                         dst + (size_t)y * output_width,
279                         0,
280                         sizeof(scalar_t) * output_width);
281                   } else {
282                     if (lpad > 0)
283                       memset(
284                           dst + (size_t)y * output_width,
285                           0,
286                           sizeof(scalar_t) * lpad);
287                     memcpy(
288                         dst + (size_t)y * output_width + lpad,
289                         src + (size_t)iy * input_width + ix + lpad,
290                         sizeof(scalar_t) * (output_width - rpad - lpad));
291                     if (rpad > 0)
292                       memset(
293                           dst + (size_t)y * output_width + output_width - rpad,
294                           0,
295                           sizeof(scalar_t) * rpad);
296                   }
297                 } else {
298                   for (x = 0; x < output_width; x++) {
299                     ix = (int64_t)x * dW - padW + kw;
300                     if (ix < 0 || ix >= input_width)
301                       memset(
302                           dst + (size_t)y * output_width + x,
303                           0,
304                           sizeof(scalar_t) * 1);
305                     else
306                       memcpy(
307                           dst + (size_t)y * output_width + x,
308                           src + (size_t)iy * input_width + ix,
309                           sizeof(scalar_t) * (1));
310                   }
311                 }
312               }
313             }
314           } else {
315             for (y = 0; y < output_height; y++) {
316               iy = (int64_t)y * dH + kh;
317               ix = 0 + kw;
318               if (dW == 1)
319                 memcpy(
320                     dst + (size_t)y * output_width,
321                     src + (size_t)iy * input_width + ix,
322                     sizeof(scalar_t) * output_width);
323               else {
324                 for (x = 0; x < output_width; x++)
325                   memcpy(
326                       dst + (size_t)y * output_width + x,
327                       src + (size_t)iy * input_width + ix + (int64_t)x * dW,
328                       sizeof(scalar_t) * (1));
329               }
330             }
331           }
332         }
333       });
334 }
335 
336 template <typename scalar_t>
unfolded2d_copy_channels_last(const scalar_t * input_data,scalar_t * finput_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)337 static void unfolded2d_copy_channels_last(
338     const scalar_t* input_data,
339     scalar_t* finput_data,
340     int64_t kH,
341     int64_t kW,
342     int64_t dH,
343     int64_t dW,
344     int64_t padH,
345     int64_t padW,
346     int64_t n_input_plane,
347     int64_t input_height,
348     int64_t input_width,
349     int64_t output_height,
350     int64_t output_width) {
351   at::parallel_for(0, output_height * output_width, 0, [&](int64_t start, int64_t end) {
352     int64_t y = 0;
353     int64_t x = 0;
354     data_index_init(start, y, output_height, x, output_width);
355 
356     for (const auto k C10_UNUSED: c10::irange(start, end)) {
357       scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
358       const scalar_t* src = input_data;
359 
360       if (padW > 0 || padH > 0) {
361         for (int64_t kh = 0; kh < kH; kh++) {
362           for (int64_t kw = 0; kw < kW; kw++) {
363             int64_t iy = y * dH - padH + kh;
364             int64_t ix = x * dW - padW + kw;
365             if (iy < 0 || iy >= input_height || ix < 0 || ix >= input_width) {
366               memset(dst + kh * kW * n_input_plane + kw * n_input_plane,
367                     0,
368                     sizeof(scalar_t) * n_input_plane);
369             } else {
370               memcpy(dst + kh * kW * n_input_plane + kw * n_input_plane,
371                      src + iy * input_width * n_input_plane + ix * n_input_plane,
372                      sizeof(scalar_t) * n_input_plane);
373             }
374           }
375         }
376       } else {
377         for (int64_t kh = 0; kh < kH; kh++) {
378           for (int64_t kw = 0; kw < kW; kw++) {
379             int64_t iy = y * dH + kh;
380             int64_t ix = x * dW + kw;
381             memcpy(dst + kh * kW * n_input_plane + kw * n_input_plane,
382                    src + iy * input_width * n_input_plane + ix * n_input_plane,
383                    sizeof(scalar_t) * n_input_plane);
384           }
385         }
386       }
387       // move on to next output index
388       data_index_step(y, output_height, x, output_width);
389     }
390   });
391 }
392 
unfolded2d_copy_kernel(ScalarType dtype,void * finput_data,const void * input_data,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,bool is_channels_last)393 void unfolded2d_copy_kernel(
394     ScalarType dtype,
395     void *finput_data,
396     const void *input_data,
397     int64_t kH,
398     int64_t kW,
399     int64_t dH,
400     int64_t dW,
401     int64_t padH,
402     int64_t padW,
403     int64_t n_input_plane,
404     int64_t input_height,
405     int64_t input_width,
406     int64_t output_height,
407     int64_t output_width,
408     bool is_channels_last) {
409   // This function assumes that
410   // kH*kW does not overflow an int
411   // n_input_plane*kH*kW does not overflow a int64_t
412   // output_height*dH does not overflow a int64_t
413   // output_width*dW does not overflow a int64_t
414 
415   if (is_channels_last) {
416     AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy_channels_last", [&] {
417       unfolded2d_copy_channels_last(
418           static_cast<const scalar_t*>(input_data),
419           static_cast<scalar_t*>(finput_data),
420             kH, kW,
421             dH, dW,
422             padH, padW,
423             n_input_plane,
424             input_height,
425             input_width,
426             output_height,
427             output_width);
428     });
429   } else {
430     AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy", [&] {
431       unfolded2d_copy(
432           static_cast<const scalar_t*>(input_data),
433           static_cast<scalar_t*>(finput_data),
434             kH, kW,
435             dH, dW,
436             padH, padW,
437             n_input_plane,
438             input_height,
439             input_width,
440             output_height,
441             output_width);
442     });
443   }
444 }
445 
446 } // namespace
447 
448 REGISTER_DISPATCH(unfolded2d_copy_stub, &unfolded2d_copy_kernel);
449 REGISTER_DISPATCH(unfolded2d_acc_stub, &unfolded2d_acc_kernel);
450 
451 } // namespace at::native
452