1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 /******************************************************************************* 10 * Copyright (c) 2018-2023 Cadence Design Systems, Inc. 11 * 12 * Permission is hereby granted, free of charge, to any person obtaining 13 * a copy of this software and associated documentation files (the 14 * "Software"), to use this Software with Cadence processor cores only and 15 * not with any other processors and platforms, subject to 16 * the following conditions: 17 * 18 * The above copyright notice and this permission notice shall be included 19 * in all copies or substantial portions of the Software. 20 * 21 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 22 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 23 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 24 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 25 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 26 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 27 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 28 29 ******************************************************************************/ 30 31 #pragma once 32 33 /* 34 nnlib forces ROW_UNROLL to be 4 when the input and weight are both aligned to 8b 35 boundary. Through experimentation, we observe that the 36 xa_nn_matmul_asym8xasym8_asym8 kernel in nnlib performs better when the weight 37 matrix is uniformly unrolled by a factor of 2 instead of 4 for 8b aligned case. 38 We add a case for ROW_UNROLL=2 and VEC_UNROLL=2 here. This code is similar to 39 the ROW_UNROLL=4 and VEC_UNROLL=2 code in 40 nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_common_macros.h. 41 */ 42 43 // Unrolling macros that unroll both matrices by a factor of 2. 44 #if (ROW_UNROLL == 2 && VEC_UNROLL == 2) 45 46 #define SETUP_VEC_BATCH UNROLL_SETUP_VEC_BATCH(0) UNROLL_SETUP_VEC_BATCH(1) 47 48 #define SETUP_ACC_BATCH \ 49 UNROLL_ROW_SETUP_ACC_BATCH(0) \ 50 UNROLL_ROW_SETUP_ACC_BATCH(1) 51 52 #define SETUP_ACC_BATCH_VEC_UNROLL(idx_row) \ 53 UNROLL_SETUP_ACC_BATCH(idx_row, 0) \ 54 UNROLL_SETUP_ACC_BATCH(idx_row, 1) 55 56 #define SETUP_ACC_BATCH_TAIL \ 57 UNROLL_SETUP_ACC_BATCH(0, 0) \ 58 UNROLL_SETUP_ACC_BATCH(1, 0) 59 60 #define LOAD_VEC_BATCH UNROLL_LOAD_VEC_BATCH(0) UNROLL_LOAD_VEC_BATCH(1) 61 62 #define LOAD_MAT1 \ 63 UNROLL_LOAD_ROW_MAT1(0) \ 64 UNROLL_LOAD_ROW_MAT1(1) 65 66 #define KERNEL_MAT1_VEC_BATCH \ 67 UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(0) \ 68 UNROLL_ROW_KERNEL_MAT1_VEC_BATCH(1) 69 70 #define KERNEL_MAT1_VEC_BATCH_VEC_UNROLL(idx_row) \ 71 UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 0) \ 72 UNROLL_KERNEL_MAT1_VEC_BATCH(idx_row, 1) 73 74 #define KERNEL_MAT1_VEC_BATCH_TAIL \ 75 UNROLL_KERNEL_MAT1_VEC_BATCH(0, 0) \ 76 UNROLL_KERNEL_MAT1_VEC_BATCH(1, 0) 77 78 #define ADD_BIAS_ACC_BATCH \ 79 UNROLL_ROW_ADD_BIAS_ACC(0) \ 80 UNROLL_ROW_ADD_BIAS_ACC(1) 81 82 #define ADD_BIAS_BATCH_ACC_VEC_UNROLL(idx_row) \ 83 UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 0) UNROLL_ADD_BIAS_ACC_BATCH(idx_row, 1) 84 85 #define ADD_BIAS_ACC_BATCH_TAIL \ 86 LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(0, 0) \ 87 LOAD_BIAS UNROLL_ADD_BIAS_ACC_BATCH(1, 0) 88 89 #define STORE_ACC_BATCH \ 90 UNROLL_ROW_STORE_ACC(0) \ 91 UNROLL_ROW_STORE_ACC(1) 92 93 #define STORE_ACC_BATCH_VEC_UNROLL(idx_row) \ 94 UNROLL_STORE_ACC_BATCH(idx_row, 0) UNROLL_STORE_ACC_BATCH(idx_row, 1) 95 96 #define STORE_ACC_BATCH_TAIL \ 97 UNROLL_STORE_ACC_BATCH(0, 0) \ 98 UNROLL_STORE_ACC_BATCH(1, 0) 99 100 #define ADJUST_ACC_BATCH_TAIL \ 101 UNROLL_ADJUST_ACC_BATCH(0, 0) \ 102 UNROLL_ADJUST_ACC_BATCH(1, 0) 103 104 #define ADJUST_ACC_BATCH \ 105 UNROLL_ROW_ADJUST_ACC(0) \ 106 UNROLL_ROW_ADJUST_ACC(1) 107 108 #define ADJUST_ACC_BATCH_VEC_UNROLL(idx_row) \ 109 UNROLL_ADJUST_ACC_BATCH(idx_row, 0) UNROLL_ADJUST_ACC_BATCH(idx_row, 1) 110 111 #endif /* (ROW_UNROLL == 2 && VEC_UNROLL == 2)*/ 112