1*77c1e3ccSAndroid Build Coastguard Worker /* 2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2019, Alliance for Open Media. All rights reserved. 3*77c1e3ccSAndroid Build Coastguard Worker * 4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and 5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can 7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the 9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10*77c1e3ccSAndroid Build Coastguard Worker */ 11*77c1e3ccSAndroid Build Coastguard Worker 12*77c1e3ccSAndroid Build Coastguard Worker #ifndef AOM_AV1_ENCODER_CNN_H_ 13*77c1e3ccSAndroid Build Coastguard Worker #define AOM_AV1_ENCODER_CNN_H_ 14*77c1e3ccSAndroid Build Coastguard Worker 15*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus 16*77c1e3ccSAndroid Build Coastguard Worker extern "C" { 17*77c1e3ccSAndroid Build Coastguard Worker #endif 18*77c1e3ccSAndroid Build Coastguard Worker 19*77c1e3ccSAndroid Build Coastguard Worker #include <math.h> 20*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h> 21*77c1e3ccSAndroid Build Coastguard Worker 22*77c1e3ccSAndroid Build Coastguard Worker #include "aom_util/aom_thread.h" 23*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h" 24*77c1e3ccSAndroid Build Coastguard Worker 25*77c1e3ccSAndroid Build Coastguard Worker struct AV1Common; 26*77c1e3ccSAndroid Build Coastguard Worker 27*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_HIDDEN_LAYERS 64 28*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_LAYERS (CNN_MAX_HIDDEN_LAYERS + 1) 29*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_CHANNELS 256 30*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_BRANCHES 4 31*77c1e3ccSAndroid Build Coastguard Worker #define CNN_MAX_THREADS 32 32*77c1e3ccSAndroid Build Coastguard Worker 33*77c1e3ccSAndroid Build Coastguard Worker #define NO_BRANCH_CONFIG \ 34*77c1e3ccSAndroid Build Coastguard Worker { 0, 0, 0 } 35*77c1e3ccSAndroid Build Coastguard Worker #define NO_BN_PARAMS \ 36*77c1e3ccSAndroid Build Coastguard Worker { NULL, NULL, NULL, NULL } 37*77c1e3ccSAndroid Build Coastguard Worker 38*77c1e3ccSAndroid Build Coastguard Worker enum { 39*77c1e3ccSAndroid Build Coastguard Worker PADDING_SAME_ZERO, // tensorflow's SAME padding with pixels outside 40*77c1e3ccSAndroid Build Coastguard Worker // the image area assumed to be 0 (default) 41*77c1e3ccSAndroid Build Coastguard Worker PADDING_SAME_REPLICATE, // tensorflow's SAME padding with pixels outside 42*77c1e3ccSAndroid Build Coastguard Worker // the image area replicated from closest edge 43*77c1e3ccSAndroid Build Coastguard Worker PADDING_VALID // tensorflow's VALID padding 44*77c1e3ccSAndroid Build Coastguard Worker } UENUM1BYTE(PADDING_TYPE); 45*77c1e3ccSAndroid Build Coastguard Worker 46*77c1e3ccSAndroid Build Coastguard Worker // enum { NONE, RELU, SOFTSIGN } UENUM1BYTE(ACTIVATION); 47*77c1e3ccSAndroid Build Coastguard Worker 48*77c1e3ccSAndroid Build Coastguard Worker // Times when input tensor may be copied to branches given in input_to_branches. 49*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_NO_COPY: doesn't copy any tensor. 50*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_INPUT: copies the input tensor to branches. 51*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_OUTPUT: copies the convolved tensor to branches. 52*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_COMBINED: copies the combined (after convolving and branch combining) 53*77c1e3ccSAndroid Build Coastguard Worker // tensor. If no combinations happen at this layer, then this option 54*77c1e3ccSAndroid Build Coastguard Worker // has the same effect as COPY_OUTPUT. 55*77c1e3ccSAndroid Build Coastguard Worker enum { 56*77c1e3ccSAndroid Build Coastguard Worker BRANCH_NO_COPY, 57*77c1e3ccSAndroid Build Coastguard Worker BRANCH_INPUT, 58*77c1e3ccSAndroid Build Coastguard Worker BRANCH_OUTPUT, 59*77c1e3ccSAndroid Build Coastguard Worker BRANCH_COMBINED 60*77c1e3ccSAndroid Build Coastguard Worker } UENUM1BYTE(BRANCH_COPY); 61*77c1e3ccSAndroid Build Coastguard Worker 62*77c1e3ccSAndroid Build Coastguard Worker // Types of combining branches with output of current layer: 63*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_NOC: no branch combining 64*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_ADD: Add previously stored branch tensor to output of layer 65*77c1e3ccSAndroid Build Coastguard Worker // BRANCH_CAT: Concatenate branch tensor to output of layer 66*77c1e3ccSAndroid Build Coastguard Worker enum { BRANCH_NOC, BRANCH_ADD, BRANCH_CAT } UENUM1BYTE(BRANCH_COMBINE); 67*77c1e3ccSAndroid Build Coastguard Worker 68*77c1e3ccSAndroid Build Coastguard Worker // The parameters used to scale each channel in batch 69*77c1e3ccSAndroid Build Coastguard Worker // normalization. The processing in done on a per-channel basis. 70*77c1e3ccSAndroid Build Coastguard Worker // e.g. bn_mean[c] is the mean for all pixels in channel c. This 71*77c1e3ccSAndroid Build Coastguard Worker // is always applied after activation. The output is given by 72*77c1e3ccSAndroid Build Coastguard Worker // out[c,i,j] = norm[c,i,j] * bn_gamma[c] + bn_beta[c] where 73*77c1e3ccSAndroid Build Coastguard Worker // norm[c,i,j] = (in[c,i,j] - bn_mean[c]) / bn_std[c] 74*77c1e3ccSAndroid Build Coastguard Worker // here we assume that the effect of variance_epsilon is already 75*77c1e3ccSAndroid Build Coastguard Worker // taken into account when bn_std is calculated. The pointers 76*77c1e3ccSAndroid Build Coastguard Worker // needs to be either all zero or all valid. If all zero, then 77*77c1e3ccSAndroid Build Coastguard Worker // batchnorm is disabled, else batchnorm is applied. 78*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BATCHNORM_PARAMS { 79*77c1e3ccSAndroid Build Coastguard Worker const float *bn_gamma; 80*77c1e3ccSAndroid Build Coastguard Worker const float *bn_beta; 81*77c1e3ccSAndroid Build Coastguard Worker const float *bn_mean; 82*77c1e3ccSAndroid Build Coastguard Worker const float *bn_std; 83*77c1e3ccSAndroid Build Coastguard Worker }; 84*77c1e3ccSAndroid Build Coastguard Worker 85*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BRANCH_CONFIG { 86*77c1e3ccSAndroid Build Coastguard Worker int input_to_branches; // If nonzero, copy the active tensor to the current 87*77c1e3ccSAndroid Build Coastguard Worker // layer and store for future use in branches 88*77c1e3ccSAndroid Build Coastguard Worker // specified in the field as a binary mask. For 89*77c1e3ccSAndroid Build Coastguard Worker // example, if input_to_branch = 0x06, it means the 90*77c1e3ccSAndroid Build Coastguard Worker // input tensor to the current branch is copied to 91*77c1e3ccSAndroid Build Coastguard Worker // branches 1 and 2 (where 0 represents the primary 92*77c1e3ccSAndroid Build Coastguard Worker // branch). One restriction is that the mask 93*77c1e3ccSAndroid Build Coastguard Worker // cannot indicate copying to the current branch. 94*77c1e3ccSAndroid Build Coastguard Worker // If greater than 0, only copies the channels up 95*77c1e3ccSAndroid Build Coastguard Worker // to the given index. 96*77c1e3ccSAndroid Build Coastguard Worker int channels_to_copy; // Within the layer, input a copy of active 97*77c1e3ccSAndroid Build Coastguard Worker // tensor to branches given in input_to_branches. 98*77c1e3ccSAndroid Build Coastguard Worker int branches_to_combine; // mask of branches to combine with output of 99*77c1e3ccSAndroid Build Coastguard Worker // current layer, if 100*77c1e3ccSAndroid Build Coastguard Worker // branch_combine_type != BRANCH_NOC 101*77c1e3ccSAndroid Build Coastguard Worker // For example, if branches_to_combine = 0x0A, 102*77c1e3ccSAndroid Build Coastguard Worker // it means that braches 1 and 3 are combined 103*77c1e3ccSAndroid Build Coastguard Worker // with the current branch. 104*77c1e3ccSAndroid Build Coastguard Worker }; 105*77c1e3ccSAndroid Build Coastguard Worker 106*77c1e3ccSAndroid Build Coastguard Worker struct CNN_LAYER_CONFIG { 107*77c1e3ccSAndroid Build Coastguard Worker int in_channels; 108*77c1e3ccSAndroid Build Coastguard Worker int filter_width; 109*77c1e3ccSAndroid Build Coastguard Worker int filter_height; 110*77c1e3ccSAndroid Build Coastguard Worker int out_channels; 111*77c1e3ccSAndroid Build Coastguard Worker int skip_width; 112*77c1e3ccSAndroid Build Coastguard Worker int skip_height; 113*77c1e3ccSAndroid Build Coastguard Worker int maxpool; // whether to use maxpool or not (only effective when 114*77c1e3ccSAndroid Build Coastguard Worker // skip width or skip_height are > 1) 115*77c1e3ccSAndroid Build Coastguard Worker const float *weights; // array of length filter_height x filter_width x 116*77c1e3ccSAndroid Build Coastguard Worker // in_channels x out_channels where the inner-most 117*77c1e3ccSAndroid Build Coastguard Worker // scan is out_channels and the outer most scan is 118*77c1e3ccSAndroid Build Coastguard Worker // filter_height. 119*77c1e3ccSAndroid Build Coastguard Worker const float *bias; // array of length out_channels 120*77c1e3ccSAndroid Build Coastguard Worker PADDING_TYPE pad; // padding type 121*77c1e3ccSAndroid Build Coastguard Worker ACTIVATION activation; // the activation function to use after convolution 122*77c1e3ccSAndroid Build Coastguard Worker int deconvolve; // whether this is a deconvolution layer. 123*77c1e3ccSAndroid Build Coastguard Worker // 0: If skip_width or skip_height are > 1, then we 124*77c1e3ccSAndroid Build Coastguard Worker // reduce resolution 125*77c1e3ccSAndroid Build Coastguard Worker // 1: If skip_width or skip_height are > 1, then we 126*77c1e3ccSAndroid Build Coastguard Worker // increase resolution 127*77c1e3ccSAndroid Build Coastguard Worker int branch; // branch index in [0, CNN_MAX_BRANCHES - 1], where 128*77c1e3ccSAndroid Build Coastguard Worker // 0 refers to the primary branch. 129*77c1e3ccSAndroid Build Coastguard Worker BRANCH_COPY branch_copy_type; 130*77c1e3ccSAndroid Build Coastguard Worker BRANCH_COMBINE branch_combine_type; 131*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BRANCH_CONFIG branch_config; 132*77c1e3ccSAndroid Build Coastguard Worker struct CNN_BATCHNORM_PARAMS 133*77c1e3ccSAndroid Build Coastguard Worker bn_params; // A struct that contains the parameters 134*77c1e3ccSAndroid Build Coastguard Worker // used for batch normalization. 135*77c1e3ccSAndroid Build Coastguard Worker int output_num; // The output buffer idx to which the layer output is 136*77c1e3ccSAndroid Build Coastguard Worker // written. Set to -1 to disable writing it to the output. In 137*77c1e3ccSAndroid Build Coastguard Worker // the case that branch_combine_type is BRANCH_CAT, all 138*77c1e3ccSAndroid Build Coastguard Worker // concatenated channels will be written to output. In the 139*77c1e3ccSAndroid Build Coastguard Worker // case of BRANCH_ADD, the output will be the result of 140*77c1e3ccSAndroid Build Coastguard Worker // summation. 141*77c1e3ccSAndroid Build Coastguard Worker }; 142*77c1e3ccSAndroid Build Coastguard Worker 143*77c1e3ccSAndroid Build Coastguard Worker struct CNN_CONFIG { 144*77c1e3ccSAndroid Build Coastguard Worker int num_layers; // number of CNN layers ( = number of hidden layers + 1) 145*77c1e3ccSAndroid Build Coastguard Worker int is_residue; // whether the output activation is a residue 146*77c1e3ccSAndroid Build Coastguard Worker int ext_width, ext_height; // extension horizontally and vertically 147*77c1e3ccSAndroid Build Coastguard Worker int strict_bounds; // whether the input bounds are strict or not. 148*77c1e3ccSAndroid Build Coastguard Worker // If strict, the extension area is filled by 149*77c1e3ccSAndroid Build Coastguard Worker // replication; if not strict, image data is 150*77c1e3ccSAndroid Build Coastguard Worker // assumed available beyond the bounds. 151*77c1e3ccSAndroid Build Coastguard Worker CNN_LAYER_CONFIG layer_config[CNN_MAX_LAYERS]; 152*77c1e3ccSAndroid Build Coastguard Worker }; 153*77c1e3ccSAndroid Build Coastguard Worker 154*77c1e3ccSAndroid Build Coastguard Worker struct CNN_THREAD_DATA { 155*77c1e3ccSAndroid Build Coastguard Worker int num_workers; 156*77c1e3ccSAndroid Build Coastguard Worker AVxWorker *workers; 157*77c1e3ccSAndroid Build Coastguard Worker }; 158*77c1e3ccSAndroid Build Coastguard Worker 159*77c1e3ccSAndroid Build Coastguard Worker struct CNN_MULTI_OUT { 160*77c1e3ccSAndroid Build Coastguard Worker int num_outputs; 161*77c1e3ccSAndroid Build Coastguard Worker const int *output_channels; 162*77c1e3ccSAndroid Build Coastguard Worker const int *output_strides; 163*77c1e3ccSAndroid Build Coastguard Worker float **output_buffer; 164*77c1e3ccSAndroid Build Coastguard Worker }; 165*77c1e3ccSAndroid Build Coastguard Worker 166*77c1e3ccSAndroid Build Coastguard Worker // Function to return size of output 167*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_output_size(int in_width, int in_height, 168*77c1e3ccSAndroid Build Coastguard Worker const CNN_CONFIG *cnn_config, int *out_width, 169*77c1e3ccSAndroid Build Coastguard Worker int *out_height, int *out_channels); 170*77c1e3ccSAndroid Build Coastguard Worker 171*77c1e3ccSAndroid Build Coastguard Worker // Function to return output width and output height of given layer. 172*77c1e3ccSAndroid Build Coastguard Worker void av1_find_cnn_layer_output_size(int in_width, int in_height, 173*77c1e3ccSAndroid Build Coastguard Worker const CNN_LAYER_CONFIG *layer_config, 174*77c1e3ccSAndroid Build Coastguard Worker int *out_width, int *out_height); 175*77c1e3ccSAndroid Build Coastguard Worker 176*77c1e3ccSAndroid Build Coastguard Worker // Prediction functions from set of input image buffers. This function supports 177*77c1e3ccSAndroid Build Coastguard Worker // CNN with multiple outputs. 178*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height, 179*77c1e3ccSAndroid Build Coastguard Worker int stride, const CNN_CONFIG *cnn_config, 180*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data, 181*77c1e3ccSAndroid Build Coastguard Worker struct CNN_MULTI_OUT *output); 182*77c1e3ccSAndroid Build Coastguard Worker bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height, 183*77c1e3ccSAndroid Build Coastguard Worker int stride, 184*77c1e3ccSAndroid Build Coastguard Worker const CNN_CONFIG *cnn_config, 185*77c1e3ccSAndroid Build Coastguard Worker const CNN_THREAD_DATA *thread_data, 186*77c1e3ccSAndroid Build Coastguard Worker int bit_depth, CNN_MULTI_OUT *output); 187*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus 188*77c1e3ccSAndroid Build Coastguard Worker } // extern "C" 189*77c1e3ccSAndroid Build Coastguard Worker #endif 190*77c1e3ccSAndroid Build Coastguard Worker 191*77c1e3ccSAndroid Build Coastguard Worker #endif // AOM_AV1_ENCODER_CNN_H_ 192