xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_conv/addressing.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "addressing.hpp"
26 #include "src/core/NEON/kernels/arm_gemm/utils.hpp"
27 #include <cstring>
28 
29 namespace arm_conv {
30 namespace addressing {
31 
fill_pointer_array(size_t element_size,void ** dest_raw,const unsigned int array_rows,const unsigned int array_cols,void * base_ptr_raw,size_t ld_row,size_t ld_col,void * pad_buffer_raw,const unsigned int pad_top,const unsigned int valid_rows,const unsigned int pad_left,const unsigned int valid_cols)32 void fill_pointer_array(
33   size_t element_size,
34   void **dest_raw, const unsigned int array_rows, const unsigned int array_cols,
35   void *base_ptr_raw, size_t ld_row, size_t ld_col,
36   void *pad_buffer_raw,
37   const unsigned int pad_top, const unsigned int valid_rows,
38   const unsigned int pad_left, const unsigned int valid_cols
39 )
40 {
41   auto dest = reinterpret_cast<char **>(dest_raw);
42   auto base_ptr = reinterpret_cast<char *>(base_ptr_raw);
43   auto pad_buffer = reinterpret_cast<char *>(pad_buffer_raw);
44   ld_row *= element_size;
45   ld_col *= element_size;
46 
47   const auto last_valid_row = std::min(pad_top + valid_rows, array_rows);
48   const auto last_valid_col = std::min(pad_left + valid_cols, array_cols);
49 
50   unsigned int i = 0;
51   for (; i < pad_top; i++)
52   {
53     for (unsigned int j = 0; j < array_cols; j++)
54     {
55       *(dest++) = pad_buffer;
56     }
57   }
58   for (; i < last_valid_row; i++)
59   {
60     unsigned int j = 0;
61     auto colptr = base_ptr;
62     base_ptr += ld_row;
63 
64     for (; j < pad_left; j++)
65     {
66       *(dest++) = pad_buffer;
67     }
68     for (; j < last_valid_col; j++)
69     {
70       *(dest++) = colptr;
71       colptr += ld_col;
72     }
73     for (; j < array_cols; j++)
74     {
75       *(dest++) = pad_buffer;
76     }
77   }
78   for (; i < array_rows; i++)
79   {
80     for (unsigned int j = 0; j < array_cols; j++)
81     {
82       *(dest++) = pad_buffer;
83     }
84   }
85 }
86 
87 
fill_pointer_array_generic_kernel(const size_t element_size,void ** dest_raw,const unsigned int output_rows,const unsigned int output_cols,const unsigned int kernel_rows,const unsigned int kernel_cols,const unsigned int stride_rows,const unsigned int stride_cols,void * base_ptr_raw,size_t ld_row,size_t ld_col,void * pad_buffer_raw,const unsigned int pad_top,const unsigned int valid_rows,const unsigned int pad_left,const unsigned int valid_cols)88 void fill_pointer_array_generic_kernel(
89   const size_t element_size,
90   void **dest_raw,
91   const unsigned int output_rows, const unsigned int output_cols,
92   const unsigned int kernel_rows, const unsigned int kernel_cols,
93   const unsigned int stride_rows, const unsigned int stride_cols,
94   void *base_ptr_raw, size_t ld_row, size_t ld_col,
95   void *pad_buffer_raw,
96   const unsigned int pad_top, const unsigned int valid_rows,
97   const unsigned int pad_left, const unsigned int valid_cols
98 )
99 {
100   auto dest = reinterpret_cast<char **>(dest_raw);
101   auto base_ptr = reinterpret_cast<char *>(base_ptr_raw);
102   auto pad_buffer = reinterpret_cast<char *>(pad_buffer_raw);
103   ld_row *= element_size;
104   ld_col *= element_size;
105 
106   const auto last_valid_row = pad_top + valid_rows;
107   const auto last_valid_col = pad_left + valid_cols;
108   const auto point_stride = output_rows * output_cols;
109 
110   // Iterate over the output points, after every point increment the pointer
111   // into the address array.
112   for (unsigned int oi = 0; oi < output_rows; oi++)
113   {
114     for (unsigned int oj = 0; oj < output_cols; oj++)
115     {
116       auto point_dest = dest;
117       dest++;
118 
119       // Iterate over kernel points and fill in the pointer array.
120       unsigned int ki = 0, ii = oi*stride_rows;
121       for (; ii < pad_top && ki < kernel_rows; ii++, ki++)
122       {
123         // Fill with padding
124         for (unsigned int j = 0; j < kernel_cols; j++)
125         {
126           *point_dest = pad_buffer;
127           point_dest += point_stride;
128         }
129       }
130       for (; ii < last_valid_row && ki < kernel_rows; ii++, ki++)
131       {
132         unsigned int kj = 0, ij = oj*stride_cols;
133         for (; ij < pad_left && kj < kernel_cols; ij++, kj++)
134         {
135           // Padding
136           *point_dest = pad_buffer;
137           point_dest += point_stride;
138         }
139         for (; ij < last_valid_col && kj < kernel_cols; ij++, kj++)
140         {
141           *point_dest = base_ptr + (ii - pad_top)*ld_row + (ij - pad_left)*ld_col;
142           point_dest += point_stride;
143         }
144         for (; kj < kernel_cols; kj++)
145         {
146           // Padding
147           *point_dest = pad_buffer;
148           point_dest += point_stride;
149         }
150       }
151       for (; ki < kernel_rows; ki++)
152       {
153         // Fill with padding
154         for (unsigned int j = 0; j < kernel_cols; j++)
155         {
156           *point_dest = pad_buffer;
157           point_dest += point_stride;
158         }
159       }
160     }
161   }
162 }
163 
164 /* Patch array constructor
165  *
166  * Some depthwise kernels require an NCHW-ordered patch of input. Here we
167  * construct such a patch, and fill in an array of pointers to the rows of the
168  * patch.
169  */
fill_nchw_patch_array(size_t element_size,const void ** dest_row_pointers_raw,void * dest_patch_raw,const unsigned int patch_rows,unsigned int patch_cols,const void * src_ptr_raw,size_t ld_row,size_t ld_col,const void * pad_row,const unsigned int pad_top,const unsigned int valid_rows,const unsigned int pad_left,const unsigned int valid_cols)170 void fill_nchw_patch_array(
171   size_t element_size,
172   const void **dest_row_pointers_raw,  // Array of pointers to each row of the patch
173   void *dest_patch_raw,  // Pointer to space which can be used to construct the patch
174   const unsigned int patch_rows, unsigned int patch_cols,  // Patch size
175   const void *src_ptr_raw, size_t ld_row, size_t ld_col,  // Source tensor
176   const void *pad_row,  // Pointer to a row of padding values
177   const unsigned int pad_top, const unsigned int valid_rows,
178   const unsigned int pad_left, const unsigned int valid_cols
179 )
180 {
181   // Convert into more useful types
182   auto row_pointers = reinterpret_cast<const char **>(dest_row_pointers_raw);
183   auto dest_patch = reinterpret_cast<char *>(dest_patch_raw);
184   auto src = reinterpret_cast<const char *>(src_ptr_raw);
185   ld_row *= element_size;
186   ld_col *= element_size;
187 
188   // Round up the patch columns to be a full quad
189   patch_cols = arm_gemm::roundup<unsigned int>(patch_cols, 16 / element_size);
190 
191   const auto last_valid_row = std::min(pad_top + valid_rows, patch_rows);
192   const auto last_valid_col = std::min(pad_left + valid_cols, patch_cols);
193 
194   // Construct the patch and row pointer array together
195   unsigned int i = 0;
196   for (; i < pad_top; i++)
197   {
198     // Insert pointers into the padding row
199     *(row_pointers++) = reinterpret_cast<const char *>(pad_row);
200   }
201   for (; i < last_valid_row; i++)
202   {
203     // Get a copy of the pointer for this row
204     auto colptr = src;
205     src += ld_row;
206 
207     // If the input is already in NCHW format (ld_col == element_size) AND
208     // there is no padding, then we just use a pointer to the source tensor;
209     // otherwise we need to construct a patch and provide a pointer to it.
210     if (ld_col == element_size && pad_left == 0 && last_valid_col == patch_cols)
211     {
212       *(row_pointers++) = colptr;
213     }
214     else
215     {
216       auto patch_col = dest_patch;
217       *(row_pointers++) = dest_patch;
218       dest_patch += element_size * patch_cols;  // Move the patch pointer on
219 
220       // Construct the patch; fill the entirety with padding and then copy in
221       // the valid elements.
222       memcpy(patch_col, pad_row, element_size * patch_cols);
223       patch_col += pad_left * element_size;  // Move over the left padding
224 
225       if (ld_col == element_size)
226       {
227         // If the input is NCHW then copy across as many columns as we can.
228         memcpy(patch_col, colptr, (last_valid_col - pad_left) * element_size);
229       }
230       else
231       {
232         // If the input is NHWC then copy columns across in turn.
233         for (auto j = pad_left; j < last_valid_col; j++)
234         {
235           memcpy(patch_col, colptr, element_size);  // Copy the valid element
236           patch_col += element_size;  // Progress the patch destination
237           colptr += ld_col;  // Progress the patch source
238         }
239       }
240     }
241   }
242   for (; i < patch_rows; i++)
243   {
244     // Insert pointers into the padding row
245     *(row_pointers++) = reinterpret_cast<const char *>(pad_row);
246   }
247 }
248 
249 
250 /* Patch array constructor (generic kernels)
251  *
252  * Construct an array of pointers; one pointer for each output row for each
253  * kernel point. Pointers should point at a whole number of QUADS containing an
254  * input point for each output point. If the kernel column stride is 1 and the
255  * data is NCHW then the input tensor might be addressed directly, otherwise a
256  * new patch sample might need to be constructed.
257  */
fill_patch_array_generic_kernel(size_t element_size,const void ** dest_pointers_raw,void * patch_raw,const unsigned int output_rows,const unsigned int output_cols,const unsigned int kernel_rows,const unsigned int kernel_cols,const unsigned int stride_rows,const unsigned int stride_cols,const void * src_ptr_raw,size_t ld_row,size_t ld_col,const void * pad_row,const unsigned int pad_top,const unsigned int valid_rows,const unsigned int pad_left,const unsigned int valid_cols)258 void fill_patch_array_generic_kernel(
259   size_t element_size,
260   const void **dest_pointers_raw,  // Pointers: one per output row per kernel point
261   void *patch_raw,  // Pointer to space which can be used to construct the patch
262   const unsigned int output_rows, const unsigned int output_cols,
263   const unsigned int kernel_rows, const unsigned int kernel_cols,
264   const unsigned int stride_rows, const unsigned int stride_cols,
265   const void *src_ptr_raw, size_t ld_row, size_t ld_col,  // Source tensor
266   const void *pad_row,  // Pointer to a row of padding values
267   const unsigned int pad_top, const unsigned int valid_rows,
268   const unsigned int pad_left, const unsigned int valid_cols
269 )
270 {
271   auto dest = reinterpret_cast<const char **>(dest_pointers_raw);
272   auto patch = reinterpret_cast<char *>(patch_raw);
273   auto src_ptr = reinterpret_cast<const char *>(src_ptr_raw);
274   ld_row *= element_size;
275   ld_col *= element_size;
276 
277   // Round up the patch columns to a multiple of quad-length
278   const auto patch_cols = arm_gemm::roundup<unsigned int>(output_cols, 16 / element_size);
279 
280   const auto input_rows = kernel_rows + (output_rows - 1) * stride_rows;
281   const auto last_valid_row = std::min(pad_top + valid_rows, input_rows);
282 
283   const auto input_cols = kernel_cols + (output_cols - 1) * stride_cols;
284   const auto last_valid_col = std::min(pad_left + valid_cols, input_cols);
285 
286   for (auto ki = 0u; ki < kernel_rows; ki++)
287   {
288     for (auto kj = 0u; kj < kernel_cols; kj++)
289     {
290       auto oi = 0u, ii = ki;
291       for (; oi < output_rows && ii < pad_top; oi++, ii += stride_rows)
292       {
293         // Insert a pointer to the padding row
294         *(dest++) = reinterpret_cast<const char *>(pad_row);
295       }
296       for (; oi < output_rows && ii < last_valid_row; oi++, ii += stride_rows)
297       {
298         auto rowptr = src_ptr + (ii - pad_top) * ld_row;
299 
300         // Construct a sample of the input here
301         auto patch_pos = patch;
302         *(dest++) = patch;
303         patch += patch_cols * element_size;
304 
305         // Fill with padding
306         memcpy(patch_pos, pad_row, patch_cols * element_size);
307 
308         // Fill in the valid elements
309         auto oj = 0u, ij = kj;
310         for (; oj < patch_cols && ij < pad_left; oj++, ij += stride_cols)
311         {
312           // Do nothing for padding
313           patch_pos += element_size;
314         }
315         for (; oj < patch_cols && ij < last_valid_col; oj++, ij += stride_cols)
316         {
317           // Copy from the source tensor
318           memcpy(patch_pos, rowptr + (ij - pad_left)*ld_col, element_size);
319           patch_pos += element_size;
320         }
321         // No action required for right-hand padding
322       }
323       for (; oi < output_rows; oi++)
324       {
325         *(dest++) = reinterpret_cast<const char *>(pad_row);
326       }
327     }
328   }
329 }
330 
331 }  // namespace addressing
332 }  // namespace arm_conv
333