1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_ARM_H_ 17 #define RUY_RUY_KERNEL_ARM_H_ 18 19 #include <cstddef> 20 #include <cstdint> 21 22 #include "ruy/asm_helpers.h" 23 #include "ruy/kernel_common.h" 24 #include "ruy/mat.h" 25 #include "ruy/mul_params.h" 26 #include "ruy/opt_set.h" 27 #include "ruy/path.h" 28 #include "ruy/platform.h" 29 #include "ruy/profiler/instrumentation.h" 30 #include "ruy/side_pair.h" 31 #include "ruy/size_util.h" 32 #include "ruy/tune.h" 33 34 namespace ruy { 35 36 #if RUY_PLATFORM_NEON && RUY_OPT(ASM) 37 38 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) 39 RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) 40 41 #if RUY_PLATFORM_NEON_64 42 void Kernel8bitNeon(const KernelParams8bit<4, 4>& params); 43 void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params); 44 #elif RUY_PLATFORM_NEON_32 45 void Kernel8bitNeon(const KernelParams8bit<4, 2>& params); 46 void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params); 47 #endif 48 void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params); 49 void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params); 50 void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params); 51 void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params); 52 void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params); 53 54 #if RUY_PLATFORM_NEON_64 55 template <typename DstScalar> 56 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 57 static constexpr Path kPath = Path::kNeon; 58 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 59 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 60 Tuning tuning = Tuning::kAuto; 61 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 62 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 63 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 64 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 65 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 66 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 67 end_col, dst, ¶ms); 68 if (dst->layout.cols == 1 && 69 mul_params.channel_dimension() == ChannelDimension::kRow) { 70 Kernel8bitNeon1Col(params); 71 return; 72 } 73 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 74 Kernel8bitNeonA55ish(params); 75 } else { 76 Kernel8bitNeon(params); 77 } 78 } 79 }; 80 #endif 81 82 #if RUY_PLATFORM_NEON_32 83 template <typename DstScalar> 84 struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 85 static constexpr Path kPath = Path::kNeon; 86 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; 87 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>; 88 Tuning tuning = Tuning::kAuto; 89 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 90 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 91 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 92 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 93 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 94 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 95 end_col, dst, ¶ms); 96 if (dst->layout.cols == 1 && 97 mul_params.channel_dimension() == ChannelDimension::kRow) { 98 Kernel8bitNeon1Col(params); 99 return; 100 } 101 Kernel8bitNeon(params); 102 } 103 }; 104 #endif 105 106 #if RUY_PLATFORM_NEON_64 107 template <typename DstScalar> 108 struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, 109 DstScalar> { 110 static constexpr Path kPath = Path::kNeonDotprod; 111 Tuning tuning = Tuning::kAuto; 112 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 113 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 114 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 115 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 116 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 117 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 118 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 119 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 120 end_col, dst, ¶ms); 121 if (dst->layout.cols == 1 && 122 mul_params.channel_dimension() == ChannelDimension::kRow) { 123 Kernel8bitNeonDotprod1Col(params); 124 } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 125 Kernel8bitNeonDotprodA55ish(params); 126 } else if (tuning == Tuning::kX1) { 127 Kernel8bitNeonDotprodX1(params); 128 } else { 129 Kernel8bitNeonDotprod(params); 130 } 131 } 132 }; 133 #endif 134 135 void KernelFloatNeon(const KernelParamsFloat<8, 8>& params); 136 void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params); 137 void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params); 138 void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params); 139 void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params); 140 141 #if RUY_PLATFORM_NEON_64 142 // A Float kernel for ARM64 Neon. 143 template <> 144 struct Kernel<Path::kNeon, float, float, float, float> { 145 static constexpr Path kPath = Path::kNeon; 146 Tuning tuning = Tuning::kAuto; 147 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 148 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 149 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 150 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 151 const MulParams<float, float>& mul_params, int start_row, 152 int start_col, int end_row, int end_col, Mat<float>* dst) const { 153 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 154 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 155 end_col, dst, ¶ms); 156 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 157 KernelFloatNeonA55ish(params); 158 } else if (tuning == Tuning::kX1) { 159 KernelFloatNeonX1(params); 160 } else { 161 KernelFloatNeon(params); 162 } 163 } 164 }; 165 #endif 166 167 #if RUY_PLATFORM_NEON_32 168 // A Float kernel for ARM32 Neon. 169 template <> 170 struct Kernel<Path::kNeon, float, float, float, float> { 171 static constexpr Path kPath = Path::kNeon; 172 Tuning tuning = Tuning::kAuto; 173 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 174 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>; 175 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 176 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 177 const MulParams<float, float>& mul_params, int start_row, 178 int start_col, int end_row, int end_col, Mat<float>* dst) const { 179 KernelParamsFloat<8, 4> params; 180 181 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 182 end_col, dst, ¶ms); 183 184 KernelFloat32Neon(params); 185 } 186 }; 187 #endif 188 189 // While the dotprod NEON extension does not concern floating-point arithmetic, 190 // its presence allows us to distinguish, in the in-order tuning case, between 191 // A53 and A55r1. TODO: should this be folded into tuning? 192 template <> 193 struct Kernel<Path::kNeonDotprod, float, float, float, float> { 194 static constexpr Path kPath = Path::kNeonDotprod; 195 Tuning tuning = Tuning::kAuto; 196 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 197 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 198 using Base = Kernel<Path::kNeon, float, float, float, float>; 199 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 200 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 201 const MulParams<float, float>& mul_params, int start_row, 202 int start_col, int end_row, int end_col, Mat<float>* dst) const { 203 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 204 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 205 end_col, dst, ¶ms); 206 if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 207 KernelFloatNeonDotprodA55ish(params); 208 } else if (tuning == Tuning::kX1) { 209 KernelFloatNeonX1(params); 210 } else { 211 KernelFloatNeon(params); 212 } 213 } 214 }; 215 216 #endif // RUY_PLATFORM_NEON && RUY_OPT(ASM) 217 218 } // namespace ruy 219 220 #endif // RUY_RUY_KERNEL_ARM_H_ 221