1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_COMMON_H_ 17 #define RUY_RUY_KERNEL_COMMON_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <type_traits> 22 23 #include "ruy/apply_multiplier.h" 24 #include "ruy/check_macros.h" 25 #include "ruy/mat.h" 26 #include "ruy/matrix.h" 27 #include "ruy/mul_params.h" 28 #include "ruy/opt_set.h" 29 #include "ruy/path.h" 30 #include "ruy/platform.h" 31 #include "ruy/profiler/instrumentation.h" 32 #include "ruy/side_pair.h" 33 #include "ruy/size_util.h" 34 #include "ruy/tune.h" 35 36 namespace ruy { 37 38 template <Path ThePath, typename LhsScalar, typename RhsScalar, 39 typename AccumScalar, typename DstScalar> 40 struct Kernel; 41 42 #define RUY_INHERIT_KERNEL(PARENT, CHILD) \ 43 template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ 44 typename AccumScalar> \ 45 struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \ 46 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \ 47 explicit Kernel(Tuning tuning) \ 48 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \ 49 tuning) {} \ 50 }; 51 52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. 53 // 54 // In other cases, we still define (empty) versions, so that dummy kernels 55 // can use the classes in function signatures. 56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \ 57 RUY_PLATFORM_X86 58 59 #define RUY_ASM_FLAG_HAS_BIAS 0x1 60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20 65 66 #define RUY_ASM_TYPE_ID_UINT8 1 67 #define RUY_ASM_TYPE_ID_INT8 2 68 #define RUY_ASM_TYPE_ID_INT16 3 69 #define RUY_ASM_TYPE_ID_INT32 4 70 71 template <typename DstScalar> 72 struct DstTypeId {}; 73 74 template <> 75 struct DstTypeId<std::uint8_t> { 76 static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; 77 }; 78 79 template <> 80 struct DstTypeId<std::int8_t> { 81 static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; 82 }; 83 84 template <> 85 struct DstTypeId<std::int16_t> { 86 static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; 87 }; 88 89 template <> 90 struct DstTypeId<std::int32_t> { 91 static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; 92 }; 93 94 template <int LhsCols, int RhsCols> 95 struct KernelParams8bit { 96 static constexpr int kMaxDstTypeSize = 4; 97 98 const std::int32_t* bias; 99 const std::int32_t* lhs_sums; 100 const std::int32_t* rhs_sums; 101 const std::int8_t* lhs_base_ptr; 102 const std::int32_t* multiplier_fixedpoint; 103 const std::int32_t* multiplier_exponent; 104 // Make it void* to support 8bit(LHS)x16bit(RHS) case. 105 const void* rhs_base_ptr; 106 void* dst_base_ptr; 107 std::int32_t lhs_zero_point; 108 std::int32_t rhs_zero_point; 109 std::int32_t dst_zero_point; 110 std::int32_t prod_zp_depth; 111 std::int32_t start_row; 112 std::int32_t start_col; 113 std::int32_t last_row; 114 std::int32_t last_col; 115 std::int32_t dst_rows; 116 std::int32_t dst_cols; 117 std::int32_t lhs_stride; 118 std::int32_t rhs_stride; 119 std::int32_t dst_stride; 120 std::int32_t depth; 121 std::int32_t clamp_min; 122 std::int32_t clamp_max; 123 std::uint8_t flags; 124 std::uint8_t dst_type_id; 125 const std::int32_t zero_data[LhsCols] = {0}; 126 std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; 127 std::int32_t multiplier_fixedpoint_buf[LhsCols]; 128 std::int32_t multiplier_exponent_buf[LhsCols]; 129 std::size_t rhs_scalar_size; 130 }; 131 132 template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols> 133 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, 134 const PMat<RhsScalar>& rhs, 135 const MulParams<std::int32_t, DstScalar>& mul_params, 136 int start_row, int start_col, int end_row, 137 int end_col, Mat<DstScalar>* dst, 138 KernelParams8bit<LhsCols, RhsCols>* params) { 139 using Params = KernelParams8bit<LhsCols, RhsCols>; 140 141 static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); 142 143 const int depth = lhs.layout.rows; 144 RUY_DCHECK_EQ(start_row % LhsCols, 0); 145 RUY_DCHECK_EQ(start_col % RhsCols, 0); 146 RUY_DCHECK_EQ(end_row % LhsCols, 0); 147 RUY_DCHECK_EQ(end_col % RhsCols, 0); 148 149 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 150 params->rhs_scalar_size = sizeof(RhsScalar); 151 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 152 params->flags = 0; 153 params->bias = params->zero_data; 154 if (mul_params.bias()) { 155 params->bias = mul_params.bias(); 156 params->flags |= RUY_ASM_FLAG_HAS_BIAS; 157 } 158 if (lhs.sums) { 159 params->lhs_sums = lhs.sums; 160 params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; 161 } 162 if (rhs.sums) { 163 params->rhs_sums = rhs.sums; 164 params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; 165 } 166 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 167 params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 168 } 169 params->start_row = start_row; 170 params->start_col = start_col; 171 params->last_row = end_row - LhsCols; 172 params->last_col = end_col - RhsCols; 173 params->lhs_stride = lhs.layout.stride; 174 params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride; 175 params->dst_stride = sizeof(DstScalar) * dst->layout.stride; 176 params->lhs_zero_point = lhs.zero_point; 177 params->rhs_zero_point = rhs.zero_point; 178 params->dst_zero_point = dst->zero_point; 179 params->depth = depth; 180 params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; 181 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; 182 if (mul_params.multiplier_fixedpoint_perchannel()) { 183 // Temporary release-assert to debug some crashes in an application. 184 RUY_CHECK(mul_params.multiplier_exponent_perchannel()); 185 params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; 186 params->multiplier_fixedpoint = 187 mul_params.multiplier_fixedpoint_perchannel(); 188 params->multiplier_exponent = mul_params.multiplier_exponent_perchannel(); 189 } else { 190 params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; 191 params->multiplier_exponent = params->multiplier_exponent_buf; 192 for (int i = 0; i < LhsCols; i++) { 193 params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint(); 194 params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent(); 195 } 196 } 197 params->clamp_min = mul_params.clamp_min(); 198 params->clamp_max = mul_params.clamp_max(); 199 params->dst_rows = dst->layout.rows; 200 params->dst_cols = dst->layout.cols; 201 202 RUY_DCHECK_LT(params->last_row, params->dst_rows); 203 RUY_DCHECK_LT(params->last_col, params->dst_cols); 204 205 params->dst_type_id = DstTypeId<DstScalar>::kValue; 206 params->dst_base_ptr = 207 dst->data.get() + start_col * dst->layout.stride + start_row; 208 209 // Temporary release-asserts to debug some crashes in an application. 210 RUY_CHECK(params->multiplier_fixedpoint); 211 RUY_CHECK(params->multiplier_exponent); 212 RUY_CHECK(params->bias); 213 } 214 215 template <int LhsCols, int RhsCols> 216 struct KernelParamsFloat { 217 const float* lhs_base_ptr; 218 const float* rhs_base_ptr; 219 float* dst_base_ptr; 220 const float* bias; 221 std::int32_t start_row; 222 std::int32_t start_col; 223 std::int32_t last_row; 224 std::int32_t last_col; 225 std::int32_t dst_rows; 226 std::int32_t dst_cols; 227 std::int32_t lhs_stride; 228 std::int32_t rhs_stride; 229 std::int32_t dst_stride; 230 std::int32_t depth; 231 float clamp_min; 232 float clamp_max; 233 std::uint8_t flags; 234 const float zero_data[LhsCols] = {0}; 235 float dst_tmp_buf[LhsCols * RhsCols]; 236 }; 237 238 template <int LhsCols, int RhsCols> 239 inline void MakeKernelParamsFloat(const PMat<float>& lhs, 240 const PMat<float>& rhs, 241 const MulParams<float, float>& mul_params, 242 int start_row, int start_col, int end_row, 243 int end_col, Mat<float>* dst, 244 KernelParamsFloat<LhsCols, RhsCols>* params) { 245 const int depth = lhs.layout.rows; 246 RUY_DCHECK_EQ(start_row % LhsCols, 0); 247 RUY_DCHECK_EQ(start_col % RhsCols, 0); 248 RUY_DCHECK_EQ(end_row % LhsCols, 0); 249 RUY_DCHECK_EQ(end_col % RhsCols, 0); 250 251 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; 252 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; 253 params->dst_base_ptr = 254 dst->data.get() + start_col * dst->layout.stride + start_row; 255 256 std::uint8_t flags = 0; 257 params->bias = params->zero_data; 258 if (mul_params.bias()) { 259 params->bias = mul_params.bias(); 260 flags |= RUY_ASM_FLAG_HAS_BIAS; 261 } 262 if (mul_params.channel_dimension() == ChannelDimension::kCol) { 263 flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 264 } 265 params->flags = flags; 266 params->start_row = start_row; 267 params->start_col = start_col; 268 params->last_row = end_row - LhsCols; 269 params->last_col = end_col - RhsCols; 270 params->lhs_stride = sizeof(float) * lhs.layout.stride; 271 params->rhs_stride = sizeof(float) * rhs.layout.stride; 272 params->dst_stride = sizeof(float) * dst->layout.stride; 273 params->depth = depth; 274 params->clamp_min = mul_params.clamp_min(); 275 params->clamp_max = mul_params.clamp_max(); 276 params->dst_rows = dst->layout.rows; 277 params->dst_cols = dst->layout.cols; 278 279 RUY_DCHECK_LT(params->last_row, params->dst_rows); 280 RUY_DCHECK_LT(params->last_col, params->dst_cols); 281 } 282 283 #else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 284 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 285 286 template <int LhsCols, int RhsCols> 287 struct KernelParams8bit {}; 288 289 template <int LhsCols, int RhsCols> 290 struct KernelParamsFloat {}; 291 292 #endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && 293 // RUY_OPT(ASM)) || RUY_PLATFORM_X86 294 295 } // namespace ruy 296 297 #endif // RUY_RUY_KERNEL_COMMON_H_ 298