xref: /aosp_15_r20/external/libaom/av1/encoder/cnn.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 #include <stdbool.h>
15 
16 #include "aom_dsp/aom_dsp_common.h"
17 #include "av1/common/av1_common_int.h"
18 #include "av1/encoder/cnn.h"
19 
20 #define CLAMPINDEX(a, hi) ((a) < 0 ? 0 : ((a) >= (hi) ? ((hi)-1) : (a)))
21 
22 typedef struct {
23   const float **input;
24   int in_width;
25   int in_height;
26   int in_stride;
27   const CNN_LAYER_CONFIG *layer_config;
28   float **output;
29   int out_stride;
30   int start_idx;
31   int th_step;
32 } CONVOLVE_OPS;
33 
softsign(float x)34 static inline float softsign(float x) { return x / (fabsf(x) + 1.0f); }
35 
relu(float x)36 static inline float relu(float x) { return (x < 0) ? 0 : x; }
37 
38 typedef struct {
39   int allocsize;
40   int channels;
41   int width, height, stride;
42   float *buf[CNN_MAX_CHANNELS];
43 } TENSOR;
44 
init_tensor(TENSOR * tensor)45 static void init_tensor(TENSOR *tensor) { memset(tensor, 0, sizeof(*tensor)); }
46 
free_tensor(TENSOR * tensor)47 static void free_tensor(TENSOR *tensor) {
48   if (tensor->allocsize) {
49     aom_free(tensor->buf[0]);
50     tensor->buf[0] = NULL;
51     tensor->allocsize = 0;
52   }
53 }
54 
realloc_tensor(TENSOR * tensor,int channels,int width,int height)55 static bool realloc_tensor(TENSOR *tensor, int channels, int width,
56                            int height) {
57   const int newallocsize = channels * width * height;
58   if (tensor->allocsize < newallocsize) {
59     free_tensor(tensor);
60     tensor->buf[0] =
61         (float *)aom_malloc(sizeof(*tensor->buf[0]) * newallocsize);
62     if (!tensor->buf[0]) return false;
63     tensor->allocsize = newallocsize;
64   }
65   tensor->width = width;
66   tensor->height = height;
67   tensor->stride = width;
68   tensor->channels = channels;
69   for (int c = 1; c < channels; ++c)
70     tensor->buf[c] = &tensor->buf[0][c * width * height];
71   return true;
72 }
73 
copy_tensor(const TENSOR * src,int copy_channels,int dst_offset,TENSOR * dst)74 static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
75                         TENSOR *dst) {
76   assert(src->width == dst->width);
77   assert(src->height == dst->height);
78   assert(copy_channels <= src->channels);
79   if (src->stride == dst->width && dst->stride == dst->width) {
80     for (int c = 0; c < copy_channels; ++c) {
81       memcpy(dst->buf[dst_offset + c], src->buf[c],
82              sizeof(*dst->buf[0]) * src->width * src->height);
83     }
84   } else {
85     for (int c = 0; c < copy_channels; ++c) {
86       for (int r = 0; r < dst->height; ++r) {
87         memcpy(&dst->buf[dst_offset + c][r * dst->stride],
88                &src->buf[c][r * src->stride],
89                dst->width * sizeof(*dst->buf[c]));
90       }
91     }
92   }
93 }
94 
assign_tensor(TENSOR * tensor,float * buf[CNN_MAX_CHANNELS],int channels,int width,int height,int stride)95 static void assign_tensor(TENSOR *tensor, float *buf[CNN_MAX_CHANNELS],
96                           int channels, int width, int height, int stride) {
97   tensor->allocsize = 0;
98   tensor->channels = channels;
99   tensor->width = width;
100   tensor->height = height;
101   tensor->stride = stride;
102   if (buf) {
103     for (int c = 0; c < channels; ++c) tensor->buf[c] = buf[c];
104   } else {
105     for (int c = 0; c < channels; ++c) tensor->buf[c] = NULL;
106   }
107 }
108 
swap_tensor(TENSOR * t1,TENSOR * t2)109 static void swap_tensor(TENSOR *t1, TENSOR *t2) {
110   TENSOR t = *t1;
111   *t1 = *t2;
112   *t2 = t;
113 }
114 
115 // The concatenated tensor goes into dst with first the channels in
116 // original dst followed by the channels in the src
concat_tensor(const TENSOR * src,TENSOR * dst)117 static bool concat_tensor(const TENSOR *src, TENSOR *dst) {
118   assert(src->width == dst->width);
119   assert(src->height == dst->height);
120 
121   const int dst_channels = dst->channels;
122   const int channels = dst->channels + src->channels;
123   const int newallocsize = channels * dst->width * dst->height;
124   if (dst->allocsize < newallocsize) {
125     TENSOR t;
126     init_tensor(&t);
127     // allocate new buffers and copy first the dst channels
128     if (!realloc_tensor(&t, channels, dst->width, dst->height)) return false;
129     copy_tensor(dst, dst->channels, 0, &t);
130     // Swap the tensors and free the old buffers
131     swap_tensor(dst, &t);
132     free_tensor(&t);
133   }
134   for (int c = 1; c < channels; ++c)
135     dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
136   // Copy the channels in src after the first dst_channels channels.
137   copy_tensor(src, src->channels, dst_channels, dst);
138   return true;
139 }
140 
141 #ifndef NDEBUG
check_tensor_equal_dims(TENSOR * t1,TENSOR * t2)142 static int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
143   return (t1->width == t2->width && t1->height == t2->height);
144 }
145 
check_tensor_equal_size(TENSOR * t1,TENSOR * t2)146 static int check_tensor_equal_size(TENSOR *t1, TENSOR *t2) {
147   return (t1->channels == t2->channels && t1->width == t2->width &&
148           t1->height == t2->height);
149 }
150 #endif  // NDEBUG
151 
av1_find_cnn_layer_output_size(int in_width,int in_height,const CNN_LAYER_CONFIG * layer_config,int * out_width,int * out_height)152 void av1_find_cnn_layer_output_size(int in_width, int in_height,
153                                     const CNN_LAYER_CONFIG *layer_config,
154                                     int *out_width, int *out_height) {
155   assert(layer_config->skip_width > 0);
156   assert(layer_config->skip_height > 0);
157   if (!layer_config->deconvolve) {
158     switch (layer_config->pad) {
159       case PADDING_SAME_ZERO:
160       case PADDING_SAME_REPLICATE:
161         *out_width = (in_width + layer_config->skip_width - 1) /
162                      layer_config->skip_width;
163         *out_height = (in_height + layer_config->skip_height - 1) /
164                       layer_config->skip_height;
165         break;
166       case PADDING_VALID:
167         *out_width =
168             (in_width - layer_config->filter_width + layer_config->skip_width) /
169             layer_config->skip_width;
170         *out_height = (in_height - layer_config->filter_height +
171                        layer_config->skip_height) /
172                       layer_config->skip_height;
173         break;
174       default: assert(0 && "Unknown padding type");
175     }
176   } else {
177     switch (layer_config->pad) {
178       case PADDING_SAME_ZERO:
179       case PADDING_SAME_REPLICATE:
180         *out_width = in_width * layer_config->skip_width;
181         *out_height = in_height * layer_config->skip_height;
182         break;
183       case PADDING_VALID:
184         *out_width = (in_width - 1) * layer_config->skip_width +
185                      layer_config->filter_width;
186         *out_height = (in_height - 1) * layer_config->skip_height +
187                       layer_config->filter_height;
188         break;
189       default: assert(0 && "Unknown padding type");
190     }
191   }
192 }
193 
find_cnn_out_channels(const CNN_LAYER_CONFIG * layer_config,int channels_per_branch[])194 static void find_cnn_out_channels(const CNN_LAYER_CONFIG *layer_config,
195                                   int channels_per_branch[]) {
196   int branch = layer_config->branch;
197   const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
198   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
199     if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
200       if (layer_config->branch_copy_type == BRANCH_INPUT) {
201         channels_per_branch[b] = layer_config->in_channels;
202       } else if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
203         channels_per_branch[b] = layer_config->out_channels;
204       } else if (layer_config->branch_copy_type == BRANCH_COMBINED) {
205         channels_per_branch[b] = layer_config->out_channels;
206         for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
207           if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
208             assert(channels_per_branch[c] > 0);
209             channels_per_branch[b] += channels_per_branch[c];
210           }
211         }
212       }
213     }
214   }
215   channels_per_branch[branch] = layer_config->out_channels;
216   for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
217     if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
218       assert(channels_per_branch[c] > 0);
219       channels_per_branch[branch] += channels_per_branch[c];
220     }
221   }
222 }
223 
224 #if CONFIG_DEBUG
cnn_has_at_least_one_output(const CNN_CONFIG * cnn_config)225 static inline int cnn_has_at_least_one_output(const CNN_CONFIG *cnn_config) {
226   const int num_layers = cnn_config->num_layers;
227   const CNN_LAYER_CONFIG *layer_configs = cnn_config->layer_config;
228 
229   for (int idx = 0; idx < num_layers; idx++) {
230     if (layer_configs[idx].output_num != -1) {
231       return 1;
232     }
233   }
234   return 0;
235 }
236 #endif
237 
av1_find_cnn_output_size(int in_width,int in_height,const CNN_CONFIG * cnn_config,int * out_width,int * out_height,int * out_channels)238 void av1_find_cnn_output_size(int in_width, int in_height,
239                               const CNN_CONFIG *cnn_config, int *out_width,
240                               int *out_height, int *out_channels) {
241   int channels_per_branch[CNN_MAX_BRANCHES] = { 0 };
242   int i_width[CNN_MAX_BRANCHES] = { 0 };
243   int i_height[CNN_MAX_BRANCHES] = { 0 };
244   i_width[0] = in_width + cnn_config->ext_width * 2;
245   i_height[0] = in_height + cnn_config->ext_height * 2;
246 
247 #if CONFIG_DEBUG
248   assert(cnn_has_at_least_one_output(cnn_config));
249 #endif
250 
251   for (int i = 0; i < cnn_config->num_layers; ++i) {
252     const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[i];
253     const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
254     const int branch = layer_config->branch;
255     int o_width = 0, o_height = 0;
256 
257     if (layer_config->branch_copy_type == BRANCH_INPUT) {
258       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
259         if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
260           assert(i_width[branch] > 0 && i_height[branch] > 0);
261           i_width[b] = i_width[branch];
262           i_height[b] = i_height[branch];
263         }
264       }
265     }
266 
267     av1_find_cnn_layer_output_size(i_width[branch], i_height[branch],
268                                    layer_config, &o_width, &o_height);
269     i_width[branch] = o_width;
270     i_height[branch] = o_height;
271 
272     if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
273       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
274         if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
275           i_width[b] = o_width;
276           i_height[b] = o_height;
277         }
278       }
279     }
280 
281     find_cnn_out_channels(layer_config, channels_per_branch);
282 
283     const int output_num = layer_config->output_num;
284     if (output_num != -1) {  // Current layer is an output layer
285       out_width[output_num] = o_width;
286       out_height[output_num] = o_height;
287       out_channels[output_num] = channels_per_branch[layer_config->branch];
288     }
289   }
290 }
291 
get_start_shift_convolve(int width,int filt_width,int stride)292 static inline int get_start_shift_convolve(int width, int filt_width,
293                                            int stride) {
294   const int mod = (width % stride);
295   const int filt_off = (filt_width - 1) / 2;
296   const int dif = (mod ? mod - 1 : stride - 1);
297   return AOMMIN((dif + (filt_width % 2)) / 2, filt_off);
298 }
299 
av1_cnn_add_c(float ** output,int channels,int width,int height,int stride,const float ** add)300 void av1_cnn_add_c(float **output, int channels, int width, int height,
301                    int stride, const float **add) {
302   for (int c = 0; c < channels; ++c) {
303     for (int i = 0; i < height; ++i)
304       for (int j = 0; j < width; ++j)
305         output[c][i * stride + j] += add[c][i * stride + j];
306   }
307 }
308 
av1_cnn_activate_c(float ** output,int channels,int width,int height,int stride,ACTIVATION layer_activation)309 void av1_cnn_activate_c(float **output, int channels, int width, int height,
310                         int stride, ACTIVATION layer_activation) {
311   if (layer_activation == RELU) {
312     for (int c = 0; c < channels; ++c) {
313       for (int i = 0; i < height; ++i)
314         for (int j = 0; j < width; ++j)
315           output[c][i * stride + j] = relu(output[c][i * stride + j]);
316     }
317   } else if (layer_activation == SOFTSIGN) {
318     for (int c = 0; c < channels; ++c) {
319       for (int i = 0; i < height; ++i)
320         for (int j = 0; j < width; ++j)
321           output[c][i * stride + j] = softsign(output[c][i * stride + j]);
322     }
323   } else if (layer_activation == SIGMOID) {
324     assert(0 && "Sigmoid has not been supported in CNN.");  // TO DO
325   } else if (layer_activation != NONE) {
326     assert(0 && "Unknown activation type");
327   }
328 }
329 
copy_active_tensor_to_branches(const TENSOR * layer_active_tensor,const CNN_LAYER_CONFIG * layer_config,int branch,TENSOR branch_output[])330 static bool copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
331                                            const CNN_LAYER_CONFIG *layer_config,
332                                            int branch, TENSOR branch_output[]) {
333   const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
334   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
335     if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
336       // Copy layer's active tensor to output tensor of branch b if set in
337       // mask. The output becomes the input of the first layer of the branch
338       // because the layer of the branch is not the first layer.
339       int copy_channels = branch_config->channels_to_copy > 0
340                               ? branch_config->channels_to_copy
341                               : layer_active_tensor->channels;
342       if (!realloc_tensor(&branch_output[b], copy_channels,
343                           layer_active_tensor->width,
344                           layer_active_tensor->height)) {
345         return false;
346       }
347       copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
348     }
349   }
350   return true;
351 }
352 
353 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
354 // greater than 1 and padding equal to PADDING_SAME_ZERO.
convolve_maxpool_padding_zero(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep,const int filter_width_half,const int filter_height_half)355 static void convolve_maxpool_padding_zero(
356     const float **input, int in_width, int in_height, int in_stride,
357     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
358     const int cstep, const int filter_width_half,
359     const int filter_height_half) {
360   for (int i = 0; i < layer_config->out_channels; ++i) {
361     for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
362       for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
363         for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
364              ++hh) {
365           for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
366                ++ww) {
367             float sum = layer_config->bias[i];
368             for (int k = 0; k < layer_config->in_channels; ++k) {
369               int off = k * layer_config->out_channels + i;
370               for (int l = 0; l < layer_config->filter_height; ++l) {
371                 const int ii = hh + l - filter_height_half;
372                 for (int m = 0; m < layer_config->filter_width;
373                      ++m, off += cstep) {
374                   const int jj = ww + m - filter_width_half;
375                   if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
376                     continue;
377                   sum += layer_config->weights[off] *
378                          input[k][ii * in_stride + jj];
379                 }
380               }
381             }
382             const float a = sum;
383             if (h == hh && w == ww)
384               output[i][u * out_stride + v] = a;
385             else
386               output[i][u * out_stride + v] =
387                   AOMMAX(output[i][u * out_stride + v], a);
388           }
389         }
390       }
391     }
392   }
393 }
394 
395 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
396 // greater than 1 and padding equal to PADDING_SAME_REPLICATE.
convolve_maxpool_padding_replicate(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep,const int filter_width_half,const int filter_height_half)397 static void convolve_maxpool_padding_replicate(
398     const float **input, int in_width, int in_height, int in_stride,
399     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
400     const int cstep, const int filter_width_half,
401     const int filter_height_half) {
402   for (int i = 0; i < layer_config->out_channels; ++i) {
403     for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
404       for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
405         for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
406              ++hh) {
407           for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
408                ++ww) {
409             float sum = layer_config->bias[i];
410             for (int k = 0; k < layer_config->in_channels; ++k) {
411               int off = k * layer_config->out_channels + i;
412               for (int l = 0; l < layer_config->filter_height; ++l) {
413                 const int ii =
414                     CLAMPINDEX(hh + l - filter_height_half, in_height);
415                 for (int m = 0; m < layer_config->filter_width;
416                      ++m, off += cstep) {
417                   const int jj =
418                       CLAMPINDEX(ww + m - filter_width_half, in_width);
419                   assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
420                   sum += layer_config->weights[off] *
421                          input[k][ii * in_stride + jj];
422                 }
423               }
424             }
425             const float a = sum;
426             if (h == hh && w == ww)
427               output[i][u * out_stride + v] = a;
428             else
429               output[i][u * out_stride + v] =
430                   AOMMAX(output[i][u * out_stride + v], a);
431           }
432         }
433       }
434     }
435   }
436 }
437 
438 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
439 // greater than 1 and padding equal to PADDING_VALID.
convolve_maxpool_padding_valid(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep)440 static void convolve_maxpool_padding_valid(
441     const float **input, int in_width, int in_height, int in_stride,
442     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
443     const int cstep) {
444   for (int i = 0; i < layer_config->out_channels; ++i) {
445     for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
446          h += layer_config->skip_height, ++u) {
447       for (int w = 0, v = 0; w < in_width - layer_config->filter_width + 1;
448            w += layer_config->skip_width, ++v) {
449         for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
450              ++hh) {
451           for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
452                ++ww) {
453             float sum = layer_config->bias[i];
454             for (int k = 0; k < layer_config->in_channels; ++k) {
455               int off = k * layer_config->out_channels + i;
456               for (int l = 0; l < layer_config->filter_height; ++l) {
457                 const int ii = hh + l;
458                 for (int m = 0; m < layer_config->filter_width;
459                      ++m, off += cstep) {
460                   const int jj = ww + m;
461                   assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
462                   sum += layer_config->weights[off] *
463                          input[k][ii * in_stride + jj];
464                 }
465               }
466             }
467             const float a = sum;
468             if (h == hh && w == ww)
469               output[i][u * out_stride + v] = a;
470             else
471               output[i][u * out_stride + v] =
472                   AOMMAX(output[i][u * out_stride + v], a);
473           }
474         }
475       }
476     }
477   }
478 }
479 
480 // CNNConvolve specific to maxpool set as 0 with filter_height and filter_width
481 // equal to 1.
convolve_element_wise(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,int step)482 static void convolve_element_wise(const float **input, int in_width,
483                                   int in_height, int in_stride,
484                                   const CNN_LAYER_CONFIG *const layer_config,
485                                   float **output, int out_stride, int start_idx,
486                                   int step) {
487   const int start_h = get_start_shift_convolve(
488       in_height, layer_config->filter_height, layer_config->skip_height);
489   const int start_w =
490       get_start_shift_convolve(in_width, layer_config->filter_width,
491                                layer_config->skip_width) +
492       start_idx * layer_config->skip_width;
493   const int out_w_step = AOMMAX(step, 1);
494   const int in_w_step = layer_config->skip_width * out_w_step;
495   for (int i = 0; i < layer_config->out_channels; ++i) {
496     for (int h = start_h, u = 0; h < in_height;
497          h += layer_config->skip_height, ++u) {
498       const int in_h = h * in_stride;
499       const int out_h = u * out_stride + start_idx;
500       for (int w = start_w, out_index = out_h; w < in_width;
501            w += in_w_step, out_index += out_w_step) {
502         float sum = layer_config->bias[i];
503         for (int k = 0; k < layer_config->in_channels; ++k) {
504           sum += layer_config->weights[k * layer_config->out_channels + i] *
505                  input[k][in_h + w];
506         }
507         output[i][out_index] = sum;
508       }
509     }
510   }
511 }
512 
513 // CNNConvolve specific to maxpool set as 0 and padding equal to
514 // PADDING_SAME_ZERO.
convolve_no_maxpool_padding_zero(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int filter_width_half,const int filter_height_half,const int ii_shift,const int jj_shift,const int channel_step)515 static void convolve_no_maxpool_padding_zero(
516     const float **input, int in_width, int in_height, int in_stride,
517     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
518     int start_idx, const int cstep, const int filter_width_half,
519     const int filter_height_half, const int ii_shift, const int jj_shift,
520     const int channel_step) {
521   const int start_h = get_start_shift_convolve(
522       in_height, layer_config->filter_height, layer_config->skip_height);
523   const int start_w = get_start_shift_convolve(
524       in_width, layer_config->filter_width, layer_config->skip_width);
525   const int end_ii_shift = filter_height_half + 1;
526   const int end_jj_shift = filter_width_half + 1;
527   // *_filter_margin stores the number of pixels along a dimension in the
528   // intersection of the complement of the image in the extended image
529   // and the filter.
530   const int top_filter_margin = layer_config->filter_width * ii_shift;
531   const int right_filter_margin = end_jj_shift - in_width;
532   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
533     for (int h = start_h, u = 0; h < in_height;
534          h += layer_config->skip_height, ++u) {
535       const int out_h = u * out_stride;
536       const int top_cstep =
537           AOMMAX(0, top_filter_margin - h * layer_config->filter_width) *
538               cstep +
539           i;
540       const int start_ii = AOMMAX(0, h - ii_shift);
541       const int end_ii = AOMMIN(in_height, h + end_ii_shift);
542       for (int w = start_w, out_index = out_h; w < in_width;
543            w += layer_config->skip_width, ++out_index) {
544         const int left_cstep = AOMMAX(0, jj_shift - w) * cstep;
545         const int right_cstep = AOMMAX(0, right_filter_margin + w) * cstep;
546         const int start_jj = AOMMAX(0, w - jj_shift);
547         const int end_jj = AOMMIN(in_width, w + end_jj_shift);
548         float sum = layer_config->bias[i];
549         for (int k = 0; k < layer_config->in_channels; ++k) {
550           int off = k * layer_config->out_channels + top_cstep;
551           for (int ii = start_ii; ii < end_ii; ++ii) {
552             off += left_cstep;
553             for (int jj = start_jj; jj < end_jj; ++jj, off += cstep) {
554               sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
555             }
556             off += right_cstep;
557           }
558         }
559         output[i][out_index] = sum;
560       }
561     }
562   }
563 }
564 
565 // CNNConvolve specific to maxpool set as 0 and padding equal to
566 // PADDING_SAME_REPLICATE.
convolve_no_maxpool_padding_replicate(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int ii_shift,const int jj_shift,const int channel_step)567 static void convolve_no_maxpool_padding_replicate(
568     const float **input, int in_width, int in_height, int in_stride,
569     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
570     int start_idx, const int cstep, const int ii_shift, const int jj_shift,
571     const int channel_step) {
572   // h and w are shifted to an offset coordinate system to reduce in-loop
573   // computation.
574   const int start_h =
575       get_start_shift_convolve(in_height, layer_config->filter_height,
576                                layer_config->skip_height) -
577       ii_shift;
578   const int start_w =
579       get_start_shift_convolve(in_width, layer_config->filter_width,
580                                layer_config->skip_width) -
581       jj_shift;
582   const int end_h = in_height - ii_shift;
583   const int end_w = in_width - jj_shift;
584   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
585     for (int h = start_h, u = 0; h < end_h;
586          h += layer_config->skip_height, ++u) {
587       const int out_h = u * out_stride;
588       const int upper_ii_index = layer_config->filter_height + h;
589       for (int w = start_w, out_index = out_h; w < end_w;
590            w += layer_config->skip_width, ++out_index) {
591         const int upper_jj_index = layer_config->filter_width + w;
592         float sum = layer_config->bias[i];
593         for (int k = 0; k < layer_config->in_channels; ++k) {
594           int off = k * layer_config->out_channels + i;
595           for (int ii = h; ii < upper_ii_index; ++ii) {
596             const int clamped_ii = CLAMPINDEX(ii, in_height);
597             for (int jj = w; jj < upper_jj_index; ++jj) {
598               const int clamped_jj = CLAMPINDEX(jj, in_width);
599               assert(clamped_ii >= 0 && clamped_ii < in_height &&
600                      clamped_jj >= 0 && clamped_jj < in_width);
601               sum += layer_config->weights[off] *
602                      input[k][clamped_ii * in_stride + clamped_jj];
603               off += cstep;
604             }
605           }
606         }
607         output[i][out_index] = sum;
608       }
609     }
610   }
611 }
612 
613 // CNNConvolve specific to maxpool set as 0 and padding equal to
614 // PADDING_VALID.
av1_cnn_convolve_no_maxpool_padding_valid_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int cstep,int channel_step)615 void av1_cnn_convolve_no_maxpool_padding_valid_c(
616     const float **input, int in_width, int in_height, int in_stride,
617     const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
618     int start_idx, int cstep, int channel_step) {
619   assert((layer_config->skip_height == 1 && layer_config->skip_width == 1) ||
620          !layer_config->maxpool);
621   assert(layer_config->filter_height > 1 || layer_config->filter_width > 1);
622   assert(layer_config->pad == PADDING_VALID);
623   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
624     for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
625          h += layer_config->skip_height, ++u) {
626       const int out_h = u * out_stride;
627       const int upper_ii_index = layer_config->filter_height + h;
628       for (int w = 0, out_index = out_h;
629            w < in_width - layer_config->filter_width + 1;
630            w += layer_config->skip_width, ++out_index) {
631         const int upper_jj_index = layer_config->filter_width + w;
632         float sum = layer_config->bias[i];
633         for (int k = 0; k < layer_config->in_channels; ++k) {
634           int off = k * layer_config->out_channels + i;
635           for (int ii = h; ii < upper_ii_index; ++ii) {
636             for (int jj = w; jj < upper_jj_index; ++jj) {
637               assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
638               sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
639               off += cstep;
640             }
641           }
642         }
643         output[i][out_index] = sum;
644       }
645     }
646   }
647 }
648 
av1_cnn_convolve(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int step)649 static void av1_cnn_convolve(const float **input, int in_width, int in_height,
650                              int in_stride,
651                              const CNN_LAYER_CONFIG *layer_config,
652                              float **output, int out_stride, int start_idx,
653                              int step) {
654   assert(!layer_config->deconvolve);
655   const int cstep = layer_config->in_channels * layer_config->out_channels;
656   const int filter_height_half = layer_config->filter_height >> 1;
657   const int filter_width_half = layer_config->filter_width >> 1;
658   const int channel_step = AOMMAX(step, 1);
659 
660   if (layer_config->maxpool &&
661       (layer_config->skip_height > 1 || layer_config->skip_width > 1)) {
662     switch (layer_config->pad) {
663       case PADDING_SAME_ZERO:
664         convolve_maxpool_padding_zero(input, in_width, in_height, in_stride,
665                                       layer_config, output, out_stride, cstep,
666                                       filter_width_half, filter_height_half);
667         break;
668       case PADDING_SAME_REPLICATE:
669         convolve_maxpool_padding_replicate(
670             input, in_width, in_height, in_stride, layer_config, output,
671             out_stride, cstep, filter_width_half, filter_height_half);
672         break;
673       case PADDING_VALID:
674         convolve_maxpool_padding_valid(input, in_width, in_height, in_stride,
675                                        layer_config, output, out_stride, cstep);
676         break;
677       default: assert(0 && "Unknown padding type");
678     }
679   } else {
680     // Results in element-wise matrix multiplication.
681     if (layer_config->filter_height == 1 && layer_config->filter_width == 1) {
682       convolve_element_wise(input, in_width, in_height, in_stride, layer_config,
683                             output, out_stride, start_idx, step);
684       return;
685     }
686     const int ii_shift =
687         filter_height_half - (layer_config->filter_height - 1) % 2;
688     const int jj_shift =
689         filter_width_half - (layer_config->filter_width - 1) % 2;
690     switch (layer_config->pad) {
691       case PADDING_SAME_ZERO:
692         convolve_no_maxpool_padding_zero(
693             input, in_width, in_height, in_stride, layer_config, output,
694             out_stride, start_idx, cstep, filter_width_half, filter_height_half,
695             ii_shift, jj_shift, channel_step);
696         break;
697       case PADDING_SAME_REPLICATE:
698         convolve_no_maxpool_padding_replicate(
699             input, in_width, in_height, in_stride, layer_config, output,
700             out_stride, start_idx, cstep, ii_shift, jj_shift, channel_step);
701         break;
702       case PADDING_VALID:
703         av1_cnn_convolve_no_maxpool_padding_valid(
704             input, in_width, in_height, in_stride, layer_config, output,
705             out_stride, start_idx, cstep, channel_step);
706         break;
707       default: assert(0 && "Unknown padding type");
708     }
709   }
710 }
711 
convolve_layer(void * arg1,void * arg2)712 static int convolve_layer(void *arg1, void *arg2) {
713   const CONVOLVE_OPS *convolve_ops = arg1;
714   (void)arg2;
715   av1_cnn_convolve(
716       convolve_ops->input, convolve_ops->in_width, convolve_ops->in_height,
717       convolve_ops->in_stride, convolve_ops->layer_config, convolve_ops->output,
718       convolve_ops->out_stride, convolve_ops->start_idx, convolve_ops->th_step);
719   return 1;
720 }
721 
convolve_layer_mt(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,const CNN_THREAD_DATA * thread_data,float ** output,int out_stride)722 static void convolve_layer_mt(const float **input, int in_width, int in_height,
723                               int in_stride,
724                               const CNN_LAYER_CONFIG *layer_config,
725                               const CNN_THREAD_DATA *thread_data,
726                               float **output, int out_stride) {
727   const AVxWorkerInterface *const winterface = aom_get_worker_interface();
728   const int num_workers = thread_data->num_workers;
729   assert(thread_data->workers);
730 
731   CONVOLVE_OPS convolve_ops[CNN_MAX_THREADS];
732   for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
733     AVxWorker *const worker = &thread_data->workers[th];
734     winterface->reset(worker);
735 
736     CONVOLVE_OPS convolve_op = { input,      in_width,     in_height,
737                                  in_stride,  layer_config, output,
738                                  out_stride, th,           num_workers };
739     convolve_ops[th] = convolve_op;
740     worker->hook = convolve_layer;
741     worker->data1 = &(convolve_ops[th]);
742     worker->data2 = NULL;
743 
744     // Start convolving.
745     if (th == num_workers - 1) {
746       winterface->execute(worker);
747     } else {
748       winterface->launch(worker);
749     }
750   }
751 
752   // Wait until all workers have finished.
753   for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
754     winterface->sync(&thread_data->workers[th]);
755   }
756 }
757 
get_start_shift_deconvolve(int filt_width,int stride)758 static inline int get_start_shift_deconvolve(int filt_width, int stride) {
759   const int dif = AOMMAX(filt_width - stride, 0);
760   return dif / 2;
761 }
762 
av1_cnn_batchnorm_c(float ** image,int channels,int width,int height,int stride,const float * gamma,const float * beta,const float * mean,const float * std)763 void av1_cnn_batchnorm_c(float **image, int channels, int width, int height,
764                          int stride, const float *gamma, const float *beta,
765                          const float *mean, const float *std) {
766   assert(gamma && beta && beta && std && "batchnorm has null parameter!");
767   for (int ch = 0; ch < channels; ch++) {
768     const float ch_gamma = gamma[ch];
769     const float ch_beta = beta[ch];
770     const float ch_mean = mean[ch];
771     const float ch_std = std[ch];
772     float *image_row = image[ch];
773 
774     for (int row = 0; row < height; row++) {
775       for (int col = 0; col < width; col++) {
776         image_row[col] =
777             ch_gamma * (image_row[col] - ch_mean) / ch_std + ch_beta;
778       }
779       image_row += stride;
780     }
781   }
782 }
783 
av1_cnn_deconvolve_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride)784 void av1_cnn_deconvolve_c(const float **input, int in_width, int in_height,
785                           int in_stride, const CNN_LAYER_CONFIG *layer_config,
786                           float **output, int out_stride) {
787   assert(layer_config->deconvolve);
788 
789   const int cstep = layer_config->in_channels * layer_config->out_channels;
790 
791   int out_width = 0;
792   int out_height = 0;
793   av1_find_cnn_layer_output_size(in_width, in_height, layer_config, &out_width,
794                                  &out_height);
795   switch (layer_config->pad) {
796     case PADDING_SAME_ZERO:
797       for (int i = 0; i < layer_config->out_channels; ++i) {
798         for (int u = 0; u < out_height; ++u) {
799           for (int v = 0; v < out_width; ++v) {
800             float sum = layer_config->bias[i];
801             for (int k = 0; k < layer_config->in_channels; ++k) {
802               int off = k * layer_config->out_channels + i;
803               for (int l = 0; l < layer_config->filter_height; ++l) {
804                 const int h =
805                     u - l +
806                     get_start_shift_deconvolve(layer_config->filter_height,
807                                                layer_config->skip_height);
808                 for (int m = 0; m < layer_config->filter_width;
809                      ++m, off += cstep) {
810                   const int w =
811                       v - m +
812                       get_start_shift_deconvolve(layer_config->filter_width,
813                                                  layer_config->skip_width);
814                   if ((h % layer_config->skip_height) != 0 ||
815                       (w % layer_config->skip_width) != 0)
816                     continue;
817                   const int ii = h / layer_config->skip_height;
818                   const int jj = w / layer_config->skip_width;
819                   if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
820                     continue;
821                   sum += layer_config->weights[off] *
822                          input[k][ii * in_stride + jj];
823                 }
824               }
825             }
826             output[i][u * out_stride + v] = sum;
827           }
828         }
829       }
830       break;
831     case PADDING_SAME_REPLICATE:
832       for (int i = 0; i < layer_config->out_channels; ++i) {
833         for (int u = 0; u < out_height; ++u) {
834           for (int v = 0; v < out_width; ++v) {
835             float sum = layer_config->bias[i];
836             for (int k = 0; k < layer_config->in_channels; ++k) {
837               int off = k * layer_config->out_channels + i;
838               for (int l = 0; l < layer_config->filter_height; ++l) {
839                 const int h =
840                     u - l +
841                     get_start_shift_deconvolve(layer_config->filter_height,
842                                                layer_config->skip_height);
843                 for (int m = 0; m < layer_config->filter_width;
844                      ++m, off += cstep) {
845                   const int w =
846                       v - m +
847                       get_start_shift_deconvolve(layer_config->filter_width,
848                                                  layer_config->skip_width);
849                   if ((h % layer_config->skip_height) != 0 ||
850                       (w % layer_config->skip_width) != 0)
851                     continue;
852                   const int ii =
853                       CLAMPINDEX(h / layer_config->skip_height, in_height);
854                   const int jj =
855                       CLAMPINDEX(w / layer_config->skip_width, in_width);
856                   assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
857                   sum += layer_config->weights[off] *
858                          input[k][ii * in_stride + jj];
859                 }
860               }
861             }
862             output[i][u * out_stride + v] = sum;
863           }
864         }
865       }
866       break;
867     case PADDING_VALID:
868       for (int i = 0; i < layer_config->out_channels; ++i) {
869         for (int u = 0; u < out_height; ++u) {
870           for (int v = 0; v < out_width; ++v) {
871             float sum = layer_config->bias[i];
872             for (int k = 0; k < layer_config->in_channels; ++k) {
873               int off = k * layer_config->out_channels + i;
874               for (int l = 0; l < layer_config->filter_height; ++l) {
875                 const int h = u - l;
876                 for (int m = 0; m < layer_config->filter_width;
877                      ++m, off += cstep) {
878                   const int w = v - m;
879                   if ((h % layer_config->skip_height) != 0 ||
880                       (w % layer_config->skip_width) != 0)
881                     continue;
882                   const int ii = h / layer_config->skip_height;
883                   const int jj = w / layer_config->skip_width;
884                   if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
885                     continue;
886                   sum += layer_config->weights[off] *
887                          input[k][ii * in_stride + jj];
888                 }
889               }
890             }
891             output[i][u * out_stride + v] = sum;
892           }
893         }
894       }
895       break;
896     default: assert(0 && "Unknown padding type");
897   }
898 }
899 
av1_cnn_predict_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output_struct)900 bool av1_cnn_predict_c(const float **input, int in_width, int in_height,
901                        int in_stride, const CNN_CONFIG *cnn_config,
902                        const CNN_THREAD_DATA *thread_data,
903                        CNN_MULTI_OUT *output_struct) {
904   bool success = false;
905   TENSOR tensor1[CNN_MAX_BRANCHES] = { { 0 } };
906   TENSOR tensor2[CNN_MAX_BRANCHES] = { { 0 } };
907 
908   float **output[CNN_MAX_BRANCHES];
909   const int *out_chs = output_struct->output_channels;
910   output[0] = output_struct->output_buffer;
911   for (int out_idx = 1; out_idx < output_struct->num_outputs; out_idx++) {
912     output[out_idx] = output[out_idx - 1] + out_chs[out_idx - 1];
913   }
914 
915   int i_width = in_width;
916   int i_height = in_height;
917   int o_width = 0, o_height = 0;
918   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
919     init_tensor(&tensor1[b]);
920     init_tensor(&tensor2[b]);
921   }
922 
923   const int *out_stride = output_struct->output_strides;
924   for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
925     const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
926     const int branch = layer_config->branch;
927     const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
928 
929     // Allocate input tensor
930     if (layer == 0) {       // First layer
931       assert(branch == 0);  // First layer must be primary branch
932       assign_tensor(&tensor1[branch], (float **)input,
933                     layer_config->in_channels, in_width, in_height, in_stride);
934     } else {  // Non-first layer
935       // Swap tensor1 and tensor2
936       swap_tensor(&tensor1[branch], &tensor2[branch]);
937 
938       i_width = tensor1[branch].width;
939       i_height = tensor1[branch].height;
940     }
941 
942     // Allocate output tensor
943     av1_find_cnn_layer_output_size(i_width, i_height, layer_config, &o_width,
944                                    &o_height);
945     const int output_num = layer_config->output_num;
946     if (output_num == -1) {  // Non-output layer
947       if (!realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
948                           o_height)) {
949         goto Error;
950       }
951     } else {  // Output layer
952       free_tensor(&tensor2[branch]);
953       assign_tensor(&tensor2[branch], output[output_num],
954                     layer_config->out_channels, o_width, o_height,
955                     out_stride[output_num]);
956     }
957 
958     // If we are combining branches make sure that the branch to combine
959     // is different from the current branch.
960     assert(IMPLIES(layer_config->branch_combine_type != BRANCH_NOC,
961                    !(branch_config->branches_to_combine & (1 << branch))));
962 
963     if (layer_config->branch_copy_type == BRANCH_INPUT) {
964       if (!copy_active_tensor_to_branches(&tensor1[branch], layer_config,
965                                           branch, tensor2)) {
966         goto Error;
967       }
968     }
969     // Check consistency of input and output channels
970     assert(tensor1[branch].channels == layer_config->in_channels);
971     assert(tensor2[branch].channels == layer_config->out_channels);
972 
973     // Convolve/Deconvolve
974     if (!cnn_config->layer_config[layer].deconvolve) {
975       if (thread_data->num_workers > 1) {
976         convolve_layer_mt((const float **)tensor1[branch].buf,
977                           tensor1[branch].width, tensor1[branch].height,
978                           tensor1[branch].stride, layer_config, thread_data,
979                           tensor2[branch].buf, tensor2[branch].stride);
980       } else {
981         av1_cnn_convolve((const float **)tensor1[branch].buf,
982                          tensor1[branch].width, tensor1[branch].height,
983                          tensor1[branch].stride, layer_config,
984                          tensor2[branch].buf, tensor2[branch].stride, 0, 1);
985       }
986     } else {
987       av1_cnn_deconvolve((const float **)tensor1[branch].buf,
988                          tensor1[branch].width, tensor1[branch].height,
989                          tensor1[branch].stride, layer_config,
990                          tensor2[branch].buf, tensor2[branch].stride);
991     }
992 
993     if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
994       if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
995                                           branch, tensor2)) {
996         goto Error;
997       }
998     }
999 
1000     // Add tensors from other branches if needed
1001     if (layer_config->branch_combine_type == BRANCH_ADD) {
1002       for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1003         if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1004           assert(check_tensor_equal_size(&tensor2[b], &tensor2[branch]));
1005           av1_cnn_add(tensor2[branch].buf, tensor2[branch].channels,
1006                       tensor2[branch].width, tensor2[branch].height,
1007                       tensor2[branch].stride, (const float **)tensor2[b].buf);
1008         }
1009       }
1010     }
1011 
1012     // Non-linearity
1013     av1_cnn_activate(tensor2[branch].buf, tensor2[branch].channels,
1014                      tensor2[branch].width, tensor2[branch].height,
1015                      tensor2[branch].stride, layer_config->activation);
1016 
1017     if (layer_config->bn_params.bn_gamma) {
1018       av1_cnn_batchnorm(
1019           tensor2[branch].buf, tensor2[branch].channels, tensor2[branch].width,
1020           tensor2[branch].height, tensor2[branch].stride,
1021           layer_config->bn_params.bn_gamma, layer_config->bn_params.bn_beta,
1022           layer_config->bn_params.bn_mean, layer_config->bn_params.bn_std);
1023     }
1024 
1025     // Concatenate tensors
1026     if (layer_config->branch_combine_type == BRANCH_CAT) {
1027       if (output_num == -1) {  // Non-output layer
1028         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1029           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1030             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1031             assert(tensor2[b].channels > 0);
1032             if (!concat_tensor(&tensor2[b], &tensor2[branch])) goto Error;
1033           }
1034         }
1035       } else {  // Output layer
1036         const int existing_channels = tensor2[branch].channels;
1037         int num_chs = existing_channels;
1038         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1039           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1040             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1041             // Needed only to assign the new channel buffers
1042             num_chs += tensor2[b].channels;
1043           }
1044         }
1045         assign_tensor(&tensor2[branch], output[output_num], num_chs, o_width,
1046                       o_height, out_stride[output_num]);
1047 
1048         num_chs = existing_channels;
1049         for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1050           if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1051             assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1052             // Needed only to assign the new channel buffers
1053             copy_tensor(&tensor2[b], tensor2[b].channels, num_chs,
1054                         &tensor2[branch]);
1055             num_chs += tensor2[b].channels;
1056           }
1057         }
1058       }
1059     }
1060 
1061     if (layer_config->branch_copy_type == BRANCH_COMBINED) {
1062       if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
1063                                           branch, tensor2)) {
1064         goto Error;
1065       }
1066     }
1067   }
1068 
1069   success = true;
1070 Error:
1071   for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1072     free_tensor(&tensor1[b]);
1073     free_tensor(&tensor2[b]);
1074   }
1075   return success;
1076 }
1077 
1078 // Assume output already has proper allocation
1079 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out(uint8_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output)1080 bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
1081                                    int stride, const CNN_CONFIG *cnn_config,
1082                                    const CNN_THREAD_DATA *thread_data,
1083                                    CNN_MULTI_OUT *output) {
1084   const float max_val = 255.0;
1085 
1086   const int in_width = width + 2 * cnn_config->ext_width;
1087   const int in_height = height + 2 * cnn_config->ext_height;
1088   const int in_channels = cnn_config->layer_config[0].in_channels;
1089   float *inputs[CNN_MAX_CHANNELS];
1090   float *input_ =
1091       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1092   if (!input_) return false;
1093   const int in_stride = in_width;
1094 
1095   for (int c = 0; c < in_channels; ++c) {
1096     inputs[c] = input_ + c * in_stride * in_height;
1097     float *input =
1098         inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1099 
1100     if (cnn_config->strict_bounds) {
1101       for (int i = 0; i < height; ++i)
1102         for (int j = 0; j < width; ++j)
1103           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1104       // extend left and right
1105       for (int i = 0; i < height; ++i) {
1106         for (int j = -cnn_config->ext_width; j < 0; ++j)
1107           input[i * in_stride + j] = input[i * in_stride];
1108         for (int j = width; j < width + cnn_config->ext_width; ++j)
1109           input[i * in_stride + j] = input[i * in_stride + width - 1];
1110       }
1111       // extend top and bottom
1112       for (int i = -cnn_config->ext_height; i < 0; ++i)
1113         memcpy(&input[i * in_stride - cnn_config->ext_width],
1114                &input[-cnn_config->ext_width], in_width * sizeof(*input));
1115       for (int i = height; i < height + cnn_config->ext_height; ++i)
1116         memcpy(&input[i * in_stride - cnn_config->ext_width],
1117                &input[(height - 1) * in_stride - cnn_config->ext_width],
1118                in_width * sizeof(*input));
1119     } else {
1120       for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1121            ++i)
1122         for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1123              ++j)
1124           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1125     }
1126   }
1127   bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
1128                                  in_stride, cnn_config, thread_data, output);
1129 
1130   aom_free(input_);
1131   return success;
1132 }
1133 
1134 // Assume output already has proper allocation
1135 // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out_highbd(uint16_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,int bit_depth,CNN_MULTI_OUT * output)1136 bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
1137                                           int stride,
1138                                           const CNN_CONFIG *cnn_config,
1139                                           const CNN_THREAD_DATA *thread_data,
1140                                           int bit_depth,
1141                                           CNN_MULTI_OUT *output) {
1142   const float max_val = (float)((1 << bit_depth) - 1);
1143 
1144   const int in_width = width + 2 * cnn_config->ext_width;
1145   const int in_height = height + 2 * cnn_config->ext_height;
1146   const int in_channels = cnn_config->layer_config[0].in_channels;
1147   float *inputs[CNN_MAX_CHANNELS];
1148   float *input_ =
1149       (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1150   if (!input_) return false;
1151   const int in_stride = in_width;
1152 
1153   for (int c = 0; c < in_channels; ++c) {
1154     inputs[c] = input_ + c * in_stride * in_height;
1155     float *input =
1156         inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1157 
1158     if (cnn_config->strict_bounds) {
1159       for (int i = 0; i < height; ++i)
1160         for (int j = 0; j < width; ++j)
1161           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1162       // extend left and right
1163       for (int i = 0; i < height; ++i) {
1164         for (int j = -cnn_config->ext_width; j < 0; ++j)
1165           input[i * in_stride + j] = input[i * in_stride];
1166         for (int j = width; j < width + cnn_config->ext_width; ++j)
1167           input[i * in_stride + j] = input[i * in_stride + width - 1];
1168       }
1169       // extend top and bottom
1170       for (int i = -cnn_config->ext_height; i < 0; ++i)
1171         memcpy(&input[i * in_stride - cnn_config->ext_width],
1172                &input[-cnn_config->ext_width], in_width * sizeof(*input));
1173       for (int i = height; i < height + cnn_config->ext_height; ++i)
1174         memcpy(&input[i * in_stride - cnn_config->ext_width],
1175                &input[(height - 1) * in_stride - cnn_config->ext_width],
1176                in_width * sizeof(*input));
1177     } else {
1178       for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1179            ++i)
1180         for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1181              ++j)
1182           input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1183     }
1184   }
1185 
1186   bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
1187                                  in_stride, cnn_config, thread_data, output);
1188 
1189   aom_free(input_);
1190   return success;
1191 }
1192