xref: /aosp_15_r20/external/libaom/av1/encoder/cnn.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #ifndef AOM_AV1_ENCODER_CNN_H_
13*77c1e3ccSAndroid Build Coastguard Worker #define AOM_AV1_ENCODER_CNN_H_
14*77c1e3ccSAndroid Build Coastguard Worker 
15*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus
16*77c1e3ccSAndroid Build Coastguard Worker extern "C" {
17*77c1e3ccSAndroid Build Coastguard Worker #endif
18*77c1e3ccSAndroid Build Coastguard Worker 
19*77c1e3ccSAndroid Build Coastguard Worker #include <math.h>
20*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h>
21*77c1e3ccSAndroid Build Coastguard Worker 
22*77c1e3ccSAndroid Build Coastguard Worker #include "aom_util/aom_thread.h"
23*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
24*77c1e3ccSAndroid Build Coastguard Worker 
25*77c1e3ccSAndroid Build Coastguard Worker struct AV1Common;
26*77c1e3ccSAndroid Build Coastguard Worker 
27*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_HIDDEN_LAYERS 64
28*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_LAYERS (CNN_MAX_HIDDEN_LAYERS + 1)
29*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_CHANNELS 256
30*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_BRANCHES 4
31*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_THREADS 32
32*77c1e3ccSAndroid Build Coastguard Worker 
33*77c1e3ccSAndroid Build Coastguard Worker #define NO_BRANCH_CONFIG \
34*77c1e3ccSAndroid Build Coastguard Worker   { 0, 0, 0 }
35*77c1e3ccSAndroid Build Coastguard Worker #define NO_BN_PARAMS \
36*77c1e3ccSAndroid Build Coastguard Worker   { NULL, NULL, NULL, NULL }
37*77c1e3ccSAndroid Build Coastguard Worker 
38*77c1e3ccSAndroid Build Coastguard Worker enum {
39*77c1e3ccSAndroid Build Coastguard Worker   PADDING_SAME_ZERO,       // tensorflow's SAME padding with pixels outside
40*77c1e3ccSAndroid Build Coastguard Worker                            // the image area assumed to be 0 (default)
41*77c1e3ccSAndroid Build Coastguard Worker   PADDING_SAME_REPLICATE,  // tensorflow's SAME padding with pixels outside
42*77c1e3ccSAndroid Build Coastguard Worker                            // the image area replicated from closest edge
43*77c1e3ccSAndroid Build Coastguard Worker   PADDING_VALID            // tensorflow's VALID padding
44*77c1e3ccSAndroid Build Coastguard Worker } UENUM1BYTE(PADDING_TYPE);
45*77c1e3ccSAndroid Build Coastguard Worker 
46*77c1e3ccSAndroid Build Coastguard Worker // enum { NONE, RELU, SOFTSIGN } UENUM1BYTE(ACTIVATION);
47*77c1e3ccSAndroid Build Coastguard Worker 
48*77c1e3ccSAndroid Build Coastguard Worker // Times when input tensor may be copied to branches given in input_to_branches.
49*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_NO_COPY: doesn't copy any tensor.
50*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_INPUT: copies the input tensor to branches.
51*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_OUTPUT: copies the convolved tensor to branches.
52*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_COMBINED: copies the combined (after convolving and branch combining)
53*77c1e3ccSAndroid Build Coastguard Worker //   tensor. If no combinations happen at this layer, then this option
54*77c1e3ccSAndroid Build Coastguard Worker //   has the same effect as COPY_OUTPUT.
55*77c1e3ccSAndroid Build Coastguard Worker enum {
56*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_NO_COPY,
57*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_INPUT,
58*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_OUTPUT,
59*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_COMBINED
60*77c1e3ccSAndroid Build Coastguard Worker } UENUM1BYTE(BRANCH_COPY);
61*77c1e3ccSAndroid Build Coastguard Worker 
62*77c1e3ccSAndroid Build Coastguard Worker // Types of combining branches with output of current layer:
63*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_NOC: no branch combining
64*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_ADD: Add previously stored branch tensor to output of layer
65*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_CAT: Concatenate branch tensor to output of layer
66*77c1e3ccSAndroid Build Coastguard Worker enum { BRANCH_NOC, BRANCH_ADD, BRANCH_CAT } UENUM1BYTE(BRANCH_COMBINE);
67*77c1e3ccSAndroid Build Coastguard Worker 
68*77c1e3ccSAndroid Build Coastguard Worker // The parameters used to scale each channel in batch
69*77c1e3ccSAndroid Build Coastguard Worker // normalization. The processing in done on a per-channel basis.
70*77c1e3ccSAndroid Build Coastguard Worker // e.g. bn_mean[c] is the mean for all pixels in channel c. This
71*77c1e3ccSAndroid Build Coastguard Worker // is always applied after activation. The output is given by
72*77c1e3ccSAndroid Build Coastguard Worker // out[c,i,j] = norm[c,i,j] * bn_gamma[c] + bn_beta[c] where
73*77c1e3ccSAndroid Build Coastguard Worker // norm[c,i,j] = (in[c,i,j] - bn_mean[c]) / bn_std[c]
74*77c1e3ccSAndroid Build Coastguard Worker // here we assume that the effect of variance_epsilon is already
75*77c1e3ccSAndroid Build Coastguard Worker // taken into account when bn_std is calculated. The pointers
76*77c1e3ccSAndroid Build Coastguard Worker // needs to be either all zero or all valid. If all zero, then
77*77c1e3ccSAndroid Build Coastguard Worker // batchnorm is disabled, else batchnorm is applied.
78*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BATCHNORM_PARAMS {
79*77c1e3ccSAndroid Build Coastguard Worker   const float *bn_gamma;
80*77c1e3ccSAndroid Build Coastguard Worker   const float *bn_beta;
81*77c1e3ccSAndroid Build Coastguard Worker   const float *bn_mean;
82*77c1e3ccSAndroid Build Coastguard Worker   const float *bn_std;
83*77c1e3ccSAndroid Build Coastguard Worker };
84*77c1e3ccSAndroid Build Coastguard Worker 
85*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BRANCH_CONFIG {
86*77c1e3ccSAndroid Build Coastguard Worker   int input_to_branches;  // If nonzero, copy the active tensor to the current
87*77c1e3ccSAndroid Build Coastguard Worker   // layer and store for future use in branches
88*77c1e3ccSAndroid Build Coastguard Worker   // specified in the field as a binary mask. For
89*77c1e3ccSAndroid Build Coastguard Worker   // example, if input_to_branch = 0x06, it means the
90*77c1e3ccSAndroid Build Coastguard Worker   // input tensor to the current branch is copied to
91*77c1e3ccSAndroid Build Coastguard Worker   // branches 1 and 2 (where 0 represents the primary
92*77c1e3ccSAndroid Build Coastguard Worker   // branch). One restriction is that the mask
93*77c1e3ccSAndroid Build Coastguard Worker   // cannot indicate copying to the current branch.
94*77c1e3ccSAndroid Build Coastguard Worker   // If greater than 0, only copies the channels up
95*77c1e3ccSAndroid Build Coastguard Worker   // to the given index.
96*77c1e3ccSAndroid Build Coastguard Worker   int channels_to_copy;  // Within the layer, input a copy of active
97*77c1e3ccSAndroid Build Coastguard Worker   // tensor to branches given in input_to_branches.
98*77c1e3ccSAndroid Build Coastguard Worker   int branches_to_combine;  // mask of branches to combine with output of
99*77c1e3ccSAndroid Build Coastguard Worker   // current layer, if
100*77c1e3ccSAndroid Build Coastguard Worker   // branch_combine_type != BRANCH_NOC
101*77c1e3ccSAndroid Build Coastguard Worker   // For example, if branches_to_combine = 0x0A,
102*77c1e3ccSAndroid Build Coastguard Worker   // it means that braches 1 and 3 are combined
103*77c1e3ccSAndroid Build Coastguard Worker   // with the current branch.
104*77c1e3ccSAndroid Build Coastguard Worker };
105*77c1e3ccSAndroid Build Coastguard Worker 
106*77c1e3ccSAndroid Build Coastguard Worker struct CNN_LAYER_CONFIG {
107*77c1e3ccSAndroid Build Coastguard Worker   int in_channels;
108*77c1e3ccSAndroid Build Coastguard Worker   int filter_width;
109*77c1e3ccSAndroid Build Coastguard Worker   int filter_height;
110*77c1e3ccSAndroid Build Coastguard Worker   int out_channels;
111*77c1e3ccSAndroid Build Coastguard Worker   int skip_width;
112*77c1e3ccSAndroid Build Coastguard Worker   int skip_height;
113*77c1e3ccSAndroid Build Coastguard Worker   int maxpool;            // whether to use maxpool or not (only effective when
114*77c1e3ccSAndroid Build Coastguard Worker                           // skip width or skip_height are > 1)
115*77c1e3ccSAndroid Build Coastguard Worker   const float *weights;   // array of length filter_height x filter_width x
116*77c1e3ccSAndroid Build Coastguard Worker                           // in_channels x out_channels where the inner-most
117*77c1e3ccSAndroid Build Coastguard Worker                           // scan is out_channels and the outer most scan is
118*77c1e3ccSAndroid Build Coastguard Worker                           // filter_height.
119*77c1e3ccSAndroid Build Coastguard Worker   const float *bias;      // array of length out_channels
120*77c1e3ccSAndroid Build Coastguard Worker   PADDING_TYPE pad;       // padding type
121*77c1e3ccSAndroid Build Coastguard Worker   ACTIVATION activation;  // the activation function to use after convolution
122*77c1e3ccSAndroid Build Coastguard Worker   int deconvolve;         // whether this is a deconvolution layer.
123*77c1e3ccSAndroid Build Coastguard Worker                           // 0: If skip_width or skip_height are > 1, then we
124*77c1e3ccSAndroid Build Coastguard Worker                           // reduce resolution
125*77c1e3ccSAndroid Build Coastguard Worker                           // 1: If skip_width or skip_height are > 1, then we
126*77c1e3ccSAndroid Build Coastguard Worker                           // increase resolution
127*77c1e3ccSAndroid Build Coastguard Worker   int branch;             // branch index in [0, CNN_MAX_BRANCHES - 1], where
128*77c1e3ccSAndroid Build Coastguard Worker                           // 0 refers to the primary branch.
129*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_COPY branch_copy_type;
130*77c1e3ccSAndroid Build Coastguard Worker   BRANCH_COMBINE branch_combine_type;
131*77c1e3ccSAndroid Build Coastguard Worker   struct CNN_BRANCH_CONFIG branch_config;
132*77c1e3ccSAndroid Build Coastguard Worker   struct CNN_BATCHNORM_PARAMS
133*77c1e3ccSAndroid Build Coastguard Worker       bn_params;   // A struct that contains the parameters
134*77c1e3ccSAndroid Build Coastguard Worker                    // used for batch normalization.
135*77c1e3ccSAndroid Build Coastguard Worker   int output_num;  // The output buffer idx to which the layer output is
136*77c1e3ccSAndroid Build Coastguard Worker                    // written. Set to -1 to disable writing it to the output. In
137*77c1e3ccSAndroid Build Coastguard Worker                    // the case that branch_combine_type is BRANCH_CAT, all
138*77c1e3ccSAndroid Build Coastguard Worker                    // concatenated channels will be written to output. In the
139*77c1e3ccSAndroid Build Coastguard Worker                    // case of BRANCH_ADD, the output will be the result of
140*77c1e3ccSAndroid Build Coastguard Worker                    // summation.
141*77c1e3ccSAndroid Build Coastguard Worker };
142*77c1e3ccSAndroid Build Coastguard Worker 
143*77c1e3ccSAndroid Build Coastguard Worker struct CNN_CONFIG {
144*77c1e3ccSAndroid Build Coastguard Worker   int num_layers;  // number of CNN layers ( = number of hidden layers + 1)
145*77c1e3ccSAndroid Build Coastguard Worker   int is_residue;  // whether the output activation is a residue
146*77c1e3ccSAndroid Build Coastguard Worker   int ext_width, ext_height;  // extension horizontally and vertically
147*77c1e3ccSAndroid Build Coastguard Worker   int strict_bounds;          // whether the input bounds are strict or not.
148*77c1e3ccSAndroid Build Coastguard Worker                               // If strict, the extension area is filled by
149*77c1e3ccSAndroid Build Coastguard Worker                               // replication; if not strict, image data is
150*77c1e3ccSAndroid Build Coastguard Worker                               // assumed available beyond the bounds.
151*77c1e3ccSAndroid Build Coastguard Worker   CNN_LAYER_CONFIG layer_config[CNN_MAX_LAYERS];
152*77c1e3ccSAndroid Build Coastguard Worker };
153*77c1e3ccSAndroid Build Coastguard Worker 
154*77c1e3ccSAndroid Build Coastguard Worker struct CNN_THREAD_DATA {
155*77c1e3ccSAndroid Build Coastguard Worker   int num_workers;
156*77c1e3ccSAndroid Build Coastguard Worker   AVxWorker *workers;
157*77c1e3ccSAndroid Build Coastguard Worker };
158*77c1e3ccSAndroid Build Coastguard Worker 
159*77c1e3ccSAndroid Build Coastguard Worker struct CNN_MULTI_OUT {
160*77c1e3ccSAndroid Build Coastguard Worker   int num_outputs;
161*77c1e3ccSAndroid Build Coastguard Worker   const int *output_channels;
162*77c1e3ccSAndroid Build Coastguard Worker   const int *output_strides;
163*77c1e3ccSAndroid Build Coastguard Worker   float **output_buffer;
164*77c1e3ccSAndroid Build Coastguard Worker };
165*77c1e3ccSAndroid Build Coastguard Worker 
166*77c1e3ccSAndroid Build Coastguard Worker // Function to return size of output
167*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_output_size(int in_width, int in_height,
168*77c1e3ccSAndroid Build Coastguard Worker                               const CNN_CONFIG *cnn_config, int *out_width,
169*77c1e3ccSAndroid Build Coastguard Worker                               int *out_height, int *out_channels);
170*77c1e3ccSAndroid Build Coastguard Worker 
171*77c1e3ccSAndroid Build Coastguard Worker // Function to return output width and output height of given layer.
172*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_layer_output_size(int in_width, int in_height,
173*77c1e3ccSAndroid Build Coastguard Worker                                     const CNN_LAYER_CONFIG *layer_config,
174*77c1e3ccSAndroid Build Coastguard Worker                                     int *out_width, int *out_height);
175*77c1e3ccSAndroid Build Coastguard Worker 
176*77c1e3ccSAndroid Build Coastguard Worker // Prediction functions from set of input image buffers. This function supports
177*77c1e3ccSAndroid Build Coastguard Worker // CNN with multiple outputs.
178*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
179*77c1e3ccSAndroid Build Coastguard Worker                                    int stride, const CNN_CONFIG *cnn_config,
180*77c1e3ccSAndroid Build Coastguard Worker                                    const CNN_THREAD_DATA *thread_data,
181*77c1e3ccSAndroid Build Coastguard Worker                                    struct CNN_MULTI_OUT *output);
182*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
183*77c1e3ccSAndroid Build Coastguard Worker                                           int stride,
184*77c1e3ccSAndroid Build Coastguard Worker                                           const CNN_CONFIG *cnn_config,
185*77c1e3ccSAndroid Build Coastguard Worker                                           const CNN_THREAD_DATA *thread_data,
186*77c1e3ccSAndroid Build Coastguard Worker                                           int bit_depth, CNN_MULTI_OUT *output);
187*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus
188*77c1e3ccSAndroid Build Coastguard Worker }  // extern "C"
189*77c1e3ccSAndroid Build Coastguard Worker #endif
190*77c1e3ccSAndroid Build Coastguard Worker 
191*77c1e3ccSAndroid Build Coastguard Worker #endif  // AOM_AV1_ENCODER_CNN_H_
192