1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker *
4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker */
11*77c1e3ccSAndroid Build Coastguard Worker
12*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <math.h>
14*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h>
15*77c1e3ccSAndroid Build Coastguard Worker
16*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/aom_dsp_common.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/av1_common_int.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/cnn.h"
19*77c1e3ccSAndroid Build Coastguard Worker
20*77c1e3ccSAndroid Build Coastguard Worker #define CLAMPINDEX(a, hi) ((a) < 0 ? 0 : ((a) >= (hi) ? ((hi)-1) : (a)))
21*77c1e3ccSAndroid Build Coastguard Worker
22*77c1e3ccSAndroid Build Coastguard Worker typedef struct {
23*77c1e3ccSAndroid Build Coastguard Worker const float **input;
24*77c1e3ccSAndroid Build Coastguard Worker int in_width;
25*77c1e3ccSAndroid Build Coastguard Worker int in_height;
26*77c1e3ccSAndroid Build Coastguard Worker int in_stride;
27*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config;
28*77c1e3ccSAndroid Build Coastguard Worker float **output;
29*77c1e3ccSAndroid Build Coastguard Worker int out_stride;
30*77c1e3ccSAndroid Build Coastguard Worker int start_idx;
31*77c1e3ccSAndroid Build Coastguard Worker int th_step;
32*77c1e3ccSAndroid Build Coastguard Worker } CONVOLVE_OPS;
33*77c1e3ccSAndroid Build Coastguard Worker
softsign(float x)34*77c1e3ccSAndroid Build Coastguard Worker static inline float softsign(float x) { return x / (fabsf(x) + 1.0f); }
35*77c1e3ccSAndroid Build Coastguard Worker
relu(float x)36*77c1e3ccSAndroid Build Coastguard Worker static inline float relu(float x) { return (x < 0) ? 0 : x; }
37*77c1e3ccSAndroid Build Coastguard Worker
38*77c1e3ccSAndroid Build Coastguard Worker typedef struct {
39*77c1e3ccSAndroid Build Coastguard Worker int allocsize;
40*77c1e3ccSAndroid Build Coastguard Worker int channels;
41*77c1e3ccSAndroid Build Coastguard Worker int width, height, stride;
42*77c1e3ccSAndroid Build Coastguard Worker float *buf[CNN_MAX_CHANNELS];
43*77c1e3ccSAndroid Build Coastguard Worker } TENSOR;
44*77c1e3ccSAndroid Build Coastguard Worker
init_tensor(TENSOR * tensor)45*77c1e3ccSAndroid Build Coastguard Worker static void init_tensor(TENSOR *tensor) { memset(tensor, 0, sizeof(*tensor)); }
46*77c1e3ccSAndroid Build Coastguard Worker
free_tensor(TENSOR * tensor)47*77c1e3ccSAndroid Build Coastguard Worker static void free_tensor(TENSOR *tensor) {
48*77c1e3ccSAndroid Build Coastguard Worker if (tensor->allocsize) {
49*77c1e3ccSAndroid Build Coastguard Worker aom_free(tensor->buf[0]);
50*77c1e3ccSAndroid Build Coastguard Worker tensor->buf[0] = NULL;
51*77c1e3ccSAndroid Build Coastguard Worker tensor->allocsize = 0;
52*77c1e3ccSAndroid Build Coastguard Worker }
53*77c1e3ccSAndroid Build Coastguard Worker }
54*77c1e3ccSAndroid Build Coastguard Worker
realloc_tensor(TENSOR * tensor,int channels,int width,int height)55*77c1e3ccSAndroid Build Coastguard Worker static bool realloc_tensor(TENSOR *tensor, int channels, int width,
56*77c1e3ccSAndroid Build Coastguard Worker int height) {
57*77c1e3ccSAndroid Build Coastguard Worker const int newallocsize = channels * width * height;
58*77c1e3ccSAndroid Build Coastguard Worker if (tensor->allocsize < newallocsize) {
59*77c1e3ccSAndroid Build Coastguard Worker free_tensor(tensor);
60*77c1e3ccSAndroid Build Coastguard Worker tensor->buf[0] =
61*77c1e3ccSAndroid Build Coastguard Worker (float *)aom_malloc(sizeof(*tensor->buf[0]) * newallocsize);
62*77c1e3ccSAndroid Build Coastguard Worker if (!tensor->buf[0]) return false;
63*77c1e3ccSAndroid Build Coastguard Worker tensor->allocsize = newallocsize;
64*77c1e3ccSAndroid Build Coastguard Worker }
65*77c1e3ccSAndroid Build Coastguard Worker tensor->width = width;
66*77c1e3ccSAndroid Build Coastguard Worker tensor->height = height;
67*77c1e3ccSAndroid Build Coastguard Worker tensor->stride = width;
68*77c1e3ccSAndroid Build Coastguard Worker tensor->channels = channels;
69*77c1e3ccSAndroid Build Coastguard Worker for (int c = 1; c < channels; ++c)
70*77c1e3ccSAndroid Build Coastguard Worker tensor->buf[c] = &tensor->buf[0][c * width * height];
71*77c1e3ccSAndroid Build Coastguard Worker return true;
72*77c1e3ccSAndroid Build Coastguard Worker }
73*77c1e3ccSAndroid Build Coastguard Worker
copy_tensor(const TENSOR * src,int copy_channels,int dst_offset,TENSOR * dst)74*77c1e3ccSAndroid Build Coastguard Worker static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
75*77c1e3ccSAndroid Build Coastguard Worker TENSOR *dst) {
76*77c1e3ccSAndroid Build Coastguard Worker assert(src->width == dst->width);
77*77c1e3ccSAndroid Build Coastguard Worker assert(src->height == dst->height);
78*77c1e3ccSAndroid Build Coastguard Worker assert(copy_channels <= src->channels);
79*77c1e3ccSAndroid Build Coastguard Worker if (src->stride == dst->width && dst->stride == dst->width) {
80*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < copy_channels; ++c) {
81*77c1e3ccSAndroid Build Coastguard Worker memcpy(dst->buf[dst_offset + c], src->buf[c],
82*77c1e3ccSAndroid Build Coastguard Worker sizeof(*dst->buf[0]) * src->width * src->height);
83*77c1e3ccSAndroid Build Coastguard Worker }
84*77c1e3ccSAndroid Build Coastguard Worker } else {
85*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < copy_channels; ++c) {
86*77c1e3ccSAndroid Build Coastguard Worker for (int r = 0; r < dst->height; ++r) {
87*77c1e3ccSAndroid Build Coastguard Worker memcpy(&dst->buf[dst_offset + c][r * dst->stride],
88*77c1e3ccSAndroid Build Coastguard Worker &src->buf[c][r * src->stride],
89*77c1e3ccSAndroid Build Coastguard Worker dst->width * sizeof(*dst->buf[c]));
90*77c1e3ccSAndroid Build Coastguard Worker }
91*77c1e3ccSAndroid Build Coastguard Worker }
92*77c1e3ccSAndroid Build Coastguard Worker }
93*77c1e3ccSAndroid Build Coastguard Worker }
94*77c1e3ccSAndroid Build Coastguard Worker
assign_tensor(TENSOR * tensor,float * buf[CNN_MAX_CHANNELS],int channels,int width,int height,int stride)95*77c1e3ccSAndroid Build Coastguard Worker static void assign_tensor(TENSOR *tensor, float *buf[CNN_MAX_CHANNELS],
96*77c1e3ccSAndroid Build Coastguard Worker int channels, int width, int height, int stride) {
97*77c1e3ccSAndroid Build Coastguard Worker tensor->allocsize = 0;
98*77c1e3ccSAndroid Build Coastguard Worker tensor->channels = channels;
99*77c1e3ccSAndroid Build Coastguard Worker tensor->width = width;
100*77c1e3ccSAndroid Build Coastguard Worker tensor->height = height;
101*77c1e3ccSAndroid Build Coastguard Worker tensor->stride = stride;
102*77c1e3ccSAndroid Build Coastguard Worker if (buf) {
103*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < channels; ++c) tensor->buf[c] = buf[c];
104*77c1e3ccSAndroid Build Coastguard Worker } else {
105*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < channels; ++c) tensor->buf[c] = NULL;
106*77c1e3ccSAndroid Build Coastguard Worker }
107*77c1e3ccSAndroid Build Coastguard Worker }
108*77c1e3ccSAndroid Build Coastguard Worker
swap_tensor(TENSOR * t1,TENSOR * t2)109*77c1e3ccSAndroid Build Coastguard Worker static void swap_tensor(TENSOR *t1, TENSOR *t2) {
110*77c1e3ccSAndroid Build Coastguard Worker TENSOR t = *t1;
111*77c1e3ccSAndroid Build Coastguard Worker *t1 = *t2;
112*77c1e3ccSAndroid Build Coastguard Worker *t2 = t;
113*77c1e3ccSAndroid Build Coastguard Worker }
114*77c1e3ccSAndroid Build Coastguard Worker
115*77c1e3ccSAndroid Build Coastguard Worker // The concatenated tensor goes into dst with first the channels in
116*77c1e3ccSAndroid Build Coastguard Worker // original dst followed by the channels in the src
concat_tensor(const TENSOR * src,TENSOR * dst)117*77c1e3ccSAndroid Build Coastguard Worker static bool concat_tensor(const TENSOR *src, TENSOR *dst) {
118*77c1e3ccSAndroid Build Coastguard Worker assert(src->width == dst->width);
119*77c1e3ccSAndroid Build Coastguard Worker assert(src->height == dst->height);
120*77c1e3ccSAndroid Build Coastguard Worker
121*77c1e3ccSAndroid Build Coastguard Worker const int dst_channels = dst->channels;
122*77c1e3ccSAndroid Build Coastguard Worker const int channels = dst->channels + src->channels;
123*77c1e3ccSAndroid Build Coastguard Worker const int newallocsize = channels * dst->width * dst->height;
124*77c1e3ccSAndroid Build Coastguard Worker if (dst->allocsize < newallocsize) {
125*77c1e3ccSAndroid Build Coastguard Worker TENSOR t;
126*77c1e3ccSAndroid Build Coastguard Worker init_tensor(&t);
127*77c1e3ccSAndroid Build Coastguard Worker // allocate new buffers and copy first the dst channels
128*77c1e3ccSAndroid Build Coastguard Worker if (!realloc_tensor(&t, channels, dst->width, dst->height)) return false;
129*77c1e3ccSAndroid Build Coastguard Worker copy_tensor(dst, dst->channels, 0, &t);
130*77c1e3ccSAndroid Build Coastguard Worker // Swap the tensors and free the old buffers
131*77c1e3ccSAndroid Build Coastguard Worker swap_tensor(dst, &t);
132*77c1e3ccSAndroid Build Coastguard Worker free_tensor(&t);
133*77c1e3ccSAndroid Build Coastguard Worker }
134*77c1e3ccSAndroid Build Coastguard Worker for (int c = 1; c < channels; ++c)
135*77c1e3ccSAndroid Build Coastguard Worker dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
136*77c1e3ccSAndroid Build Coastguard Worker // Copy the channels in src after the first dst_channels channels.
137*77c1e3ccSAndroid Build Coastguard Worker copy_tensor(src, src->channels, dst_channels, dst);
138*77c1e3ccSAndroid Build Coastguard Worker return true;
139*77c1e3ccSAndroid Build Coastguard Worker }
140*77c1e3ccSAndroid Build Coastguard Worker
141*77c1e3ccSAndroid Build Coastguard Worker #ifndef NDEBUG
check_tensor_equal_dims(TENSOR * t1,TENSOR * t2)142*77c1e3ccSAndroid Build Coastguard Worker static int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
143*77c1e3ccSAndroid Build Coastguard Worker return (t1->width == t2->width && t1->height == t2->height);
144*77c1e3ccSAndroid Build Coastguard Worker }
145*77c1e3ccSAndroid Build Coastguard Worker
check_tensor_equal_size(TENSOR * t1,TENSOR * t2)146*77c1e3ccSAndroid Build Coastguard Worker static int check_tensor_equal_size(TENSOR *t1, TENSOR *t2) {
147*77c1e3ccSAndroid Build Coastguard Worker return (t1->channels == t2->channels && t1->width == t2->width &&
148*77c1e3ccSAndroid Build Coastguard Worker t1->height == t2->height);
149*77c1e3ccSAndroid Build Coastguard Worker }
150*77c1e3ccSAndroid Build Coastguard Worker #endif // NDEBUG
151*77c1e3ccSAndroid Build Coastguard Worker
av1_find_cnn_layer_output_size(int in_width,int in_height,const CNN_LAYER_CONFIG * layer_config,int * out_width,int * out_height)152*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_layer_output_size(int in_width, int in_height,
153*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config,
154*77c1e3ccSAndroid Build Coastguard Worker int *out_width, int *out_height) {
155*77c1e3ccSAndroid Build Coastguard Worker assert(layer_config->skip_width > 0);
156*77c1e3ccSAndroid Build Coastguard Worker assert(layer_config->skip_height > 0);
157*77c1e3ccSAndroid Build Coastguard Worker if (!layer_config->deconvolve) {
158*77c1e3ccSAndroid Build Coastguard Worker switch (layer_config->pad) {
159*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_ZERO:
160*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_REPLICATE:
161*77c1e3ccSAndroid Build Coastguard Worker *out_width = (in_width + layer_config->skip_width - 1) /
162*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width;
163*77c1e3ccSAndroid Build Coastguard Worker *out_height = (in_height + layer_config->skip_height - 1) /
164*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height;
165*77c1e3ccSAndroid Build Coastguard Worker break;
166*77c1e3ccSAndroid Build Coastguard Worker case PADDING_VALID:
167*77c1e3ccSAndroid Build Coastguard Worker *out_width =
168*77c1e3ccSAndroid Build Coastguard Worker (in_width - layer_config->filter_width + layer_config->skip_width) /
169*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width;
170*77c1e3ccSAndroid Build Coastguard Worker *out_height = (in_height - layer_config->filter_height +
171*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height) /
172*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height;
173*77c1e3ccSAndroid Build Coastguard Worker break;
174*77c1e3ccSAndroid Build Coastguard Worker default: assert(0 && "Unknown padding type");
175*77c1e3ccSAndroid Build Coastguard Worker }
176*77c1e3ccSAndroid Build Coastguard Worker } else {
177*77c1e3ccSAndroid Build Coastguard Worker switch (layer_config->pad) {
178*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_ZERO:
179*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_REPLICATE:
180*77c1e3ccSAndroid Build Coastguard Worker *out_width = in_width * layer_config->skip_width;
181*77c1e3ccSAndroid Build Coastguard Worker *out_height = in_height * layer_config->skip_height;
182*77c1e3ccSAndroid Build Coastguard Worker break;
183*77c1e3ccSAndroid Build Coastguard Worker case PADDING_VALID:
184*77c1e3ccSAndroid Build Coastguard Worker *out_width = (in_width - 1) * layer_config->skip_width +
185*77c1e3ccSAndroid Build Coastguard Worker layer_config->filter_width;
186*77c1e3ccSAndroid Build Coastguard Worker *out_height = (in_height - 1) * layer_config->skip_height +
187*77c1e3ccSAndroid Build Coastguard Worker layer_config->filter_height;
188*77c1e3ccSAndroid Build Coastguard Worker break;
189*77c1e3ccSAndroid Build Coastguard Worker default: assert(0 && "Unknown padding type");
190*77c1e3ccSAndroid Build Coastguard Worker }
191*77c1e3ccSAndroid Build Coastguard Worker }
192*77c1e3ccSAndroid Build Coastguard Worker }
193*77c1e3ccSAndroid Build Coastguard Worker
find_cnn_out_channels(const CNN_LAYER_CONFIG * layer_config,int channels_per_branch[])194*77c1e3ccSAndroid Build Coastguard Worker static void find_cnn_out_channels(const CNN_LAYER_CONFIG *layer_config,
195*77c1e3ccSAndroid Build Coastguard Worker int channels_per_branch[]) {
196*77c1e3ccSAndroid Build Coastguard Worker int branch = layer_config->branch;
197*77c1e3ccSAndroid Build Coastguard Worker const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
198*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
199*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
200*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_INPUT) {
201*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[b] = layer_config->in_channels;
202*77c1e3ccSAndroid Build Coastguard Worker } else if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
203*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[b] = layer_config->out_channels;
204*77c1e3ccSAndroid Build Coastguard Worker } else if (layer_config->branch_copy_type == BRANCH_COMBINED) {
205*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[b] = layer_config->out_channels;
206*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
207*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
208*77c1e3ccSAndroid Build Coastguard Worker assert(channels_per_branch[c] > 0);
209*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[b] += channels_per_branch[c];
210*77c1e3ccSAndroid Build Coastguard Worker }
211*77c1e3ccSAndroid Build Coastguard Worker }
212*77c1e3ccSAndroid Build Coastguard Worker }
213*77c1e3ccSAndroid Build Coastguard Worker }
214*77c1e3ccSAndroid Build Coastguard Worker }
215*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[branch] = layer_config->out_channels;
216*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
217*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
218*77c1e3ccSAndroid Build Coastguard Worker assert(channels_per_branch[c] > 0);
219*77c1e3ccSAndroid Build Coastguard Worker channels_per_branch[branch] += channels_per_branch[c];
220*77c1e3ccSAndroid Build Coastguard Worker }
221*77c1e3ccSAndroid Build Coastguard Worker }
222*77c1e3ccSAndroid Build Coastguard Worker }
223*77c1e3ccSAndroid Build Coastguard Worker
224*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_DEBUG
cnn_has_at_least_one_output(const CNN_CONFIG * cnn_config)225*77c1e3ccSAndroid Build Coastguard Worker static inline int cnn_has_at_least_one_output(const CNN_CONFIG *cnn_config) {
226*77c1e3ccSAndroid Build Coastguard Worker const int num_layers = cnn_config->num_layers;
227*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_configs = cnn_config->layer_config;
228*77c1e3ccSAndroid Build Coastguard Worker
229*77c1e3ccSAndroid Build Coastguard Worker for (int idx = 0; idx < num_layers; idx++) {
230*77c1e3ccSAndroid Build Coastguard Worker if (layer_configs[idx].output_num != -1) {
231*77c1e3ccSAndroid Build Coastguard Worker return 1;
232*77c1e3ccSAndroid Build Coastguard Worker }
233*77c1e3ccSAndroid Build Coastguard Worker }
234*77c1e3ccSAndroid Build Coastguard Worker return 0;
235*77c1e3ccSAndroid Build Coastguard Worker }
236*77c1e3ccSAndroid Build Coastguard Worker #endif
237*77c1e3ccSAndroid Build Coastguard Worker
av1_find_cnn_output_size(int in_width,int in_height,const CNN_CONFIG * cnn_config,int * out_width,int * out_height,int * out_channels)238*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_output_size(int in_width, int in_height,
239*77c1e3ccSAndroid Build Coastguard Worker const CNN_CONFIG *cnn_config, int *out_width,
240*77c1e3ccSAndroid Build Coastguard Worker int *out_height, int *out_channels) {
241*77c1e3ccSAndroid Build Coastguard Worker int channels_per_branch[CNN_MAX_BRANCHES] = { 0 };
242*77c1e3ccSAndroid Build Coastguard Worker int i_width[CNN_MAX_BRANCHES] = { 0 };
243*77c1e3ccSAndroid Build Coastguard Worker int i_height[CNN_MAX_BRANCHES] = { 0 };
244*77c1e3ccSAndroid Build Coastguard Worker i_width[0] = in_width + cnn_config->ext_width * 2;
245*77c1e3ccSAndroid Build Coastguard Worker i_height[0] = in_height + cnn_config->ext_height * 2;
246*77c1e3ccSAndroid Build Coastguard Worker
247*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_DEBUG
248*77c1e3ccSAndroid Build Coastguard Worker assert(cnn_has_at_least_one_output(cnn_config));
249*77c1e3ccSAndroid Build Coastguard Worker #endif
250*77c1e3ccSAndroid Build Coastguard Worker
251*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < cnn_config->num_layers; ++i) {
252*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[i];
253*77c1e3ccSAndroid Build Coastguard Worker const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
254*77c1e3ccSAndroid Build Coastguard Worker const int branch = layer_config->branch;
255*77c1e3ccSAndroid Build Coastguard Worker int o_width = 0, o_height = 0;
256*77c1e3ccSAndroid Build Coastguard Worker
257*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_INPUT) {
258*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
259*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
260*77c1e3ccSAndroid Build Coastguard Worker assert(i_width[branch] > 0 && i_height[branch] > 0);
261*77c1e3ccSAndroid Build Coastguard Worker i_width[b] = i_width[branch];
262*77c1e3ccSAndroid Build Coastguard Worker i_height[b] = i_height[branch];
263*77c1e3ccSAndroid Build Coastguard Worker }
264*77c1e3ccSAndroid Build Coastguard Worker }
265*77c1e3ccSAndroid Build Coastguard Worker }
266*77c1e3ccSAndroid Build Coastguard Worker
267*77c1e3ccSAndroid Build Coastguard Worker av1_find_cnn_layer_output_size(i_width[branch], i_height[branch],
268*77c1e3ccSAndroid Build Coastguard Worker layer_config, &o_width, &o_height);
269*77c1e3ccSAndroid Build Coastguard Worker i_width[branch] = o_width;
270*77c1e3ccSAndroid Build Coastguard Worker i_height[branch] = o_height;
271*77c1e3ccSAndroid Build Coastguard Worker
272*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
273*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
274*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
275*77c1e3ccSAndroid Build Coastguard Worker i_width[b] = o_width;
276*77c1e3ccSAndroid Build Coastguard Worker i_height[b] = o_height;
277*77c1e3ccSAndroid Build Coastguard Worker }
278*77c1e3ccSAndroid Build Coastguard Worker }
279*77c1e3ccSAndroid Build Coastguard Worker }
280*77c1e3ccSAndroid Build Coastguard Worker
281*77c1e3ccSAndroid Build Coastguard Worker find_cnn_out_channels(layer_config, channels_per_branch);
282*77c1e3ccSAndroid Build Coastguard Worker
283*77c1e3ccSAndroid Build Coastguard Worker const int output_num = layer_config->output_num;
284*77c1e3ccSAndroid Build Coastguard Worker if (output_num != -1) { // Current layer is an output layer
285*77c1e3ccSAndroid Build Coastguard Worker out_width[output_num] = o_width;
286*77c1e3ccSAndroid Build Coastguard Worker out_height[output_num] = o_height;
287*77c1e3ccSAndroid Build Coastguard Worker out_channels[output_num] = channels_per_branch[layer_config->branch];
288*77c1e3ccSAndroid Build Coastguard Worker }
289*77c1e3ccSAndroid Build Coastguard Worker }
290*77c1e3ccSAndroid Build Coastguard Worker }
291*77c1e3ccSAndroid Build Coastguard Worker
get_start_shift_convolve(int width,int filt_width,int stride)292*77c1e3ccSAndroid Build Coastguard Worker static inline int get_start_shift_convolve(int width, int filt_width,
293*77c1e3ccSAndroid Build Coastguard Worker int stride) {
294*77c1e3ccSAndroid Build Coastguard Worker const int mod = (width % stride);
295*77c1e3ccSAndroid Build Coastguard Worker const int filt_off = (filt_width - 1) / 2;
296*77c1e3ccSAndroid Build Coastguard Worker const int dif = (mod ? mod - 1 : stride - 1);
297*77c1e3ccSAndroid Build Coastguard Worker return AOMMIN((dif + (filt_width % 2)) / 2, filt_off);
298*77c1e3ccSAndroid Build Coastguard Worker }
299*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_add_c(float ** output,int channels,int width,int height,int stride,const float ** add)300*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_add_c(float **output, int channels, int width, int height,
301*77c1e3ccSAndroid Build Coastguard Worker int stride, const float **add) {
302*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < channels; ++c) {
303*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i)
304*77c1e3ccSAndroid Build Coastguard Worker for (int j = 0; j < width; ++j)
305*77c1e3ccSAndroid Build Coastguard Worker output[c][i * stride + j] += add[c][i * stride + j];
306*77c1e3ccSAndroid Build Coastguard Worker }
307*77c1e3ccSAndroid Build Coastguard Worker }
308*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_activate_c(float ** output,int channels,int width,int height,int stride,ACTIVATION layer_activation)309*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_activate_c(float **output, int channels, int width, int height,
310*77c1e3ccSAndroid Build Coastguard Worker int stride, ACTIVATION layer_activation) {
311*77c1e3ccSAndroid Build Coastguard Worker if (layer_activation == RELU) {
312*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < channels; ++c) {
313*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i)
314*77c1e3ccSAndroid Build Coastguard Worker for (int j = 0; j < width; ++j)
315*77c1e3ccSAndroid Build Coastguard Worker output[c][i * stride + j] = relu(output[c][i * stride + j]);
316*77c1e3ccSAndroid Build Coastguard Worker }
317*77c1e3ccSAndroid Build Coastguard Worker } else if (layer_activation == SOFTSIGN) {
318*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < channels; ++c) {
319*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i)
320*77c1e3ccSAndroid Build Coastguard Worker for (int j = 0; j < width; ++j)
321*77c1e3ccSAndroid Build Coastguard Worker output[c][i * stride + j] = softsign(output[c][i * stride + j]);
322*77c1e3ccSAndroid Build Coastguard Worker }
323*77c1e3ccSAndroid Build Coastguard Worker } else if (layer_activation == SIGMOID) {
324*77c1e3ccSAndroid Build Coastguard Worker assert(0 && "Sigmoid has not been supported in CNN."); // TO DO
325*77c1e3ccSAndroid Build Coastguard Worker } else if (layer_activation != NONE) {
326*77c1e3ccSAndroid Build Coastguard Worker assert(0 && "Unknown activation type");
327*77c1e3ccSAndroid Build Coastguard Worker }
328*77c1e3ccSAndroid Build Coastguard Worker }
329*77c1e3ccSAndroid Build Coastguard Worker
copy_active_tensor_to_branches(const TENSOR * layer_active_tensor,const CNN_LAYER_CONFIG * layer_config,int branch,TENSOR branch_output[])330*77c1e3ccSAndroid Build Coastguard Worker static bool copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
331*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config,
332*77c1e3ccSAndroid Build Coastguard Worker int branch, TENSOR branch_output[]) {
333*77c1e3ccSAndroid Build Coastguard Worker const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
334*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
335*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
336*77c1e3ccSAndroid Build Coastguard Worker // Copy layer's active tensor to output tensor of branch b if set in
337*77c1e3ccSAndroid Build Coastguard Worker // mask. The output becomes the input of the first layer of the branch
338*77c1e3ccSAndroid Build Coastguard Worker // because the layer of the branch is not the first layer.
339*77c1e3ccSAndroid Build Coastguard Worker int copy_channels = branch_config->channels_to_copy > 0
340*77c1e3ccSAndroid Build Coastguard Worker ? branch_config->channels_to_copy
341*77c1e3ccSAndroid Build Coastguard Worker : layer_active_tensor->channels;
342*77c1e3ccSAndroid Build Coastguard Worker if (!realloc_tensor(&branch_output[b], copy_channels,
343*77c1e3ccSAndroid Build Coastguard Worker layer_active_tensor->width,
344*77c1e3ccSAndroid Build Coastguard Worker layer_active_tensor->height)) {
345*77c1e3ccSAndroid Build Coastguard Worker return false;
346*77c1e3ccSAndroid Build Coastguard Worker }
347*77c1e3ccSAndroid Build Coastguard Worker copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
348*77c1e3ccSAndroid Build Coastguard Worker }
349*77c1e3ccSAndroid Build Coastguard Worker }
350*77c1e3ccSAndroid Build Coastguard Worker return true;
351*77c1e3ccSAndroid Build Coastguard Worker }
352*77c1e3ccSAndroid Build Coastguard Worker
353*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
354*77c1e3ccSAndroid Build Coastguard Worker // greater than 1 and padding equal to PADDING_SAME_ZERO.
convolve_maxpool_padding_zero(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep,const int filter_width_half,const int filter_height_half)355*77c1e3ccSAndroid Build Coastguard Worker static void convolve_maxpool_padding_zero(
356*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
357*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
358*77c1e3ccSAndroid Build Coastguard Worker const int cstep, const int filter_width_half,
359*77c1e3ccSAndroid Build Coastguard Worker const int filter_height_half) {
360*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
361*77c1e3ccSAndroid Build Coastguard Worker for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
362*77c1e3ccSAndroid Build Coastguard Worker for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
363*77c1e3ccSAndroid Build Coastguard Worker for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
364*77c1e3ccSAndroid Build Coastguard Worker ++hh) {
365*77c1e3ccSAndroid Build Coastguard Worker for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
366*77c1e3ccSAndroid Build Coastguard Worker ++ww) {
367*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
368*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
369*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
370*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
371*77c1e3ccSAndroid Build Coastguard Worker const int ii = hh + l - filter_height_half;
372*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
373*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
374*77c1e3ccSAndroid Build Coastguard Worker const int jj = ww + m - filter_width_half;
375*77c1e3ccSAndroid Build Coastguard Worker if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
376*77c1e3ccSAndroid Build Coastguard Worker continue;
377*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
378*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
379*77c1e3ccSAndroid Build Coastguard Worker }
380*77c1e3ccSAndroid Build Coastguard Worker }
381*77c1e3ccSAndroid Build Coastguard Worker }
382*77c1e3ccSAndroid Build Coastguard Worker const float a = sum;
383*77c1e3ccSAndroid Build Coastguard Worker if (h == hh && w == ww)
384*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = a;
385*77c1e3ccSAndroid Build Coastguard Worker else
386*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] =
387*77c1e3ccSAndroid Build Coastguard Worker AOMMAX(output[i][u * out_stride + v], a);
388*77c1e3ccSAndroid Build Coastguard Worker }
389*77c1e3ccSAndroid Build Coastguard Worker }
390*77c1e3ccSAndroid Build Coastguard Worker }
391*77c1e3ccSAndroid Build Coastguard Worker }
392*77c1e3ccSAndroid Build Coastguard Worker }
393*77c1e3ccSAndroid Build Coastguard Worker }
394*77c1e3ccSAndroid Build Coastguard Worker
395*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
396*77c1e3ccSAndroid Build Coastguard Worker // greater than 1 and padding equal to PADDING_SAME_REPLICATE.
convolve_maxpool_padding_replicate(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep,const int filter_width_half,const int filter_height_half)397*77c1e3ccSAndroid Build Coastguard Worker static void convolve_maxpool_padding_replicate(
398*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
399*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
400*77c1e3ccSAndroid Build Coastguard Worker const int cstep, const int filter_width_half,
401*77c1e3ccSAndroid Build Coastguard Worker const int filter_height_half) {
402*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
403*77c1e3ccSAndroid Build Coastguard Worker for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
404*77c1e3ccSAndroid Build Coastguard Worker for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
405*77c1e3ccSAndroid Build Coastguard Worker for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
406*77c1e3ccSAndroid Build Coastguard Worker ++hh) {
407*77c1e3ccSAndroid Build Coastguard Worker for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
408*77c1e3ccSAndroid Build Coastguard Worker ++ww) {
409*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
410*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
411*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
412*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
413*77c1e3ccSAndroid Build Coastguard Worker const int ii =
414*77c1e3ccSAndroid Build Coastguard Worker CLAMPINDEX(hh + l - filter_height_half, in_height);
415*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
416*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
417*77c1e3ccSAndroid Build Coastguard Worker const int jj =
418*77c1e3ccSAndroid Build Coastguard Worker CLAMPINDEX(ww + m - filter_width_half, in_width);
419*77c1e3ccSAndroid Build Coastguard Worker assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
420*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
421*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
422*77c1e3ccSAndroid Build Coastguard Worker }
423*77c1e3ccSAndroid Build Coastguard Worker }
424*77c1e3ccSAndroid Build Coastguard Worker }
425*77c1e3ccSAndroid Build Coastguard Worker const float a = sum;
426*77c1e3ccSAndroid Build Coastguard Worker if (h == hh && w == ww)
427*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = a;
428*77c1e3ccSAndroid Build Coastguard Worker else
429*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] =
430*77c1e3ccSAndroid Build Coastguard Worker AOMMAX(output[i][u * out_stride + v], a);
431*77c1e3ccSAndroid Build Coastguard Worker }
432*77c1e3ccSAndroid Build Coastguard Worker }
433*77c1e3ccSAndroid Build Coastguard Worker }
434*77c1e3ccSAndroid Build Coastguard Worker }
435*77c1e3ccSAndroid Build Coastguard Worker }
436*77c1e3ccSAndroid Build Coastguard Worker }
437*77c1e3ccSAndroid Build Coastguard Worker
438*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
439*77c1e3ccSAndroid Build Coastguard Worker // greater than 1 and padding equal to PADDING_VALID.
convolve_maxpool_padding_valid(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,const int cstep)440*77c1e3ccSAndroid Build Coastguard Worker static void convolve_maxpool_padding_valid(
441*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
442*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
443*77c1e3ccSAndroid Build Coastguard Worker const int cstep) {
444*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
445*77c1e3ccSAndroid Build Coastguard Worker for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
446*77c1e3ccSAndroid Build Coastguard Worker h += layer_config->skip_height, ++u) {
447*77c1e3ccSAndroid Build Coastguard Worker for (int w = 0, v = 0; w < in_width - layer_config->filter_width + 1;
448*77c1e3ccSAndroid Build Coastguard Worker w += layer_config->skip_width, ++v) {
449*77c1e3ccSAndroid Build Coastguard Worker for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
450*77c1e3ccSAndroid Build Coastguard Worker ++hh) {
451*77c1e3ccSAndroid Build Coastguard Worker for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
452*77c1e3ccSAndroid Build Coastguard Worker ++ww) {
453*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
454*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
455*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
456*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
457*77c1e3ccSAndroid Build Coastguard Worker const int ii = hh + l;
458*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
459*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
460*77c1e3ccSAndroid Build Coastguard Worker const int jj = ww + m;
461*77c1e3ccSAndroid Build Coastguard Worker assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
462*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
463*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
464*77c1e3ccSAndroid Build Coastguard Worker }
465*77c1e3ccSAndroid Build Coastguard Worker }
466*77c1e3ccSAndroid Build Coastguard Worker }
467*77c1e3ccSAndroid Build Coastguard Worker const float a = sum;
468*77c1e3ccSAndroid Build Coastguard Worker if (h == hh && w == ww)
469*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = a;
470*77c1e3ccSAndroid Build Coastguard Worker else
471*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] =
472*77c1e3ccSAndroid Build Coastguard Worker AOMMAX(output[i][u * out_stride + v], a);
473*77c1e3ccSAndroid Build Coastguard Worker }
474*77c1e3ccSAndroid Build Coastguard Worker }
475*77c1e3ccSAndroid Build Coastguard Worker }
476*77c1e3ccSAndroid Build Coastguard Worker }
477*77c1e3ccSAndroid Build Coastguard Worker }
478*77c1e3ccSAndroid Build Coastguard Worker }
479*77c1e3ccSAndroid Build Coastguard Worker
480*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 0 with filter_height and filter_width
481*77c1e3ccSAndroid Build Coastguard Worker // equal to 1.
convolve_element_wise(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,int step)482*77c1e3ccSAndroid Build Coastguard Worker static void convolve_element_wise(const float **input, int in_width,
483*77c1e3ccSAndroid Build Coastguard Worker int in_height, int in_stride,
484*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config,
485*77c1e3ccSAndroid Build Coastguard Worker float **output, int out_stride, int start_idx,
486*77c1e3ccSAndroid Build Coastguard Worker int step) {
487*77c1e3ccSAndroid Build Coastguard Worker const int start_h = get_start_shift_convolve(
488*77c1e3ccSAndroid Build Coastguard Worker in_height, layer_config->filter_height, layer_config->skip_height);
489*77c1e3ccSAndroid Build Coastguard Worker const int start_w =
490*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_convolve(in_width, layer_config->filter_width,
491*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width) +
492*77c1e3ccSAndroid Build Coastguard Worker start_idx * layer_config->skip_width;
493*77c1e3ccSAndroid Build Coastguard Worker const int out_w_step = AOMMAX(step, 1);
494*77c1e3ccSAndroid Build Coastguard Worker const int in_w_step = layer_config->skip_width * out_w_step;
495*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
496*77c1e3ccSAndroid Build Coastguard Worker for (int h = start_h, u = 0; h < in_height;
497*77c1e3ccSAndroid Build Coastguard Worker h += layer_config->skip_height, ++u) {
498*77c1e3ccSAndroid Build Coastguard Worker const int in_h = h * in_stride;
499*77c1e3ccSAndroid Build Coastguard Worker const int out_h = u * out_stride + start_idx;
500*77c1e3ccSAndroid Build Coastguard Worker for (int w = start_w, out_index = out_h; w < in_width;
501*77c1e3ccSAndroid Build Coastguard Worker w += in_w_step, out_index += out_w_step) {
502*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
503*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
504*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[k * layer_config->out_channels + i] *
505*77c1e3ccSAndroid Build Coastguard Worker input[k][in_h + w];
506*77c1e3ccSAndroid Build Coastguard Worker }
507*77c1e3ccSAndroid Build Coastguard Worker output[i][out_index] = sum;
508*77c1e3ccSAndroid Build Coastguard Worker }
509*77c1e3ccSAndroid Build Coastguard Worker }
510*77c1e3ccSAndroid Build Coastguard Worker }
511*77c1e3ccSAndroid Build Coastguard Worker }
512*77c1e3ccSAndroid Build Coastguard Worker
513*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 0 and padding equal to
514*77c1e3ccSAndroid Build Coastguard Worker // PADDING_SAME_ZERO.
convolve_no_maxpool_padding_zero(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int filter_width_half,const int filter_height_half,const int ii_shift,const int jj_shift,const int channel_step)515*77c1e3ccSAndroid Build Coastguard Worker static void convolve_no_maxpool_padding_zero(
516*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
517*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
518*77c1e3ccSAndroid Build Coastguard Worker int start_idx, const int cstep, const int filter_width_half,
519*77c1e3ccSAndroid Build Coastguard Worker const int filter_height_half, const int ii_shift, const int jj_shift,
520*77c1e3ccSAndroid Build Coastguard Worker const int channel_step) {
521*77c1e3ccSAndroid Build Coastguard Worker const int start_h = get_start_shift_convolve(
522*77c1e3ccSAndroid Build Coastguard Worker in_height, layer_config->filter_height, layer_config->skip_height);
523*77c1e3ccSAndroid Build Coastguard Worker const int start_w = get_start_shift_convolve(
524*77c1e3ccSAndroid Build Coastguard Worker in_width, layer_config->filter_width, layer_config->skip_width);
525*77c1e3ccSAndroid Build Coastguard Worker const int end_ii_shift = filter_height_half + 1;
526*77c1e3ccSAndroid Build Coastguard Worker const int end_jj_shift = filter_width_half + 1;
527*77c1e3ccSAndroid Build Coastguard Worker // *_filter_margin stores the number of pixels along a dimension in the
528*77c1e3ccSAndroid Build Coastguard Worker // intersection of the complement of the image in the extended image
529*77c1e3ccSAndroid Build Coastguard Worker // and the filter.
530*77c1e3ccSAndroid Build Coastguard Worker const int top_filter_margin = layer_config->filter_width * ii_shift;
531*77c1e3ccSAndroid Build Coastguard Worker const int right_filter_margin = end_jj_shift - in_width;
532*77c1e3ccSAndroid Build Coastguard Worker for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
533*77c1e3ccSAndroid Build Coastguard Worker for (int h = start_h, u = 0; h < in_height;
534*77c1e3ccSAndroid Build Coastguard Worker h += layer_config->skip_height, ++u) {
535*77c1e3ccSAndroid Build Coastguard Worker const int out_h = u * out_stride;
536*77c1e3ccSAndroid Build Coastguard Worker const int top_cstep =
537*77c1e3ccSAndroid Build Coastguard Worker AOMMAX(0, top_filter_margin - h * layer_config->filter_width) *
538*77c1e3ccSAndroid Build Coastguard Worker cstep +
539*77c1e3ccSAndroid Build Coastguard Worker i;
540*77c1e3ccSAndroid Build Coastguard Worker const int start_ii = AOMMAX(0, h - ii_shift);
541*77c1e3ccSAndroid Build Coastguard Worker const int end_ii = AOMMIN(in_height, h + end_ii_shift);
542*77c1e3ccSAndroid Build Coastguard Worker for (int w = start_w, out_index = out_h; w < in_width;
543*77c1e3ccSAndroid Build Coastguard Worker w += layer_config->skip_width, ++out_index) {
544*77c1e3ccSAndroid Build Coastguard Worker const int left_cstep = AOMMAX(0, jj_shift - w) * cstep;
545*77c1e3ccSAndroid Build Coastguard Worker const int right_cstep = AOMMAX(0, right_filter_margin + w) * cstep;
546*77c1e3ccSAndroid Build Coastguard Worker const int start_jj = AOMMAX(0, w - jj_shift);
547*77c1e3ccSAndroid Build Coastguard Worker const int end_jj = AOMMIN(in_width, w + end_jj_shift);
548*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
549*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
550*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + top_cstep;
551*77c1e3ccSAndroid Build Coastguard Worker for (int ii = start_ii; ii < end_ii; ++ii) {
552*77c1e3ccSAndroid Build Coastguard Worker off += left_cstep;
553*77c1e3ccSAndroid Build Coastguard Worker for (int jj = start_jj; jj < end_jj; ++jj, off += cstep) {
554*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
555*77c1e3ccSAndroid Build Coastguard Worker }
556*77c1e3ccSAndroid Build Coastguard Worker off += right_cstep;
557*77c1e3ccSAndroid Build Coastguard Worker }
558*77c1e3ccSAndroid Build Coastguard Worker }
559*77c1e3ccSAndroid Build Coastguard Worker output[i][out_index] = sum;
560*77c1e3ccSAndroid Build Coastguard Worker }
561*77c1e3ccSAndroid Build Coastguard Worker }
562*77c1e3ccSAndroid Build Coastguard Worker }
563*77c1e3ccSAndroid Build Coastguard Worker }
564*77c1e3ccSAndroid Build Coastguard Worker
565*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 0 and padding equal to
566*77c1e3ccSAndroid Build Coastguard Worker // PADDING_SAME_REPLICATE.
convolve_no_maxpool_padding_replicate(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int ii_shift,const int jj_shift,const int channel_step)567*77c1e3ccSAndroid Build Coastguard Worker static void convolve_no_maxpool_padding_replicate(
568*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
569*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
570*77c1e3ccSAndroid Build Coastguard Worker int start_idx, const int cstep, const int ii_shift, const int jj_shift,
571*77c1e3ccSAndroid Build Coastguard Worker const int channel_step) {
572*77c1e3ccSAndroid Build Coastguard Worker // h and w are shifted to an offset coordinate system to reduce in-loop
573*77c1e3ccSAndroid Build Coastguard Worker // computation.
574*77c1e3ccSAndroid Build Coastguard Worker const int start_h =
575*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_convolve(in_height, layer_config->filter_height,
576*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height) -
577*77c1e3ccSAndroid Build Coastguard Worker ii_shift;
578*77c1e3ccSAndroid Build Coastguard Worker const int start_w =
579*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_convolve(in_width, layer_config->filter_width,
580*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width) -
581*77c1e3ccSAndroid Build Coastguard Worker jj_shift;
582*77c1e3ccSAndroid Build Coastguard Worker const int end_h = in_height - ii_shift;
583*77c1e3ccSAndroid Build Coastguard Worker const int end_w = in_width - jj_shift;
584*77c1e3ccSAndroid Build Coastguard Worker for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
585*77c1e3ccSAndroid Build Coastguard Worker for (int h = start_h, u = 0; h < end_h;
586*77c1e3ccSAndroid Build Coastguard Worker h += layer_config->skip_height, ++u) {
587*77c1e3ccSAndroid Build Coastguard Worker const int out_h = u * out_stride;
588*77c1e3ccSAndroid Build Coastguard Worker const int upper_ii_index = layer_config->filter_height + h;
589*77c1e3ccSAndroid Build Coastguard Worker for (int w = start_w, out_index = out_h; w < end_w;
590*77c1e3ccSAndroid Build Coastguard Worker w += layer_config->skip_width, ++out_index) {
591*77c1e3ccSAndroid Build Coastguard Worker const int upper_jj_index = layer_config->filter_width + w;
592*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
593*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
594*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
595*77c1e3ccSAndroid Build Coastguard Worker for (int ii = h; ii < upper_ii_index; ++ii) {
596*77c1e3ccSAndroid Build Coastguard Worker const int clamped_ii = CLAMPINDEX(ii, in_height);
597*77c1e3ccSAndroid Build Coastguard Worker for (int jj = w; jj < upper_jj_index; ++jj) {
598*77c1e3ccSAndroid Build Coastguard Worker const int clamped_jj = CLAMPINDEX(jj, in_width);
599*77c1e3ccSAndroid Build Coastguard Worker assert(clamped_ii >= 0 && clamped_ii < in_height &&
600*77c1e3ccSAndroid Build Coastguard Worker clamped_jj >= 0 && clamped_jj < in_width);
601*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
602*77c1e3ccSAndroid Build Coastguard Worker input[k][clamped_ii * in_stride + clamped_jj];
603*77c1e3ccSAndroid Build Coastguard Worker off += cstep;
604*77c1e3ccSAndroid Build Coastguard Worker }
605*77c1e3ccSAndroid Build Coastguard Worker }
606*77c1e3ccSAndroid Build Coastguard Worker }
607*77c1e3ccSAndroid Build Coastguard Worker output[i][out_index] = sum;
608*77c1e3ccSAndroid Build Coastguard Worker }
609*77c1e3ccSAndroid Build Coastguard Worker }
610*77c1e3ccSAndroid Build Coastguard Worker }
611*77c1e3ccSAndroid Build Coastguard Worker }
612*77c1e3ccSAndroid Build Coastguard Worker
613*77c1e3ccSAndroid Build Coastguard Worker // CNNConvolve specific to maxpool set as 0 and padding equal to
614*77c1e3ccSAndroid Build Coastguard Worker // PADDING_VALID.
av1_cnn_convolve_no_maxpool_padding_valid_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int cstep,int channel_step)615*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_convolve_no_maxpool_padding_valid_c(
616*77c1e3ccSAndroid Build Coastguard Worker const float **input, int in_width, int in_height, int in_stride,
617*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
618*77c1e3ccSAndroid Build Coastguard Worker int start_idx, int cstep, int channel_step) {
619*77c1e3ccSAndroid Build Coastguard Worker assert((layer_config->skip_height == 1 && layer_config->skip_width == 1) ||
620*77c1e3ccSAndroid Build Coastguard Worker !layer_config->maxpool);
621*77c1e3ccSAndroid Build Coastguard Worker assert(layer_config->filter_height > 1 || layer_config->filter_width > 1);
622*77c1e3ccSAndroid Build Coastguard Worker assert(layer_config->pad == PADDING_VALID);
623*77c1e3ccSAndroid Build Coastguard Worker for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
624*77c1e3ccSAndroid Build Coastguard Worker for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
625*77c1e3ccSAndroid Build Coastguard Worker h += layer_config->skip_height, ++u) {
626*77c1e3ccSAndroid Build Coastguard Worker const int out_h = u * out_stride;
627*77c1e3ccSAndroid Build Coastguard Worker const int upper_ii_index = layer_config->filter_height + h;
628*77c1e3ccSAndroid Build Coastguard Worker for (int w = 0, out_index = out_h;
629*77c1e3ccSAndroid Build Coastguard Worker w < in_width - layer_config->filter_width + 1;
630*77c1e3ccSAndroid Build Coastguard Worker w += layer_config->skip_width, ++out_index) {
631*77c1e3ccSAndroid Build Coastguard Worker const int upper_jj_index = layer_config->filter_width + w;
632*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
633*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
634*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
635*77c1e3ccSAndroid Build Coastguard Worker for (int ii = h; ii < upper_ii_index; ++ii) {
636*77c1e3ccSAndroid Build Coastguard Worker for (int jj = w; jj < upper_jj_index; ++jj) {
637*77c1e3ccSAndroid Build Coastguard Worker assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
638*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
639*77c1e3ccSAndroid Build Coastguard Worker off += cstep;
640*77c1e3ccSAndroid Build Coastguard Worker }
641*77c1e3ccSAndroid Build Coastguard Worker }
642*77c1e3ccSAndroid Build Coastguard Worker }
643*77c1e3ccSAndroid Build Coastguard Worker output[i][out_index] = sum;
644*77c1e3ccSAndroid Build Coastguard Worker }
645*77c1e3ccSAndroid Build Coastguard Worker }
646*77c1e3ccSAndroid Build Coastguard Worker }
647*77c1e3ccSAndroid Build Coastguard Worker }
648*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_convolve(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int step)649*77c1e3ccSAndroid Build Coastguard Worker static void av1_cnn_convolve(const float **input, int in_width, int in_height,
650*77c1e3ccSAndroid Build Coastguard Worker int in_stride,
651*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config,
652*77c1e3ccSAndroid Build Coastguard Worker float **output, int out_stride, int start_idx,
653*77c1e3ccSAndroid Build Coastguard Worker int step) {
654*77c1e3ccSAndroid Build Coastguard Worker assert(!layer_config->deconvolve);
655*77c1e3ccSAndroid Build Coastguard Worker const int cstep = layer_config->in_channels * layer_config->out_channels;
656*77c1e3ccSAndroid Build Coastguard Worker const int filter_height_half = layer_config->filter_height >> 1;
657*77c1e3ccSAndroid Build Coastguard Worker const int filter_width_half = layer_config->filter_width >> 1;
658*77c1e3ccSAndroid Build Coastguard Worker const int channel_step = AOMMAX(step, 1);
659*77c1e3ccSAndroid Build Coastguard Worker
660*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->maxpool &&
661*77c1e3ccSAndroid Build Coastguard Worker (layer_config->skip_height > 1 || layer_config->skip_width > 1)) {
662*77c1e3ccSAndroid Build Coastguard Worker switch (layer_config->pad) {
663*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_ZERO:
664*77c1e3ccSAndroid Build Coastguard Worker convolve_maxpool_padding_zero(input, in_width, in_height, in_stride,
665*77c1e3ccSAndroid Build Coastguard Worker layer_config, output, out_stride, cstep,
666*77c1e3ccSAndroid Build Coastguard Worker filter_width_half, filter_height_half);
667*77c1e3ccSAndroid Build Coastguard Worker break;
668*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_REPLICATE:
669*77c1e3ccSAndroid Build Coastguard Worker convolve_maxpool_padding_replicate(
670*77c1e3ccSAndroid Build Coastguard Worker input, in_width, in_height, in_stride, layer_config, output,
671*77c1e3ccSAndroid Build Coastguard Worker out_stride, cstep, filter_width_half, filter_height_half);
672*77c1e3ccSAndroid Build Coastguard Worker break;
673*77c1e3ccSAndroid Build Coastguard Worker case PADDING_VALID:
674*77c1e3ccSAndroid Build Coastguard Worker convolve_maxpool_padding_valid(input, in_width, in_height, in_stride,
675*77c1e3ccSAndroid Build Coastguard Worker layer_config, output, out_stride, cstep);
676*77c1e3ccSAndroid Build Coastguard Worker break;
677*77c1e3ccSAndroid Build Coastguard Worker default: assert(0 && "Unknown padding type");
678*77c1e3ccSAndroid Build Coastguard Worker }
679*77c1e3ccSAndroid Build Coastguard Worker } else {
680*77c1e3ccSAndroid Build Coastguard Worker // Results in element-wise matrix multiplication.
681*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->filter_height == 1 && layer_config->filter_width == 1) {
682*77c1e3ccSAndroid Build Coastguard Worker convolve_element_wise(input, in_width, in_height, in_stride, layer_config,
683*77c1e3ccSAndroid Build Coastguard Worker output, out_stride, start_idx, step);
684*77c1e3ccSAndroid Build Coastguard Worker return;
685*77c1e3ccSAndroid Build Coastguard Worker }
686*77c1e3ccSAndroid Build Coastguard Worker const int ii_shift =
687*77c1e3ccSAndroid Build Coastguard Worker filter_height_half - (layer_config->filter_height - 1) % 2;
688*77c1e3ccSAndroid Build Coastguard Worker const int jj_shift =
689*77c1e3ccSAndroid Build Coastguard Worker filter_width_half - (layer_config->filter_width - 1) % 2;
690*77c1e3ccSAndroid Build Coastguard Worker switch (layer_config->pad) {
691*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_ZERO:
692*77c1e3ccSAndroid Build Coastguard Worker convolve_no_maxpool_padding_zero(
693*77c1e3ccSAndroid Build Coastguard Worker input, in_width, in_height, in_stride, layer_config, output,
694*77c1e3ccSAndroid Build Coastguard Worker out_stride, start_idx, cstep, filter_width_half, filter_height_half,
695*77c1e3ccSAndroid Build Coastguard Worker ii_shift, jj_shift, channel_step);
696*77c1e3ccSAndroid Build Coastguard Worker break;
697*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_REPLICATE:
698*77c1e3ccSAndroid Build Coastguard Worker convolve_no_maxpool_padding_replicate(
699*77c1e3ccSAndroid Build Coastguard Worker input, in_width, in_height, in_stride, layer_config, output,
700*77c1e3ccSAndroid Build Coastguard Worker out_stride, start_idx, cstep, ii_shift, jj_shift, channel_step);
701*77c1e3ccSAndroid Build Coastguard Worker break;
702*77c1e3ccSAndroid Build Coastguard Worker case PADDING_VALID:
703*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_convolve_no_maxpool_padding_valid(
704*77c1e3ccSAndroid Build Coastguard Worker input, in_width, in_height, in_stride, layer_config, output,
705*77c1e3ccSAndroid Build Coastguard Worker out_stride, start_idx, cstep, channel_step);
706*77c1e3ccSAndroid Build Coastguard Worker break;
707*77c1e3ccSAndroid Build Coastguard Worker default: assert(0 && "Unknown padding type");
708*77c1e3ccSAndroid Build Coastguard Worker }
709*77c1e3ccSAndroid Build Coastguard Worker }
710*77c1e3ccSAndroid Build Coastguard Worker }
711*77c1e3ccSAndroid Build Coastguard Worker
convolve_layer(void * arg1,void * arg2)712*77c1e3ccSAndroid Build Coastguard Worker static int convolve_layer(void *arg1, void *arg2) {
713*77c1e3ccSAndroid Build Coastguard Worker const CONVOLVE_OPS *convolve_ops = arg1;
714*77c1e3ccSAndroid Build Coastguard Worker (void)arg2;
715*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_convolve(
716*77c1e3ccSAndroid Build Coastguard Worker convolve_ops->input, convolve_ops->in_width, convolve_ops->in_height,
717*77c1e3ccSAndroid Build Coastguard Worker convolve_ops->in_stride, convolve_ops->layer_config, convolve_ops->output,
718*77c1e3ccSAndroid Build Coastguard Worker convolve_ops->out_stride, convolve_ops->start_idx, convolve_ops->th_step);
719*77c1e3ccSAndroid Build Coastguard Worker return 1;
720*77c1e3ccSAndroid Build Coastguard Worker }
721*77c1e3ccSAndroid Build Coastguard Worker
convolve_layer_mt(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,const CNN_THREAD_DATA * thread_data,float ** output,int out_stride)722*77c1e3ccSAndroid Build Coastguard Worker static void convolve_layer_mt(const float **input, int in_width, int in_height,
723*77c1e3ccSAndroid Build Coastguard Worker int in_stride,
724*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config,
725*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data,
726*77c1e3ccSAndroid Build Coastguard Worker float **output, int out_stride) {
727*77c1e3ccSAndroid Build Coastguard Worker const AVxWorkerInterface *const winterface = aom_get_worker_interface();
728*77c1e3ccSAndroid Build Coastguard Worker const int num_workers = thread_data->num_workers;
729*77c1e3ccSAndroid Build Coastguard Worker assert(thread_data->workers);
730*77c1e3ccSAndroid Build Coastguard Worker
731*77c1e3ccSAndroid Build Coastguard Worker CONVOLVE_OPS convolve_ops[CNN_MAX_THREADS];
732*77c1e3ccSAndroid Build Coastguard Worker for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
733*77c1e3ccSAndroid Build Coastguard Worker AVxWorker *const worker = &thread_data->workers[th];
734*77c1e3ccSAndroid Build Coastguard Worker winterface->reset(worker);
735*77c1e3ccSAndroid Build Coastguard Worker
736*77c1e3ccSAndroid Build Coastguard Worker CONVOLVE_OPS convolve_op = { input, in_width, in_height,
737*77c1e3ccSAndroid Build Coastguard Worker in_stride, layer_config, output,
738*77c1e3ccSAndroid Build Coastguard Worker out_stride, th, num_workers };
739*77c1e3ccSAndroid Build Coastguard Worker convolve_ops[th] = convolve_op;
740*77c1e3ccSAndroid Build Coastguard Worker worker->hook = convolve_layer;
741*77c1e3ccSAndroid Build Coastguard Worker worker->data1 = &(convolve_ops[th]);
742*77c1e3ccSAndroid Build Coastguard Worker worker->data2 = NULL;
743*77c1e3ccSAndroid Build Coastguard Worker
744*77c1e3ccSAndroid Build Coastguard Worker // Start convolving.
745*77c1e3ccSAndroid Build Coastguard Worker if (th == num_workers - 1) {
746*77c1e3ccSAndroid Build Coastguard Worker winterface->execute(worker);
747*77c1e3ccSAndroid Build Coastguard Worker } else {
748*77c1e3ccSAndroid Build Coastguard Worker winterface->launch(worker);
749*77c1e3ccSAndroid Build Coastguard Worker }
750*77c1e3ccSAndroid Build Coastguard Worker }
751*77c1e3ccSAndroid Build Coastguard Worker
752*77c1e3ccSAndroid Build Coastguard Worker // Wait until all workers have finished.
753*77c1e3ccSAndroid Build Coastguard Worker for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
754*77c1e3ccSAndroid Build Coastguard Worker winterface->sync(&thread_data->workers[th]);
755*77c1e3ccSAndroid Build Coastguard Worker }
756*77c1e3ccSAndroid Build Coastguard Worker }
757*77c1e3ccSAndroid Build Coastguard Worker
get_start_shift_deconvolve(int filt_width,int stride)758*77c1e3ccSAndroid Build Coastguard Worker static inline int get_start_shift_deconvolve(int filt_width, int stride) {
759*77c1e3ccSAndroid Build Coastguard Worker const int dif = AOMMAX(filt_width - stride, 0);
760*77c1e3ccSAndroid Build Coastguard Worker return dif / 2;
761*77c1e3ccSAndroid Build Coastguard Worker }
762*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_batchnorm_c(float ** image,int channels,int width,int height,int stride,const float * gamma,const float * beta,const float * mean,const float * std)763*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_batchnorm_c(float **image, int channels, int width, int height,
764*77c1e3ccSAndroid Build Coastguard Worker int stride, const float *gamma, const float *beta,
765*77c1e3ccSAndroid Build Coastguard Worker const float *mean, const float *std) {
766*77c1e3ccSAndroid Build Coastguard Worker assert(gamma && beta && beta && std && "batchnorm has null parameter!");
767*77c1e3ccSAndroid Build Coastguard Worker for (int ch = 0; ch < channels; ch++) {
768*77c1e3ccSAndroid Build Coastguard Worker const float ch_gamma = gamma[ch];
769*77c1e3ccSAndroid Build Coastguard Worker const float ch_beta = beta[ch];
770*77c1e3ccSAndroid Build Coastguard Worker const float ch_mean = mean[ch];
771*77c1e3ccSAndroid Build Coastguard Worker const float ch_std = std[ch];
772*77c1e3ccSAndroid Build Coastguard Worker float *image_row = image[ch];
773*77c1e3ccSAndroid Build Coastguard Worker
774*77c1e3ccSAndroid Build Coastguard Worker for (int row = 0; row < height; row++) {
775*77c1e3ccSAndroid Build Coastguard Worker for (int col = 0; col < width; col++) {
776*77c1e3ccSAndroid Build Coastguard Worker image_row[col] =
777*77c1e3ccSAndroid Build Coastguard Worker ch_gamma * (image_row[col] - ch_mean) / ch_std + ch_beta;
778*77c1e3ccSAndroid Build Coastguard Worker }
779*77c1e3ccSAndroid Build Coastguard Worker image_row += stride;
780*77c1e3ccSAndroid Build Coastguard Worker }
781*77c1e3ccSAndroid Build Coastguard Worker }
782*77c1e3ccSAndroid Build Coastguard Worker }
783*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_deconvolve_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride)784*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_deconvolve_c(const float **input, int in_width, int in_height,
785*77c1e3ccSAndroid Build Coastguard Worker int in_stride, const CNN_LAYER_CONFIG *layer_config,
786*77c1e3ccSAndroid Build Coastguard Worker float **output, int out_stride) {
787*77c1e3ccSAndroid Build Coastguard Worker assert(layer_config->deconvolve);
788*77c1e3ccSAndroid Build Coastguard Worker
789*77c1e3ccSAndroid Build Coastguard Worker const int cstep = layer_config->in_channels * layer_config->out_channels;
790*77c1e3ccSAndroid Build Coastguard Worker
791*77c1e3ccSAndroid Build Coastguard Worker int out_width = 0;
792*77c1e3ccSAndroid Build Coastguard Worker int out_height = 0;
793*77c1e3ccSAndroid Build Coastguard Worker av1_find_cnn_layer_output_size(in_width, in_height, layer_config, &out_width,
794*77c1e3ccSAndroid Build Coastguard Worker &out_height);
795*77c1e3ccSAndroid Build Coastguard Worker switch (layer_config->pad) {
796*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_ZERO:
797*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
798*77c1e3ccSAndroid Build Coastguard Worker for (int u = 0; u < out_height; ++u) {
799*77c1e3ccSAndroid Build Coastguard Worker for (int v = 0; v < out_width; ++v) {
800*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
801*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
802*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
803*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
804*77c1e3ccSAndroid Build Coastguard Worker const int h =
805*77c1e3ccSAndroid Build Coastguard Worker u - l +
806*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_deconvolve(layer_config->filter_height,
807*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height);
808*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
809*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
810*77c1e3ccSAndroid Build Coastguard Worker const int w =
811*77c1e3ccSAndroid Build Coastguard Worker v - m +
812*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_deconvolve(layer_config->filter_width,
813*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width);
814*77c1e3ccSAndroid Build Coastguard Worker if ((h % layer_config->skip_height) != 0 ||
815*77c1e3ccSAndroid Build Coastguard Worker (w % layer_config->skip_width) != 0)
816*77c1e3ccSAndroid Build Coastguard Worker continue;
817*77c1e3ccSAndroid Build Coastguard Worker const int ii = h / layer_config->skip_height;
818*77c1e3ccSAndroid Build Coastguard Worker const int jj = w / layer_config->skip_width;
819*77c1e3ccSAndroid Build Coastguard Worker if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
820*77c1e3ccSAndroid Build Coastguard Worker continue;
821*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
822*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
823*77c1e3ccSAndroid Build Coastguard Worker }
824*77c1e3ccSAndroid Build Coastguard Worker }
825*77c1e3ccSAndroid Build Coastguard Worker }
826*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = sum;
827*77c1e3ccSAndroid Build Coastguard Worker }
828*77c1e3ccSAndroid Build Coastguard Worker }
829*77c1e3ccSAndroid Build Coastguard Worker }
830*77c1e3ccSAndroid Build Coastguard Worker break;
831*77c1e3ccSAndroid Build Coastguard Worker case PADDING_SAME_REPLICATE:
832*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
833*77c1e3ccSAndroid Build Coastguard Worker for (int u = 0; u < out_height; ++u) {
834*77c1e3ccSAndroid Build Coastguard Worker for (int v = 0; v < out_width; ++v) {
835*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
836*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
837*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
838*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
839*77c1e3ccSAndroid Build Coastguard Worker const int h =
840*77c1e3ccSAndroid Build Coastguard Worker u - l +
841*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_deconvolve(layer_config->filter_height,
842*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_height);
843*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
844*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
845*77c1e3ccSAndroid Build Coastguard Worker const int w =
846*77c1e3ccSAndroid Build Coastguard Worker v - m +
847*77c1e3ccSAndroid Build Coastguard Worker get_start_shift_deconvolve(layer_config->filter_width,
848*77c1e3ccSAndroid Build Coastguard Worker layer_config->skip_width);
849*77c1e3ccSAndroid Build Coastguard Worker if ((h % layer_config->skip_height) != 0 ||
850*77c1e3ccSAndroid Build Coastguard Worker (w % layer_config->skip_width) != 0)
851*77c1e3ccSAndroid Build Coastguard Worker continue;
852*77c1e3ccSAndroid Build Coastguard Worker const int ii =
853*77c1e3ccSAndroid Build Coastguard Worker CLAMPINDEX(h / layer_config->skip_height, in_height);
854*77c1e3ccSAndroid Build Coastguard Worker const int jj =
855*77c1e3ccSAndroid Build Coastguard Worker CLAMPINDEX(w / layer_config->skip_width, in_width);
856*77c1e3ccSAndroid Build Coastguard Worker assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
857*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
858*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
859*77c1e3ccSAndroid Build Coastguard Worker }
860*77c1e3ccSAndroid Build Coastguard Worker }
861*77c1e3ccSAndroid Build Coastguard Worker }
862*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = sum;
863*77c1e3ccSAndroid Build Coastguard Worker }
864*77c1e3ccSAndroid Build Coastguard Worker }
865*77c1e3ccSAndroid Build Coastguard Worker }
866*77c1e3ccSAndroid Build Coastguard Worker break;
867*77c1e3ccSAndroid Build Coastguard Worker case PADDING_VALID:
868*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < layer_config->out_channels; ++i) {
869*77c1e3ccSAndroid Build Coastguard Worker for (int u = 0; u < out_height; ++u) {
870*77c1e3ccSAndroid Build Coastguard Worker for (int v = 0; v < out_width; ++v) {
871*77c1e3ccSAndroid Build Coastguard Worker float sum = layer_config->bias[i];
872*77c1e3ccSAndroid Build Coastguard Worker for (int k = 0; k < layer_config->in_channels; ++k) {
873*77c1e3ccSAndroid Build Coastguard Worker int off = k * layer_config->out_channels + i;
874*77c1e3ccSAndroid Build Coastguard Worker for (int l = 0; l < layer_config->filter_height; ++l) {
875*77c1e3ccSAndroid Build Coastguard Worker const int h = u - l;
876*77c1e3ccSAndroid Build Coastguard Worker for (int m = 0; m < layer_config->filter_width;
877*77c1e3ccSAndroid Build Coastguard Worker ++m, off += cstep) {
878*77c1e3ccSAndroid Build Coastguard Worker const int w = v - m;
879*77c1e3ccSAndroid Build Coastguard Worker if ((h % layer_config->skip_height) != 0 ||
880*77c1e3ccSAndroid Build Coastguard Worker (w % layer_config->skip_width) != 0)
881*77c1e3ccSAndroid Build Coastguard Worker continue;
882*77c1e3ccSAndroid Build Coastguard Worker const int ii = h / layer_config->skip_height;
883*77c1e3ccSAndroid Build Coastguard Worker const int jj = w / layer_config->skip_width;
884*77c1e3ccSAndroid Build Coastguard Worker if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
885*77c1e3ccSAndroid Build Coastguard Worker continue;
886*77c1e3ccSAndroid Build Coastguard Worker sum += layer_config->weights[off] *
887*77c1e3ccSAndroid Build Coastguard Worker input[k][ii * in_stride + jj];
888*77c1e3ccSAndroid Build Coastguard Worker }
889*77c1e3ccSAndroid Build Coastguard Worker }
890*77c1e3ccSAndroid Build Coastguard Worker }
891*77c1e3ccSAndroid Build Coastguard Worker output[i][u * out_stride + v] = sum;
892*77c1e3ccSAndroid Build Coastguard Worker }
893*77c1e3ccSAndroid Build Coastguard Worker }
894*77c1e3ccSAndroid Build Coastguard Worker }
895*77c1e3ccSAndroid Build Coastguard Worker break;
896*77c1e3ccSAndroid Build Coastguard Worker default: assert(0 && "Unknown padding type");
897*77c1e3ccSAndroid Build Coastguard Worker }
898*77c1e3ccSAndroid Build Coastguard Worker }
899*77c1e3ccSAndroid Build Coastguard Worker
av1_cnn_predict_c(const float ** input,int in_width,int in_height,int in_stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output_struct)900*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_c(const float **input, int in_width, int in_height,
901*77c1e3ccSAndroid Build Coastguard Worker int in_stride, const CNN_CONFIG *cnn_config,
902*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data,
903*77c1e3ccSAndroid Build Coastguard Worker CNN_MULTI_OUT *output_struct) {
904*77c1e3ccSAndroid Build Coastguard Worker bool success = false;
905*77c1e3ccSAndroid Build Coastguard Worker TENSOR tensor1[CNN_MAX_BRANCHES] = { { 0 } };
906*77c1e3ccSAndroid Build Coastguard Worker TENSOR tensor2[CNN_MAX_BRANCHES] = { { 0 } };
907*77c1e3ccSAndroid Build Coastguard Worker
908*77c1e3ccSAndroid Build Coastguard Worker float **output[CNN_MAX_BRANCHES];
909*77c1e3ccSAndroid Build Coastguard Worker const int *out_chs = output_struct->output_channels;
910*77c1e3ccSAndroid Build Coastguard Worker output[0] = output_struct->output_buffer;
911*77c1e3ccSAndroid Build Coastguard Worker for (int out_idx = 1; out_idx < output_struct->num_outputs; out_idx++) {
912*77c1e3ccSAndroid Build Coastguard Worker output[out_idx] = output[out_idx - 1] + out_chs[out_idx - 1];
913*77c1e3ccSAndroid Build Coastguard Worker }
914*77c1e3ccSAndroid Build Coastguard Worker
915*77c1e3ccSAndroid Build Coastguard Worker int i_width = in_width;
916*77c1e3ccSAndroid Build Coastguard Worker int i_height = in_height;
917*77c1e3ccSAndroid Build Coastguard Worker int o_width = 0, o_height = 0;
918*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
919*77c1e3ccSAndroid Build Coastguard Worker init_tensor(&tensor1[b]);
920*77c1e3ccSAndroid Build Coastguard Worker init_tensor(&tensor2[b]);
921*77c1e3ccSAndroid Build Coastguard Worker }
922*77c1e3ccSAndroid Build Coastguard Worker
923*77c1e3ccSAndroid Build Coastguard Worker const int *out_stride = output_struct->output_strides;
924*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
925*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
926*77c1e3ccSAndroid Build Coastguard Worker const int branch = layer_config->branch;
927*77c1e3ccSAndroid Build Coastguard Worker const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
928*77c1e3ccSAndroid Build Coastguard Worker
929*77c1e3ccSAndroid Build Coastguard Worker // Allocate input tensor
930*77c1e3ccSAndroid Build Coastguard Worker if (layer == 0) { // First layer
931*77c1e3ccSAndroid Build Coastguard Worker assert(branch == 0); // First layer must be primary branch
932*77c1e3ccSAndroid Build Coastguard Worker assign_tensor(&tensor1[branch], (float **)input,
933*77c1e3ccSAndroid Build Coastguard Worker layer_config->in_channels, in_width, in_height, in_stride);
934*77c1e3ccSAndroid Build Coastguard Worker } else { // Non-first layer
935*77c1e3ccSAndroid Build Coastguard Worker // Swap tensor1 and tensor2
936*77c1e3ccSAndroid Build Coastguard Worker swap_tensor(&tensor1[branch], &tensor2[branch]);
937*77c1e3ccSAndroid Build Coastguard Worker
938*77c1e3ccSAndroid Build Coastguard Worker i_width = tensor1[branch].width;
939*77c1e3ccSAndroid Build Coastguard Worker i_height = tensor1[branch].height;
940*77c1e3ccSAndroid Build Coastguard Worker }
941*77c1e3ccSAndroid Build Coastguard Worker
942*77c1e3ccSAndroid Build Coastguard Worker // Allocate output tensor
943*77c1e3ccSAndroid Build Coastguard Worker av1_find_cnn_layer_output_size(i_width, i_height, layer_config, &o_width,
944*77c1e3ccSAndroid Build Coastguard Worker &o_height);
945*77c1e3ccSAndroid Build Coastguard Worker const int output_num = layer_config->output_num;
946*77c1e3ccSAndroid Build Coastguard Worker if (output_num == -1) { // Non-output layer
947*77c1e3ccSAndroid Build Coastguard Worker if (!realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
948*77c1e3ccSAndroid Build Coastguard Worker o_height)) {
949*77c1e3ccSAndroid Build Coastguard Worker goto Error;
950*77c1e3ccSAndroid Build Coastguard Worker }
951*77c1e3ccSAndroid Build Coastguard Worker } else { // Output layer
952*77c1e3ccSAndroid Build Coastguard Worker free_tensor(&tensor2[branch]);
953*77c1e3ccSAndroid Build Coastguard Worker assign_tensor(&tensor2[branch], output[output_num],
954*77c1e3ccSAndroid Build Coastguard Worker layer_config->out_channels, o_width, o_height,
955*77c1e3ccSAndroid Build Coastguard Worker out_stride[output_num]);
956*77c1e3ccSAndroid Build Coastguard Worker }
957*77c1e3ccSAndroid Build Coastguard Worker
958*77c1e3ccSAndroid Build Coastguard Worker // If we are combining branches make sure that the branch to combine
959*77c1e3ccSAndroid Build Coastguard Worker // is different from the current branch.
960*77c1e3ccSAndroid Build Coastguard Worker assert(IMPLIES(layer_config->branch_combine_type != BRANCH_NOC,
961*77c1e3ccSAndroid Build Coastguard Worker !(branch_config->branches_to_combine & (1 << branch))));
962*77c1e3ccSAndroid Build Coastguard Worker
963*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_INPUT) {
964*77c1e3ccSAndroid Build Coastguard Worker if (!copy_active_tensor_to_branches(&tensor1[branch], layer_config,
965*77c1e3ccSAndroid Build Coastguard Worker branch, tensor2)) {
966*77c1e3ccSAndroid Build Coastguard Worker goto Error;
967*77c1e3ccSAndroid Build Coastguard Worker }
968*77c1e3ccSAndroid Build Coastguard Worker }
969*77c1e3ccSAndroid Build Coastguard Worker // Check consistency of input and output channels
970*77c1e3ccSAndroid Build Coastguard Worker assert(tensor1[branch].channels == layer_config->in_channels);
971*77c1e3ccSAndroid Build Coastguard Worker assert(tensor2[branch].channels == layer_config->out_channels);
972*77c1e3ccSAndroid Build Coastguard Worker
973*77c1e3ccSAndroid Build Coastguard Worker // Convolve/Deconvolve
974*77c1e3ccSAndroid Build Coastguard Worker if (!cnn_config->layer_config[layer].deconvolve) {
975*77c1e3ccSAndroid Build Coastguard Worker if (thread_data->num_workers > 1) {
976*77c1e3ccSAndroid Build Coastguard Worker convolve_layer_mt((const float **)tensor1[branch].buf,
977*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].width, tensor1[branch].height,
978*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].stride, layer_config, thread_data,
979*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].buf, tensor2[branch].stride);
980*77c1e3ccSAndroid Build Coastguard Worker } else {
981*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_convolve((const float **)tensor1[branch].buf,
982*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].width, tensor1[branch].height,
983*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].stride, layer_config,
984*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].buf, tensor2[branch].stride, 0, 1);
985*77c1e3ccSAndroid Build Coastguard Worker }
986*77c1e3ccSAndroid Build Coastguard Worker } else {
987*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_deconvolve((const float **)tensor1[branch].buf,
988*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].width, tensor1[branch].height,
989*77c1e3ccSAndroid Build Coastguard Worker tensor1[branch].stride, layer_config,
990*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].buf, tensor2[branch].stride);
991*77c1e3ccSAndroid Build Coastguard Worker }
992*77c1e3ccSAndroid Build Coastguard Worker
993*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
994*77c1e3ccSAndroid Build Coastguard Worker if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
995*77c1e3ccSAndroid Build Coastguard Worker branch, tensor2)) {
996*77c1e3ccSAndroid Build Coastguard Worker goto Error;
997*77c1e3ccSAndroid Build Coastguard Worker }
998*77c1e3ccSAndroid Build Coastguard Worker }
999*77c1e3ccSAndroid Build Coastguard Worker
1000*77c1e3ccSAndroid Build Coastguard Worker // Add tensors from other branches if needed
1001*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_combine_type == BRANCH_ADD) {
1002*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1003*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1004*77c1e3ccSAndroid Build Coastguard Worker assert(check_tensor_equal_size(&tensor2[b], &tensor2[branch]));
1005*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_add(tensor2[branch].buf, tensor2[branch].channels,
1006*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].width, tensor2[branch].height,
1007*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].stride, (const float **)tensor2[b].buf);
1008*77c1e3ccSAndroid Build Coastguard Worker }
1009*77c1e3ccSAndroid Build Coastguard Worker }
1010*77c1e3ccSAndroid Build Coastguard Worker }
1011*77c1e3ccSAndroid Build Coastguard Worker
1012*77c1e3ccSAndroid Build Coastguard Worker // Non-linearity
1013*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_activate(tensor2[branch].buf, tensor2[branch].channels,
1014*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].width, tensor2[branch].height,
1015*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].stride, layer_config->activation);
1016*77c1e3ccSAndroid Build Coastguard Worker
1017*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->bn_params.bn_gamma) {
1018*77c1e3ccSAndroid Build Coastguard Worker av1_cnn_batchnorm(
1019*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].buf, tensor2[branch].channels, tensor2[branch].width,
1020*77c1e3ccSAndroid Build Coastguard Worker tensor2[branch].height, tensor2[branch].stride,
1021*77c1e3ccSAndroid Build Coastguard Worker layer_config->bn_params.bn_gamma, layer_config->bn_params.bn_beta,
1022*77c1e3ccSAndroid Build Coastguard Worker layer_config->bn_params.bn_mean, layer_config->bn_params.bn_std);
1023*77c1e3ccSAndroid Build Coastguard Worker }
1024*77c1e3ccSAndroid Build Coastguard Worker
1025*77c1e3ccSAndroid Build Coastguard Worker // Concatenate tensors
1026*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_combine_type == BRANCH_CAT) {
1027*77c1e3ccSAndroid Build Coastguard Worker if (output_num == -1) { // Non-output layer
1028*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1029*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1030*77c1e3ccSAndroid Build Coastguard Worker assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1031*77c1e3ccSAndroid Build Coastguard Worker assert(tensor2[b].channels > 0);
1032*77c1e3ccSAndroid Build Coastguard Worker if (!concat_tensor(&tensor2[b], &tensor2[branch])) goto Error;
1033*77c1e3ccSAndroid Build Coastguard Worker }
1034*77c1e3ccSAndroid Build Coastguard Worker }
1035*77c1e3ccSAndroid Build Coastguard Worker } else { // Output layer
1036*77c1e3ccSAndroid Build Coastguard Worker const int existing_channels = tensor2[branch].channels;
1037*77c1e3ccSAndroid Build Coastguard Worker int num_chs = existing_channels;
1038*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1039*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1040*77c1e3ccSAndroid Build Coastguard Worker assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1041*77c1e3ccSAndroid Build Coastguard Worker // Needed only to assign the new channel buffers
1042*77c1e3ccSAndroid Build Coastguard Worker num_chs += tensor2[b].channels;
1043*77c1e3ccSAndroid Build Coastguard Worker }
1044*77c1e3ccSAndroid Build Coastguard Worker }
1045*77c1e3ccSAndroid Build Coastguard Worker assign_tensor(&tensor2[branch], output[output_num], num_chs, o_width,
1046*77c1e3ccSAndroid Build Coastguard Worker o_height, out_stride[output_num]);
1047*77c1e3ccSAndroid Build Coastguard Worker
1048*77c1e3ccSAndroid Build Coastguard Worker num_chs = existing_channels;
1049*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1050*77c1e3ccSAndroid Build Coastguard Worker if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
1051*77c1e3ccSAndroid Build Coastguard Worker assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
1052*77c1e3ccSAndroid Build Coastguard Worker // Needed only to assign the new channel buffers
1053*77c1e3ccSAndroid Build Coastguard Worker copy_tensor(&tensor2[b], tensor2[b].channels, num_chs,
1054*77c1e3ccSAndroid Build Coastguard Worker &tensor2[branch]);
1055*77c1e3ccSAndroid Build Coastguard Worker num_chs += tensor2[b].channels;
1056*77c1e3ccSAndroid Build Coastguard Worker }
1057*77c1e3ccSAndroid Build Coastguard Worker }
1058*77c1e3ccSAndroid Build Coastguard Worker }
1059*77c1e3ccSAndroid Build Coastguard Worker }
1060*77c1e3ccSAndroid Build Coastguard Worker
1061*77c1e3ccSAndroid Build Coastguard Worker if (layer_config->branch_copy_type == BRANCH_COMBINED) {
1062*77c1e3ccSAndroid Build Coastguard Worker if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
1063*77c1e3ccSAndroid Build Coastguard Worker branch, tensor2)) {
1064*77c1e3ccSAndroid Build Coastguard Worker goto Error;
1065*77c1e3ccSAndroid Build Coastguard Worker }
1066*77c1e3ccSAndroid Build Coastguard Worker }
1067*77c1e3ccSAndroid Build Coastguard Worker }
1068*77c1e3ccSAndroid Build Coastguard Worker
1069*77c1e3ccSAndroid Build Coastguard Worker success = true;
1070*77c1e3ccSAndroid Build Coastguard Worker Error:
1071*77c1e3ccSAndroid Build Coastguard Worker for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
1072*77c1e3ccSAndroid Build Coastguard Worker free_tensor(&tensor1[b]);
1073*77c1e3ccSAndroid Build Coastguard Worker free_tensor(&tensor2[b]);
1074*77c1e3ccSAndroid Build Coastguard Worker }
1075*77c1e3ccSAndroid Build Coastguard Worker return success;
1076*77c1e3ccSAndroid Build Coastguard Worker }
1077*77c1e3ccSAndroid Build Coastguard Worker
1078*77c1e3ccSAndroid Build Coastguard Worker // Assume output already has proper allocation
1079*77c1e3ccSAndroid Build Coastguard Worker // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out(uint8_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output)1080*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
1081*77c1e3ccSAndroid Build Coastguard Worker int stride, const CNN_CONFIG *cnn_config,
1082*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data,
1083*77c1e3ccSAndroid Build Coastguard Worker CNN_MULTI_OUT *output) {
1084*77c1e3ccSAndroid Build Coastguard Worker const float max_val = 255.0;
1085*77c1e3ccSAndroid Build Coastguard Worker
1086*77c1e3ccSAndroid Build Coastguard Worker const int in_width = width + 2 * cnn_config->ext_width;
1087*77c1e3ccSAndroid Build Coastguard Worker const int in_height = height + 2 * cnn_config->ext_height;
1088*77c1e3ccSAndroid Build Coastguard Worker const int in_channels = cnn_config->layer_config[0].in_channels;
1089*77c1e3ccSAndroid Build Coastguard Worker float *inputs[CNN_MAX_CHANNELS];
1090*77c1e3ccSAndroid Build Coastguard Worker float *input_ =
1091*77c1e3ccSAndroid Build Coastguard Worker (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1092*77c1e3ccSAndroid Build Coastguard Worker if (!input_) return false;
1093*77c1e3ccSAndroid Build Coastguard Worker const int in_stride = in_width;
1094*77c1e3ccSAndroid Build Coastguard Worker
1095*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < in_channels; ++c) {
1096*77c1e3ccSAndroid Build Coastguard Worker inputs[c] = input_ + c * in_stride * in_height;
1097*77c1e3ccSAndroid Build Coastguard Worker float *input =
1098*77c1e3ccSAndroid Build Coastguard Worker inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1099*77c1e3ccSAndroid Build Coastguard Worker
1100*77c1e3ccSAndroid Build Coastguard Worker if (cnn_config->strict_bounds) {
1101*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i)
1102*77c1e3ccSAndroid Build Coastguard Worker for (int j = 0; j < width; ++j)
1103*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1104*77c1e3ccSAndroid Build Coastguard Worker // extend left and right
1105*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i) {
1106*77c1e3ccSAndroid Build Coastguard Worker for (int j = -cnn_config->ext_width; j < 0; ++j)
1107*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = input[i * in_stride];
1108*77c1e3ccSAndroid Build Coastguard Worker for (int j = width; j < width + cnn_config->ext_width; ++j)
1109*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = input[i * in_stride + width - 1];
1110*77c1e3ccSAndroid Build Coastguard Worker }
1111*77c1e3ccSAndroid Build Coastguard Worker // extend top and bottom
1112*77c1e3ccSAndroid Build Coastguard Worker for (int i = -cnn_config->ext_height; i < 0; ++i)
1113*77c1e3ccSAndroid Build Coastguard Worker memcpy(&input[i * in_stride - cnn_config->ext_width],
1114*77c1e3ccSAndroid Build Coastguard Worker &input[-cnn_config->ext_width], in_width * sizeof(*input));
1115*77c1e3ccSAndroid Build Coastguard Worker for (int i = height; i < height + cnn_config->ext_height; ++i)
1116*77c1e3ccSAndroid Build Coastguard Worker memcpy(&input[i * in_stride - cnn_config->ext_width],
1117*77c1e3ccSAndroid Build Coastguard Worker &input[(height - 1) * in_stride - cnn_config->ext_width],
1118*77c1e3ccSAndroid Build Coastguard Worker in_width * sizeof(*input));
1119*77c1e3ccSAndroid Build Coastguard Worker } else {
1120*77c1e3ccSAndroid Build Coastguard Worker for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1121*77c1e3ccSAndroid Build Coastguard Worker ++i)
1122*77c1e3ccSAndroid Build Coastguard Worker for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1123*77c1e3ccSAndroid Build Coastguard Worker ++j)
1124*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1125*77c1e3ccSAndroid Build Coastguard Worker }
1126*77c1e3ccSAndroid Build Coastguard Worker }
1127*77c1e3ccSAndroid Build Coastguard Worker bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
1128*77c1e3ccSAndroid Build Coastguard Worker in_stride, cnn_config, thread_data, output);
1129*77c1e3ccSAndroid Build Coastguard Worker
1130*77c1e3ccSAndroid Build Coastguard Worker aom_free(input_);
1131*77c1e3ccSAndroid Build Coastguard Worker return success;
1132*77c1e3ccSAndroid Build Coastguard Worker }
1133*77c1e3ccSAndroid Build Coastguard Worker
1134*77c1e3ccSAndroid Build Coastguard Worker // Assume output already has proper allocation
1135*77c1e3ccSAndroid Build Coastguard Worker // Assume input image buffers all have same resolution and strides
av1_cnn_predict_img_multi_out_highbd(uint16_t ** dgd,int width,int height,int stride,const CNN_CONFIG * cnn_config,const CNN_THREAD_DATA * thread_data,int bit_depth,CNN_MULTI_OUT * output)1136*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
1137*77c1e3ccSAndroid Build Coastguard Worker int stride,
1138*77c1e3ccSAndroid Build Coastguard Worker const CNN_CONFIG *cnn_config,
1139*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data,
1140*77c1e3ccSAndroid Build Coastguard Worker int bit_depth,
1141*77c1e3ccSAndroid Build Coastguard Worker CNN_MULTI_OUT *output) {
1142*77c1e3ccSAndroid Build Coastguard Worker const float max_val = (float)((1 << bit_depth) - 1);
1143*77c1e3ccSAndroid Build Coastguard Worker
1144*77c1e3ccSAndroid Build Coastguard Worker const int in_width = width + 2 * cnn_config->ext_width;
1145*77c1e3ccSAndroid Build Coastguard Worker const int in_height = height + 2 * cnn_config->ext_height;
1146*77c1e3ccSAndroid Build Coastguard Worker const int in_channels = cnn_config->layer_config[0].in_channels;
1147*77c1e3ccSAndroid Build Coastguard Worker float *inputs[CNN_MAX_CHANNELS];
1148*77c1e3ccSAndroid Build Coastguard Worker float *input_ =
1149*77c1e3ccSAndroid Build Coastguard Worker (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
1150*77c1e3ccSAndroid Build Coastguard Worker if (!input_) return false;
1151*77c1e3ccSAndroid Build Coastguard Worker const int in_stride = in_width;
1152*77c1e3ccSAndroid Build Coastguard Worker
1153*77c1e3ccSAndroid Build Coastguard Worker for (int c = 0; c < in_channels; ++c) {
1154*77c1e3ccSAndroid Build Coastguard Worker inputs[c] = input_ + c * in_stride * in_height;
1155*77c1e3ccSAndroid Build Coastguard Worker float *input =
1156*77c1e3ccSAndroid Build Coastguard Worker inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
1157*77c1e3ccSAndroid Build Coastguard Worker
1158*77c1e3ccSAndroid Build Coastguard Worker if (cnn_config->strict_bounds) {
1159*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i)
1160*77c1e3ccSAndroid Build Coastguard Worker for (int j = 0; j < width; ++j)
1161*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1162*77c1e3ccSAndroid Build Coastguard Worker // extend left and right
1163*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < height; ++i) {
1164*77c1e3ccSAndroid Build Coastguard Worker for (int j = -cnn_config->ext_width; j < 0; ++j)
1165*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = input[i * in_stride];
1166*77c1e3ccSAndroid Build Coastguard Worker for (int j = width; j < width + cnn_config->ext_width; ++j)
1167*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = input[i * in_stride + width - 1];
1168*77c1e3ccSAndroid Build Coastguard Worker }
1169*77c1e3ccSAndroid Build Coastguard Worker // extend top and bottom
1170*77c1e3ccSAndroid Build Coastguard Worker for (int i = -cnn_config->ext_height; i < 0; ++i)
1171*77c1e3ccSAndroid Build Coastguard Worker memcpy(&input[i * in_stride - cnn_config->ext_width],
1172*77c1e3ccSAndroid Build Coastguard Worker &input[-cnn_config->ext_width], in_width * sizeof(*input));
1173*77c1e3ccSAndroid Build Coastguard Worker for (int i = height; i < height + cnn_config->ext_height; ++i)
1174*77c1e3ccSAndroid Build Coastguard Worker memcpy(&input[i * in_stride - cnn_config->ext_width],
1175*77c1e3ccSAndroid Build Coastguard Worker &input[(height - 1) * in_stride - cnn_config->ext_width],
1176*77c1e3ccSAndroid Build Coastguard Worker in_width * sizeof(*input));
1177*77c1e3ccSAndroid Build Coastguard Worker } else {
1178*77c1e3ccSAndroid Build Coastguard Worker for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
1179*77c1e3ccSAndroid Build Coastguard Worker ++i)
1180*77c1e3ccSAndroid Build Coastguard Worker for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
1181*77c1e3ccSAndroid Build Coastguard Worker ++j)
1182*77c1e3ccSAndroid Build Coastguard Worker input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
1183*77c1e3ccSAndroid Build Coastguard Worker }
1184*77c1e3ccSAndroid Build Coastguard Worker }
1185*77c1e3ccSAndroid Build Coastguard Worker
1186*77c1e3ccSAndroid Build Coastguard Worker bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
1187*77c1e3ccSAndroid Build Coastguard Worker in_stride, cnn_config, thread_data, output);
1188*77c1e3ccSAndroid Build Coastguard Worker
1189*77c1e3ccSAndroid Build Coastguard Worker aom_free(input_);
1190*77c1e3ccSAndroid Build Coastguard Worker return success;
1191*77c1e3ccSAndroid Build Coastguard Worker }
1192