1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <inttypes.h> 12 #include <stddef.h> 13 #include <xa_type_def.h> 14 /* For NNLIB APIs */ 15 #include "xa_nnlib_kernels_api.h" 16 17 /* Potential NNLIB function/APIs */ 18 extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32( 19 FLOAT32* __restrict__ p_out, 20 const WORD32* const p_out_shape, 21 const FLOAT32* __restrict__ p_inp1, 22 const WORD32* const p_inp1_shape, 23 const FLOAT32* __restrict__ p_inp2, 24 const WORD32* const p_inp2_shape); 25 26 extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32( 27 FLOAT32* __restrict__ p_out, 28 const WORD32* const p_out_shape, 29 const FLOAT32* __restrict__ p_inp1, 30 const WORD32* const p_inp1_shape, 31 const FLOAT32* __restrict__ p_inp2, 32 const WORD32* const p_inp2_shape); 33 34 extern "C" WORD32 xa_nn_elm_div_mode_f32xf32_f32( 35 FLOAT32* __restrict__ p_out, 36 const FLOAT32* __restrict__ p_inp1, 37 const FLOAT32* __restrict__ p_inp2, 38 WORD32 num_elm, 39 WORD32 mode); 40 41 extern "C" WORD32 xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32( 42 FLOAT32* __restrict__ p_out, 43 const WORD32* const p_out_shape, 44 const FLOAT32* __restrict__ p_inp1, 45 const WORD32* const p_inp1_shape, 46 const FLOAT32* __restrict__ p_inp2, 47 const WORD32* const p_inp2_shape, 48 WORD32 mode); 49 50 extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32( 51 FLOAT32* __restrict__ p_out, 52 const WORD32* const p_out_shape, 53 const FLOAT32* __restrict__ p_inp1, 54 const WORD32* const p_inp1_shape, 55 const FLOAT32* __restrict__ p_inp2, 56 const WORD32* const p_inp2_shape); 57 58 extern "C" WORD32 xa_nn_elm_where_f32xf32_f32( 59 FLOAT32* __restrict__ p_out, 60 const FLOAT32* __restrict__ p_inp1, 61 const FLOAT32* __restrict__ p_inp2, 62 const unsigned char* __restrict__ p_condition, 63 WORD32 num_elm); 64 65 extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32( 66 FLOAT32* __restrict__ p_out, 67 const WORD32* const p_out_shape, 68 const FLOAT32* __restrict__ p_inp1, 69 const WORD32* const p_inp1_shape, 70 const FLOAT32* __restrict__ p_inp2, 71 const WORD32* const p_inp2_shape, 72 const unsigned char* __restrict__ p_condition, 73 const WORD32* const p_condition_shape); 74 75 extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32( 76 FLOAT32* __restrict__ p_out, 77 const WORD32* const p_out_shape, 78 const FLOAT32* __restrict__ p_inp, 79 const WORD32* const p_inp_shape, 80 const WORD32* __restrict__ p_axis, 81 WORD32 num_out_dims, 82 WORD32 num_inp_dims, 83 WORD32 num_axis_dims, 84 void* __restrict__ p_scratch_in); 85 86 namespace cadence { 87 namespace impl { 88 namespace HiFi { 89 namespace kernels { 90 91 void memcpy(void* dst, const void* src, size_t num_bytes); 92 93 WORD32 matmul_asym8uxasym8u_asym8u( 94 UWORD8* __restrict__ p_out, // output uint8 matrix 95 const UWORD8* __restrict__ p_mat1, // weight uint8 matrix 96 const UWORD8* __restrict__ p_vec1, // input uint8 matrix 97 const WORD32* __restrict__ p_bias, // bias int32 vec 98 WORD32 rows, // rows of p_mat1 99 WORD32 cols1, // columns of p_mat1 100 WORD32 row_stride1, // row stride of p_mat1 101 WORD32 vec_count, // rows of p_mat2 102 WORD32 vec_offset, // vec_offset of p_mat2. 103 WORD32 out_offset, // out_offset, i.e., offset of next output element 104 WORD32 out_stride, // out_stride, i.e., stride to go to next output row 105 WORD32 mat1_zero_bias, // zero_point of p_mat1 106 WORD32 vec1_zero_bias, // zero_point of p_vec1 107 const WORD32* __restrict__ out_multiplier, 108 const WORD32* __restrict__ out_shift, 109 WORD32 out_zero_bias, 110 bool per_channel_quantized = false); // per-channel quantized weight 111 112 WORD32 xa_nn_matmul_asym8uxasym8u_asym8u( 113 UWORD8* __restrict__ p_out, 114 const UWORD8* __restrict__ p_mat1, 115 const UWORD8* __restrict__ p_mat2, 116 const WORD32* __restrict__ p_bias, 117 WORD32 rows, 118 WORD32 cols, 119 WORD32 row_stride, 120 WORD32 vec_count, 121 WORD32 vec_offset, 122 WORD32 out_offset, 123 WORD32 out_stride, 124 WORD32 mat1_zero_bias, 125 WORD32 vec1_zero_bias, 126 WORD32 out_multiplier, 127 WORD32 out_shift, 128 WORD32 out_zero_bias); 129 130 template <typename T> 131 T quantize(const float x, float scale, int32_t zero_point); 132 133 template <typename T> 134 float dequantize(const T x, float scale, int32_t zero_point); 135 136 template <typename T> 137 void quantize( 138 T* __restrict__ y, 139 const float* __restrict__ x, 140 float scale, 141 int32_t zero_point, 142 size_t size); 143 144 // Deuantize an int8_t/uint8_t/int16_t array to an fp32 array 145 template <typename T> 146 void dequantize( 147 float* __restrict__ y, 148 const T* __restrict__ x, 149 float scale, 150 int32_t zero_point, 151 size_t size); 152 153 }; // namespace kernels 154 }; // namespace HiFi 155 }; // namespace impl 156 }; // namespace cadence 157