xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/indirection.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <stddef.h>
10 
11 #include <fxdiv.h>
12 
13 #include <qnnpack/indirection.h>
14 #include <qnnpack/math.h>
15 #include <qnnpack/operator.h>
16 
pytorch_qnnp_indirection_init_conv3d(pytorch_qnnp_operator_t op,size_t output_tile_size,size_t tiled_output_size)17 void pytorch_qnnp_indirection_init_conv3d(
18     pytorch_qnnp_operator_t op,
19     size_t output_tile_size,
20     size_t tiled_output_size) {
21   const void** indirection_buffer = op->indirection_buffer;
22   const void* input = op->input;
23   const size_t input_pixel_stride = op->input_pixel_stride;
24   const void* zero = op->zero_pointer;
25   const size_t groups = op->groups;
26   const size_t group_input_channels = op->group_input_channels;
27   const size_t batch_size = op->batch_size;
28   const size_t input_depth = op->input_depth;
29   const size_t input_height = op->input_height;
30   const size_t input_width = op->input_width;
31   const size_t output_depth = op->output_depth;
32   const size_t output_height = op->output_height;
33   const size_t output_width = op->output_width;
34   const size_t kernel_depth = op->kernel_depth;
35   const size_t kernel_height = op->kernel_height;
36   const size_t kernel_width = op->kernel_width;
37   const size_t stride_depth = op->stride_depth;
38   const size_t stride_height = op->stride_height;
39   const size_t stride_width = op->stride_width;
40   const size_t dilation_depth = op->dilation_depth;
41   const size_t dilation_height = op->dilation_height;
42   const size_t dilation_width = op->dilation_width;
43   const size_t input_padding_depth = op->input_padding_depth;
44   const size_t input_padding_height = op->input_padding_height;
45   const size_t input_padding_width = op->input_padding_width;
46 
47   const size_t output_size = output_depth * output_height * output_width;
48   const size_t kernel_size = kernel_depth * kernel_height * kernel_width;
49   const struct fxdiv_divisor_size_t output_yx_divisor =
50       fxdiv_init_size_t(output_height * output_width);
51   const struct fxdiv_divisor_size_t output_x_divisor =
52       fxdiv_init_size_t(output_width);
53   for (size_t group = 0; group < groups; group++) {
54     for (size_t image = 0; image < batch_size; image++) {
55       for (size_t output_tile_start = 0; output_tile_start < tiled_output_size;
56            output_tile_start += output_tile_size) {
57         for (size_t output_tile_offset = 0;
58              output_tile_offset < output_tile_size;
59              output_tile_offset++) {
60           const size_t tiled_output_index =
61               output_tile_start + output_tile_offset;
62           const size_t output_index = min(tiled_output_index, output_size - 1);
63           const struct fxdiv_result_size_t z_yx =
64               fxdiv_divide_size_t(output_index, output_yx_divisor);
65           const struct fxdiv_result_size_t y_x =
66               fxdiv_divide_size_t(z_yx.remainder, output_x_divisor);
67           const size_t output_z = z_yx.quotient;
68           const size_t output_y = y_x.quotient;
69           const size_t output_x = y_x.remainder;
70 
71           for (size_t kernel_z = 0; kernel_z < kernel_depth; kernel_z++) {
72             const size_t input_z = output_z * stride_depth +
73                 kernel_z * dilation_depth - input_padding_depth;
74             if (input_z < input_depth) {
75               for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
76                 const size_t input_y = output_y * stride_height +
77                     kernel_y * dilation_height - input_padding_height;
78                 if (input_y < input_height) {
79                   for (size_t kernel_x = 0; kernel_x < kernel_width;
80                        kernel_x++) {
81                     const size_t input_x = output_x * stride_width +
82                         kernel_x * dilation_width - input_padding_width;
83                     const size_t index = (group * batch_size + image) *
84                             tiled_output_size * kernel_size +
85                         output_tile_start * kernel_size +
86                         ((kernel_height * kernel_z + kernel_y) * kernel_width +
87                          kernel_x) *
88                             output_tile_size +
89                         output_tile_offset;
90                     if (input_x < input_width) {
91                       indirection_buffer[index] = (char*)input +
92                           (((image * input_depth + input_z) * input_height +
93                             input_y) *
94                                input_width +
95                            input_x) *
96                               input_pixel_stride +
97                           group * group_input_channels;
98                     } else {
99                       indirection_buffer[index] = zero;
100                     }
101                   }
102                 } else {
103                   for (size_t kernel_x = 0; kernel_x < kernel_width;
104                        kernel_x++) {
105                     const size_t index = (group * batch_size + image) *
106                             tiled_output_size * kernel_size +
107                         output_tile_start * kernel_size +
108                         ((kernel_height * kernel_z + kernel_y) * kernel_width +
109                          kernel_x) *
110                             output_tile_size +
111                         output_tile_offset;
112                     indirection_buffer[index] = zero;
113                   }
114                 }
115               }
116             } else {
117               for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
118                 for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
119                   const size_t index = (group * batch_size + image) *
120                           tiled_output_size * kernel_size +
121                       output_tile_start * kernel_size +
122                       ((kernel_height * kernel_z + kernel_y) * kernel_width +
123                        kernel_x) *
124                           output_tile_size +
125                       output_tile_offset;
126                   indirection_buffer[index] = zero;
127                 }
128               }
129             }
130           }
131         }
132       }
133     }
134   }
135 }
136 
137 /**
138  * Imagine a we want to do dw conv or avgpooling with these parameters:
139  * kernel_width/height=3 stride=2
140  * Input is:
141  *  ---------------
142  *  |0|1|2|3|4|5|6|
143  *  ---------------       -------
144  *  | | | | | | | |   to  |0|1|2|
145  *  ---------------       -------
146  *  | | | | | | | |       | | | |
147  *  ---------------       -------
148  *  | | | | | | | |
149  *  ---------------
150  *  | | | | | | | |
151  *  ---------------
152  *
153  *  Thus we are going from width=7 height=5 input to width=3 height=2
154  *  Convince yourself that input 5x7 with pooling params of 3x3 kernel
155  *  with 2x2 stride gets you to 2x3 output.
156  *  Now for each output place (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
157  *  we have 3x3 input.
158  *  For just the first row of output this will look like as follows:
159  *  pixel:0   pixel:1  pixel:2
160  *  -------   -------  -------
161  *  |0|1|2|   |2|3|4|  |4|5|6|
162  *  -------   -------  -------
163  *  | | | |   | | | |  | | | |
164  *  -------   -------  -------
165  *  | | | |   | | | |  | | | |
166  *  -------   -------  -------
167  *  As you can see there is some overlap in the input needed for each
168  *  output pixel.
169  *  What is indirection buffer:
170  *  Indirection buffer just stores the pointer to the underlying data.
171  *  In this case pointer for a particular input position will point to
172  *  all the input channels of that position in NHWC format.
173  *  So one option for the aforemnetioned storage would be:
174  *  For each output position: store a 3x3 array of pointers. Thus we
175  *  would have 3x3 * 3 (3 output pixel of the first row) = 27 pointers
176  *  stored.
177  *  Now instead we store the pointer in this format:
178  *  ---------------
179  *  |0|1|2|3|4|5|6|
180  *  ---------------
181  *  | | | | | | | |
182  *  ---------------
183  *  | | | | | | | |
184  *  ---------------
185  *  Then we have all the pointers needed as before, but with less duplication.
186  *  So instead of 27 pointers now we have:
187  *  (3 (# of output pixels) - 1) * (stride) * 3 (kernel height) * + 3 * 3 (kernel h*w)
188  *  = 4 * 3 + 9
189  *  = 21 pointers.
190  *  which is the equation below.
191  *  Now in order for this to work the kernel has to be adjusted.
192  *  Here the kernel produced output worth of entire width. Thus as you move from one
193  *  pixel to the next, the jump in the indirection buffer has to be not 3*3 = 9
194  *  but kernel height (3) * stride (2) = 6.
195  *  This you will see operator-run.c
196  *
197  * step_width: The number of yz slices of the kernel to traverse to move from
198  *   the starting input index of an output pixel in the indirection buffer to
199  *   that of the output pixel directly after it in the same row.
200  *   i.e. if indirection_buffer[j] points to the first input pixel used to
201  *   compute the i'th output pixel, then
202  *   indirection_buffer[j + (kernel_depth * kernel_height * step_width)]
203  *   points to the first input pixel used to compute the (i + 1)'th output
204  *   pixel if in the same row
205  *   When dilation is 1 (for convolution): if neighboring output pixels use
206  *   overlapping regions of the input, this overlap is not included in the
207  *   indirection buffer (saving some space), hence step width is set to stride
208  *   width
209  *
210  * step_height: The number of pointers to traverse to move from an output
211  *   pixel's first input's index in the indirection buffer to that of the
212  *   output pixel one ROW (one output y) after it.
213  *   i.e. if indirection_buffer[j] points to the first input pixel used to
214  *   compute the i'th output pixel, then
215  *   indirection_buffer[j + step_height] points to the first
216  *   input pixel used to compute the output pixel one row below-
217  *   the (i + output_width)'th output pixel
218  *
219  * step_depth: Same as step height but for an xy slice rather than a row
220  *
221  * The input operator's step dimensions must have been set up before calling
222  * this function.
223  */
pytorch_qnnp_indirection_init_dwconv(pytorch_qnnp_operator_t op,size_t batch_start)224 void pytorch_qnnp_indirection_init_dwconv(
225     pytorch_qnnp_operator_t op,
226     size_t batch_start) {
227   const void** indirection_buffer = op->indirection_buffer;
228   const void* input = op->input;
229   const size_t input_pixel_stride = op->input_pixel_stride;
230   const void* zero = op->zero_pointer;
231   const size_t batch_size = op->batch_size;
232   const size_t input_depth = op->input_depth;
233   const size_t input_height = op->input_height;
234   const size_t input_width = op->input_width;
235   const size_t output_depth = op->output_depth;
236   const size_t output_height = op->output_height;
237   const size_t output_width = op->output_width;
238   const size_t kernel_depth = op->kernel_depth;
239   const size_t kernel_height = op->kernel_height;
240   const size_t kernel_width = op->kernel_width;
241   const size_t stride_depth = op->stride_depth;
242   const size_t stride_height = op->stride_height;
243   const size_t stride_width = op->stride_width;
244   const size_t dilation_depth = op->dilation_depth;
245   const size_t dilation_height = op->dilation_height;
246   const size_t dilation_width = op->dilation_width;
247   const size_t input_padding_depth = op->input_padding_depth;
248   const size_t input_padding_height = op->input_padding_height;
249   const size_t input_padding_width = op->input_padding_width;
250   const size_t step_depth = op->step_depth;
251   const size_t step_height = op->step_height;
252   const size_t step_width = op->step_width;
253 
254 #define DW_CONV_3D_INDEX(oz, oy, ox, kz, ky, kx)                              \
255   /* Output Pixel */                                                          \
256   (image * output_depth + oz) * step_depth + /* slice */                      \
257   oy * step_height + /* row */                                                \
258   ox * step_width * kernel_height * kernel_depth + /* column */               \
259   /* Kernel */                                                                \
260   kx * kernel_depth * kernel_height + /* column */                            \
261   ky * kernel_depth + /* row */                                               \
262   kz /* slice */
263 
264   for (size_t image = batch_start; image < batch_size; image++) {
265     for (size_t output_z = 0; output_z < output_depth; output_z++) {
266       for (size_t kernel_z = 0; kernel_z < kernel_depth; kernel_z++) {
267         const size_t input_z = output_z * stride_depth +
268             kernel_z * dilation_depth - input_padding_depth;
269         if (input_z < input_depth) {
270           for (size_t output_y = 0; output_y < output_height; output_y++) {
271             for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
272               const size_t input_y = output_y * stride_height +
273                   kernel_y * dilation_height - input_padding_height;
274               if (input_y < input_height) {
275                 for (size_t output_x = 0; output_x < output_width; output_x++) {
276                   for (size_t kernel_x = 0; kernel_x < kernel_width;
277                        kernel_x++) {
278                     const size_t input_x = output_x * stride_width +
279                         kernel_x * dilation_width - input_padding_width;
280                     const size_t index = DW_CONV_3D_INDEX(
281                         output_z,
282                         output_y,
283                         output_x,
284                         kernel_z,
285                         kernel_y,
286                         kernel_x);
287                     if (input_x < input_width) {
288                       indirection_buffer[index] = (char*)input +
289                           ((image * input_depth + input_z) * input_height *
290                                input_width + // slice
291                            input_y * input_width + // row
292                            input_x // column
293                            ) * input_pixel_stride;
294                     } else {
295                       indirection_buffer[index] = zero;
296                     }
297                   }
298                 }
299               } else {
300                 for (size_t output_x = 0; output_x < output_width; output_x++) {
301                   for (size_t kernel_x = 0; kernel_x < kernel_width;
302                        kernel_x++) {
303                     const size_t index = DW_CONV_3D_INDEX(
304                         output_z,
305                         output_y,
306                         output_x,
307                         kernel_z,
308                         kernel_y,
309                         kernel_x);
310                     indirection_buffer[index] = zero;
311                   }
312                 }
313               }
314             }
315           }
316         } else {
317           for (size_t output_y = 0; output_y < output_height; output_y++) {
318             for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
319               for (size_t output_x = 0; output_x < output_width; output_x++) {
320                 for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
321                   const size_t index = DW_CONV_3D_INDEX(
322                       output_z,
323                       output_y,
324                       output_x,
325                       kernel_z,
326                       kernel_y,
327                       kernel_x);
328                   indirection_buffer[index] = zero;
329                 }
330               }
331             }
332           }
333         }
334       }
335     }
336   }
337 }
338 
pytorch_qnnp_indirection_init_deconv2d(pytorch_qnnp_operator_t op,size_t output_tile_size,size_t tiled_output_size)339 void pytorch_qnnp_indirection_init_deconv2d(
340     pytorch_qnnp_operator_t op,
341     size_t output_tile_size,
342     size_t tiled_output_size) {
343   const void** indirection_buffer = op->indirection_buffer;
344   const void* input = op->input;
345   const size_t input_pixel_stride = op->input_pixel_stride;
346   const void* zero = op->zero_pointer;
347   const size_t groups = op->groups;
348   const size_t group_input_channels = op->group_input_channels;
349   const size_t batch_size = op->batch_size;
350   const size_t input_height = op->input_height;
351   const size_t input_width = op->input_width;
352   const size_t output_height = op->output_height;
353   const size_t output_width = op->output_width;
354   const size_t kernel_height = op->kernel_height;
355   const size_t kernel_width = op->kernel_width;
356   const size_t stride_height = op->stride_height;
357   const size_t stride_width = op->stride_width;
358   const size_t dilation_height = op->dilation_height;
359   const size_t dilation_width = op->dilation_width;
360   const size_t input_padding_height = op->input_padding_height;
361   const size_t input_padding_width = op->input_padding_width;
362 
363   const size_t output_size = output_height * output_width;
364   const size_t kernel_size = kernel_height * kernel_width;
365 
366   for (size_t group = 0; group < groups; group++) {
367     for (size_t image = 0; image < batch_size; image++) {
368       for (size_t output_tile_start = 0; output_tile_start < tiled_output_size;
369            output_tile_start += output_tile_size) {
370         for (size_t output_tile_offset = 0;
371              output_tile_offset < output_tile_size;
372              output_tile_offset++) {
373           const size_t tiled_output_index =
374               output_tile_start + output_tile_offset;
375           const size_t output_index = min(tiled_output_index, output_size - 1);
376           const size_t output_y = output_index / output_width;
377           const size_t output_x = output_index % output_width;
378           for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
379             const size_t y =
380                 output_y + input_padding_height - kernel_y * dilation_height;
381             const size_t input_y = y / stride_height;
382             for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
383               const size_t x =
384                   output_x + input_padding_width - kernel_x * dilation_width;
385               const size_t input_x = x / stride_width;
386               const size_t index = (group * batch_size + image) *
387                       tiled_output_size * kernel_size +
388                   output_tile_start * kernel_size +
389                   (kernel_y * kernel_width + kernel_x) * output_tile_size +
390                   output_tile_offset;
391               if (input_y * stride_height == y && input_y < input_height &&
392                   input_x * stride_width == x && input_x < input_width) {
393                 indirection_buffer[index] = (char*)input +
394                     ((image * input_height + input_y) * input_width + input_x) *
395                         input_pixel_stride +
396                     group * group_input_channels;
397               } else {
398                 indirection_buffer[index] = zero;
399               }
400             }
401           }
402         }
403       }
404     }
405   }
406 }
407 
pytorch_qnnp_indirection_init_maxpool2d(pytorch_qnnp_operator_t op,size_t batch_start)408 void pytorch_qnnp_indirection_init_maxpool2d(
409     pytorch_qnnp_operator_t op,
410     size_t batch_start) {
411   const void** indirection_buffer = op->indirection_buffer;
412   const void* input = op->input;
413   const size_t input_pixel_stride = op->input_pixel_stride;
414   const size_t batch_size = op->batch_size;
415   const size_t input_height = op->input_height;
416   const size_t input_width = op->input_width;
417   const size_t output_height = op->output_height;
418   const size_t output_width = op->output_width;
419   const size_t pooling_height = op->kernel_height;
420   const size_t pooling_width = op->kernel_width;
421   const size_t stride_height = op->stride_height;
422   const size_t stride_width = op->stride_width;
423   const size_t dilation_height = op->dilation_height;
424   const size_t dilation_width = op->dilation_width;
425   const size_t input_padding_height = op->input_padding_height;
426   const size_t input_padding_width = op->input_padding_width;
427   const size_t step_height = op->step_height;
428   const size_t step_width = op->step_width;
429 
430   for (size_t image = batch_start; image < batch_size; image++) {
431     for (size_t output_y = 0; output_y < output_height; output_y++) {
432       for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) {
433         const size_t input_y =
434             doz(output_y * stride_height + pooling_y * dilation_height,
435                 input_padding_height);
436         const size_t clamped_input_y = min(input_y, input_height - 1);
437         for (size_t output_x = 0; output_x < output_width; output_x++) {
438           for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) {
439             const size_t input_x =
440                 doz(output_x * stride_width + pooling_x * dilation_width,
441                     input_padding_width);
442             const size_t clamped_input_x = min(input_x, input_width - 1);
443             const size_t index =
444                 (image * output_height + output_y) * step_height +
445                 output_x * step_width * pooling_height +
446                 pooling_x * pooling_height + pooling_y;
447             indirection_buffer[index] = (char*)input +
448                 ((image * input_height + clamped_input_y) * input_width +
449                  clamped_input_x) *
450                     input_pixel_stride;
451           }
452         }
453       }
454     }
455   }
456 }
457 
pytorch_qnnp_indirection_set_step_dimensions(pytorch_qnnp_operator_t op)458 void pytorch_qnnp_indirection_set_step_dimensions(pytorch_qnnp_operator_t op) {
459   const size_t original_kernel_depth = op->kernel_depth;
460   const size_t kernel_depth =
461       (original_kernel_depth != 0) ? original_kernel_depth : 1;
462   const size_t kernel_height = op->kernel_height;
463   const size_t kernel_width = op->kernel_width;
464   const size_t kernel_size = kernel_depth * kernel_height * kernel_width;
465   const size_t output_height = op->output_height;
466   const size_t output_width = op->output_width;
467 
468   size_t step_width = 0;
469   switch (op->ukernel_type) {
470     case pytorch_qnnp_ukernel_type_dwconv:
471       step_width = op->dilation_width == 1 ? op->stride_width : kernel_width;
472       break;
473     case pytorch_qnnp_ukernel_type_average_pooling:
474       step_width = min(op->stride_width, kernel_width);
475       break;
476     case pytorch_qnnp_ukernel_type_max_pooling:
477       step_width = op->dilation_width > 1 ? kernel_width
478                                           : min(op->stride_width, kernel_width);
479       break;
480     default:
481       PYTORCH_QNNP_UNREACHABLE;
482   }
483 
484   const size_t step_height = kernel_size +
485       (output_width - 1) * step_width * kernel_height * kernel_depth;
486 
487   const size_t step_depth = step_height * output_height;
488 
489   op->step_depth = step_depth;
490   op->step_height = step_height;
491   op->step_width = step_width;
492 }
493