xref: /aosp_15_r20/external/ruy/ruy/kernel_arm.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_ARM_H_
17 #define RUY_RUY_KERNEL_ARM_H_
18 
19 #include <cstddef>
20 #include <cstdint>
21 
22 #include "ruy/asm_helpers.h"
23 #include "ruy/kernel_common.h"
24 #include "ruy/mat.h"
25 #include "ruy/mul_params.h"
26 #include "ruy/opt_set.h"
27 #include "ruy/path.h"
28 #include "ruy/platform.h"
29 #include "ruy/profiler/instrumentation.h"
30 #include "ruy/side_pair.h"
31 #include "ruy/size_util.h"
32 #include "ruy/tune.h"
33 
34 namespace ruy {
35 
36 #if RUY_PLATFORM_NEON && RUY_OPT(ASM)
37 
38 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
39 RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
40 
41 #if RUY_PLATFORM_NEON_64
42 void Kernel8bitNeon(const KernelParams8bit<4, 4>& params);
43 void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params);
44 #elif RUY_PLATFORM_NEON_32
45 void Kernel8bitNeon(const KernelParams8bit<4, 2>& params);
46 void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params);
47 #endif
48 void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params);
49 void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params);
50 void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params);
51 void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params);
52 void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params);
53 
54 #if RUY_PLATFORM_NEON_64
55 template <typename DstScalar>
56 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
57   static constexpr Path kPath = Path::kNeon;
58   using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
59   using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
60   Tuning tuning = Tuning::kAuto;
61   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
62   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
63            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
64            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
65     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
66     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
67                          end_col, dst, &params);
68     if (dst->layout.cols == 1 &&
69         mul_params.channel_dimension() == ChannelDimension::kRow) {
70       Kernel8bitNeon1Col(params);
71       return;
72     }
73     if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
74       Kernel8bitNeonA55ish(params);
75     } else {
76       Kernel8bitNeon(params);
77     }
78   }
79 };
80 #endif
81 
82 #if RUY_PLATFORM_NEON_32
83 template <typename DstScalar>
84 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
85   static constexpr Path kPath = Path::kNeon;
86   using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
87   using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
88   Tuning tuning = Tuning::kAuto;
89   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
90   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
91            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
92            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
93     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
94     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
95                          end_col, dst, &params);
96     if (dst->layout.cols == 1 &&
97         mul_params.channel_dimension() == ChannelDimension::kRow) {
98       Kernel8bitNeon1Col(params);
99       return;
100     }
101     Kernel8bitNeon(params);
102   }
103 };
104 #endif
105 
106 #if RUY_PLATFORM_NEON_64
107 template <typename DstScalar>
108 struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t,
109               DstScalar> {
110   static constexpr Path kPath = Path::kNeonDotprod;
111   Tuning tuning = Tuning::kAuto;
112   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
113   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
114   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
115   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
116            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
117            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
118     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
119     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
120                          end_col, dst, &params);
121     if (dst->layout.cols == 1 &&
122         mul_params.channel_dimension() == ChannelDimension::kRow) {
123       Kernel8bitNeonDotprod1Col(params);
124     } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
125       Kernel8bitNeonDotprodA55ish(params);
126     } else if (tuning == Tuning::kX1) {
127       Kernel8bitNeonDotprodX1(params);
128     } else {
129       Kernel8bitNeonDotprod(params);
130     }
131   }
132 };
133 #endif
134 
135 void KernelFloatNeon(const KernelParamsFloat<8, 8>& params);
136 void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params);
137 void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params);
138 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params);
139 void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params);
140 
141 #if RUY_PLATFORM_NEON_64
142 // A Float kernel for ARM64 Neon.
143 template <>
144 struct Kernel<Path::kNeon, float, float, float, float> {
145   static constexpr Path kPath = Path::kNeon;
146   Tuning tuning = Tuning::kAuto;
147   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
148   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
149   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
150   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
151            const MulParams<float, float>& mul_params, int start_row,
152            int start_col, int end_row, int end_col, Mat<float>* dst) const {
153     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
154     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
155                           end_col, dst, &params);
156     if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
157       KernelFloatNeonA55ish(params);
158     } else if (tuning == Tuning::kX1) {
159       KernelFloatNeonX1(params);
160     } else {
161       KernelFloatNeon(params);
162     }
163   }
164 };
165 #endif
166 
167 #if RUY_PLATFORM_NEON_32
168 // A Float kernel for ARM32 Neon.
169 template <>
170 struct Kernel<Path::kNeon, float, float, float, float> {
171   static constexpr Path kPath = Path::kNeon;
172   Tuning tuning = Tuning::kAuto;
173   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
174   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
175   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
176   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
177            const MulParams<float, float>& mul_params, int start_row,
178            int start_col, int end_row, int end_col, Mat<float>* dst) const {
179     KernelParamsFloat<8, 4> params;
180 
181     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
182                           end_col, dst, &params);
183 
184     KernelFloat32Neon(params);
185   }
186 };
187 #endif
188 
189 // While the dotprod NEON extension does not concern floating-point arithmetic,
190 // its presence allows us to distinguish, in the in-order tuning case, between
191 // A53 and A55r1. TODO: should this be folded into tuning?
192 template <>
193 struct Kernel<Path::kNeonDotprod, float, float, float, float> {
194   static constexpr Path kPath = Path::kNeonDotprod;
195   Tuning tuning = Tuning::kAuto;
196   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
197   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
198   using Base = Kernel<Path::kNeon, float, float, float, float>;
199   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
200   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
201            const MulParams<float, float>& mul_params, int start_row,
202            int start_col, int end_row, int end_col, Mat<float>* dst) const {
203     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
204     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
205                           end_col, dst, &params);
206     if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
207       KernelFloatNeonDotprodA55ish(params);
208     } else if (tuning == Tuning::kX1) {
209       KernelFloatNeonX1(params);
210     } else {
211       KernelFloatNeon(params);
212     }
213   }
214 };
215 
216 #endif  // RUY_PLATFORM_NEON && RUY_OPT(ASM)
217 
218 }  // namespace ruy
219 
220 #endif  // RUY_RUY_KERNEL_ARM_H_
221