xref: /aosp_15_r20/external/ruy/ruy/kernel_common.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_COMMON_H_
17 #define RUY_RUY_KERNEL_COMMON_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <type_traits>
22 
23 #include "ruy/apply_multiplier.h"
24 #include "ruy/check_macros.h"
25 #include "ruy/mat.h"
26 #include "ruy/matrix.h"
27 #include "ruy/mul_params.h"
28 #include "ruy/opt_set.h"
29 #include "ruy/path.h"
30 #include "ruy/platform.h"
31 #include "ruy/profiler/instrumentation.h"
32 #include "ruy/side_pair.h"
33 #include "ruy/size_util.h"
34 #include "ruy/tune.h"
35 
36 namespace ruy {
37 
38 template <Path ThePath, typename LhsScalar, typename RhsScalar,
39           typename AccumScalar, typename DstScalar>
40 struct Kernel;
41 
42 #define RUY_INHERIT_KERNEL(PARENT, CHILD)                               \
43   template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
44             typename AccumScalar>                                       \
45   struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar>    \
46       : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> {  \
47     explicit Kernel(Tuning tuning)                                      \
48         : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \
49               tuning) {}                                                \
50   };
51 
52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
53 //
54 // In other cases, we still define (empty) versions, so that dummy kernels
55 // can use the classes in function signatures.
56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \
57     RUY_PLATFORM_X86
58 
59 #define RUY_ASM_FLAG_HAS_BIAS 0x1
60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20
65 
66 #define RUY_ASM_TYPE_ID_UINT8 1
67 #define RUY_ASM_TYPE_ID_INT8 2
68 #define RUY_ASM_TYPE_ID_INT16 3
69 #define RUY_ASM_TYPE_ID_INT32 4
70 
71 template <typename DstScalar>
72 struct DstTypeId {};
73 
74 template <>
75 struct DstTypeId<std::uint8_t> {
76   static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
77 };
78 
79 template <>
80 struct DstTypeId<std::int8_t> {
81   static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
82 };
83 
84 template <>
85 struct DstTypeId<std::int16_t> {
86   static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
87 };
88 
89 template <>
90 struct DstTypeId<std::int32_t> {
91   static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
92 };
93 
94 template <int LhsCols, int RhsCols>
95 struct KernelParams8bit {
96   static constexpr int kMaxDstTypeSize = 4;
97 
98   const std::int32_t* bias;
99   const std::int32_t* lhs_sums;
100   const std::int32_t* rhs_sums;
101   const std::int8_t* lhs_base_ptr;
102   const std::int32_t* multiplier_fixedpoint;
103   const std::int32_t* multiplier_exponent;
104   // Make it void* to support 8bit(LHS)x16bit(RHS) case.
105   const void* rhs_base_ptr;
106   void* dst_base_ptr;
107   std::int32_t lhs_zero_point;
108   std::int32_t rhs_zero_point;
109   std::int32_t dst_zero_point;
110   std::int32_t prod_zp_depth;
111   std::int32_t start_row;
112   std::int32_t start_col;
113   std::int32_t last_row;
114   std::int32_t last_col;
115   std::int32_t dst_rows;
116   std::int32_t dst_cols;
117   std::int32_t lhs_stride;
118   std::int32_t rhs_stride;
119   std::int32_t dst_stride;
120   std::int32_t depth;
121   std::int32_t clamp_min;
122   std::int32_t clamp_max;
123   std::uint8_t flags;
124   std::uint8_t dst_type_id;
125   const std::int32_t zero_data[LhsCols] = {0};
126   std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
127   std::int32_t multiplier_fixedpoint_buf[LhsCols];
128   std::int32_t multiplier_exponent_buf[LhsCols];
129   std::size_t rhs_scalar_size;
130 };
131 
132 template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols>
133 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
134                           const PMat<RhsScalar>& rhs,
135                           const MulParams<std::int32_t, DstScalar>& mul_params,
136                           int start_row, int start_col, int end_row,
137                           int end_col, Mat<DstScalar>* dst,
138                           KernelParams8bit<LhsCols, RhsCols>* params) {
139   using Params = KernelParams8bit<LhsCols, RhsCols>;
140 
141   static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
142 
143   const int depth = lhs.layout.rows;
144   RUY_DCHECK_EQ(start_row % LhsCols, 0);
145   RUY_DCHECK_EQ(start_col % RhsCols, 0);
146   RUY_DCHECK_EQ(end_row % LhsCols, 0);
147   RUY_DCHECK_EQ(end_col % RhsCols, 0);
148 
149   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
150   params->rhs_scalar_size = sizeof(RhsScalar);
151   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
152   params->flags = 0;
153   params->bias = params->zero_data;
154   if (mul_params.bias()) {
155     params->bias = mul_params.bias();
156     params->flags |= RUY_ASM_FLAG_HAS_BIAS;
157   }
158   if (lhs.sums) {
159     params->lhs_sums = lhs.sums;
160     params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
161   }
162   if (rhs.sums) {
163     params->rhs_sums = rhs.sums;
164     params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
165   }
166   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
167     params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
168   }
169   params->start_row = start_row;
170   params->start_col = start_col;
171   params->last_row = end_row - LhsCols;
172   params->last_col = end_col - RhsCols;
173   params->lhs_stride = lhs.layout.stride;
174   params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride;
175   params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
176   params->lhs_zero_point = lhs.zero_point;
177   params->rhs_zero_point = rhs.zero_point;
178   params->dst_zero_point = dst->zero_point;
179   params->depth = depth;
180   params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
181   params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
182   if (mul_params.multiplier_fixedpoint_perchannel()) {
183     // Temporary release-assert to debug some crashes in an application.
184     RUY_CHECK(mul_params.multiplier_exponent_perchannel());
185     params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
186     params->multiplier_fixedpoint =
187         mul_params.multiplier_fixedpoint_perchannel();
188     params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
189   } else {
190     params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
191     params->multiplier_exponent = params->multiplier_exponent_buf;
192     for (int i = 0; i < LhsCols; i++) {
193       params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint();
194       params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent();
195     }
196   }
197   params->clamp_min = mul_params.clamp_min();
198   params->clamp_max = mul_params.clamp_max();
199   params->dst_rows = dst->layout.rows;
200   params->dst_cols = dst->layout.cols;
201 
202   RUY_DCHECK_LT(params->last_row, params->dst_rows);
203   RUY_DCHECK_LT(params->last_col, params->dst_cols);
204 
205   params->dst_type_id = DstTypeId<DstScalar>::kValue;
206   params->dst_base_ptr =
207       dst->data.get() + start_col * dst->layout.stride + start_row;
208 
209   // Temporary release-asserts to debug some crashes in an application.
210   RUY_CHECK(params->multiplier_fixedpoint);
211   RUY_CHECK(params->multiplier_exponent);
212   RUY_CHECK(params->bias);
213 }
214 
215 template <int LhsCols, int RhsCols>
216 struct KernelParamsFloat {
217   const float* lhs_base_ptr;
218   const float* rhs_base_ptr;
219   float* dst_base_ptr;
220   const float* bias;
221   std::int32_t start_row;
222   std::int32_t start_col;
223   std::int32_t last_row;
224   std::int32_t last_col;
225   std::int32_t dst_rows;
226   std::int32_t dst_cols;
227   std::int32_t lhs_stride;
228   std::int32_t rhs_stride;
229   std::int32_t dst_stride;
230   std::int32_t depth;
231   float clamp_min;
232   float clamp_max;
233   std::uint8_t flags;
234   const float zero_data[LhsCols] = {0};
235   float dst_tmp_buf[LhsCols * RhsCols];
236 };
237 
238 template <int LhsCols, int RhsCols>
239 inline void MakeKernelParamsFloat(const PMat<float>& lhs,
240                                   const PMat<float>& rhs,
241                                   const MulParams<float, float>& mul_params,
242                                   int start_row, int start_col, int end_row,
243                                   int end_col, Mat<float>* dst,
244                                   KernelParamsFloat<LhsCols, RhsCols>* params) {
245   const int depth = lhs.layout.rows;
246   RUY_DCHECK_EQ(start_row % LhsCols, 0);
247   RUY_DCHECK_EQ(start_col % RhsCols, 0);
248   RUY_DCHECK_EQ(end_row % LhsCols, 0);
249   RUY_DCHECK_EQ(end_col % RhsCols, 0);
250 
251   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
252   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
253   params->dst_base_ptr =
254       dst->data.get() + start_col * dst->layout.stride + start_row;
255 
256   std::uint8_t flags = 0;
257   params->bias = params->zero_data;
258   if (mul_params.bias()) {
259     params->bias = mul_params.bias();
260     flags |= RUY_ASM_FLAG_HAS_BIAS;
261   }
262   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
263     flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
264   }
265   params->flags = flags;
266   params->start_row = start_row;
267   params->start_col = start_col;
268   params->last_row = end_row - LhsCols;
269   params->last_col = end_col - RhsCols;
270   params->lhs_stride = sizeof(float) * lhs.layout.stride;
271   params->rhs_stride = sizeof(float) * rhs.layout.stride;
272   params->dst_stride = sizeof(float) * dst->layout.stride;
273   params->depth = depth;
274   params->clamp_min = mul_params.clamp_min();
275   params->clamp_max = mul_params.clamp_max();
276   params->dst_rows = dst->layout.rows;
277   params->dst_cols = dst->layout.cols;
278 
279   RUY_DCHECK_LT(params->last_row, params->dst_rows);
280   RUY_DCHECK_LT(params->last_col, params->dst_cols);
281 }
282 
283 #else  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
284        // RUY_OPT(ASM)) || RUY_PLATFORM_X86
285 
286 template <int LhsCols, int RhsCols>
287 struct KernelParams8bit {};
288 
289 template <int LhsCols, int RhsCols>
290 struct KernelParamsFloat {};
291 
292 #endif  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
293         //  RUY_OPT(ASM)) || RUY_PLATFORM_X86
294 
295 }  // namespace ruy
296 
297 #endif  // RUY_RUY_KERNEL_COMMON_H_
298