1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code 16 17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 19 20 #include "simd_wrappers.h" 21 22 namespace gemmlowp { 23 24 template <typename SrcScalarType, int N> 25 struct LoadImpl<RegBlockInt32<4, N>, 26 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 27 static RegBlockInt32<4, N> Run( 28 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 29 int col) { 30 RegBlockInt32<4, N> result; 31 for (int i = 0; i < N; i++) { 32 result.buf.reg[i] = LoadInt32x4(src.data(row, col + i)); 33 } 34 return result; 35 } 36 }; 37 38 template <typename SrcScalarType, int N> 39 struct LoadImpl<RegBlockInt32<8, N>, 40 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 41 static RegBlockInt32<8, N> Run( 42 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 43 int col) { 44 RegBlockInt32<8, N> result; 45 for (int i = 0; i < N; i++) { 46 result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i)); 47 result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i)); 48 } 49 return result; 50 } 51 }; 52 53 template <typename SrcScalarType> 54 struct LoadImpl<RegBlockInt32<1, 4>, 55 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 56 static RegBlockInt32<1, 4> Run( 57 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 58 int col) { 59 RegBlockInt32<1, 4> result; 60 std::int32_t buf[4]; 61 for (int i = 0; i < 4; i++) { 62 buf[i] = src(row, col + i); 63 } 64 result.buf.reg[0] = LoadInt32x4(buf); 65 return result; 66 } 67 }; 68 69 template <typename SrcScalarType> 70 struct LoadImpl<RegBlockInt32<1, 8>, 71 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 72 static RegBlockInt32<1, 8> Run( 73 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 74 int col) { 75 RegBlockInt32<1, 8> result; 76 std::int32_t buf[8]; 77 for (int i = 0; i < 8; i++) { 78 buf[i] = src(row, col + i); 79 } 80 result.buf.reg[0] = LoadInt32x4(buf); 81 result.buf.reg[1] = LoadInt32x4(buf + 4); 82 return result; 83 } 84 }; 85 86 template <typename SrcScalarType> 87 struct LoadImpl<RegBlockInt32<4, 1>, 88 VectorMap<SrcScalarType, VectorShape::Col>> { 89 static RegBlockInt32<4, 1> Run( 90 const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) { 91 RegBlockInt32<4, 1> result; 92 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 93 return result; 94 } 95 }; 96 97 template <typename SrcScalarType> 98 struct LoadImpl<RegBlockInt32<4, 1>, 99 VectorDup<SrcScalarType, VectorShape::Col>> { 100 static RegBlockInt32<4, 1> Run( 101 const VectorDup<SrcScalarType, VectorShape::Col>& src, int) { 102 RegBlockInt32<4, 1> result; 103 result.buf.reg[0] = LoadInt32x4(src(0)); 104 return result; 105 } 106 }; 107 108 template <typename SrcScalarType, int N> 109 struct LoadForBroadcastingImpl<RegBlockInt32<4, N>, 110 VectorMap<SrcScalarType, VectorShape::Col>> { 111 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 112 using RegisterBlockType = RegBlockInt32<4, N>; 113 using ResultBlockType = 114 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 115 SrcObjectType>::Type; 116 117 static ResultBlockType Run(const SrcObjectType& src, int pos) { 118 ResultBlockType result; 119 static_assert(ResultBlockType::kRegisterCount == 1, ""); 120 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 121 return result; 122 } 123 }; 124 125 template <typename SrcScalarType, int N> 126 struct LoadForBroadcastingImpl<RegBlockInt32<8, N>, 127 VectorMap<SrcScalarType, VectorShape::Col>> { 128 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 129 using RegisterBlockType = RegBlockInt32<8, N>; 130 using ResultBlockType = 131 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 132 SrcObjectType>::Type; 133 134 static ResultBlockType Run(const SrcObjectType& src, int pos) { 135 ResultBlockType result; 136 static_assert(ResultBlockType::kRegisterCount == 2, ""); 137 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 138 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 139 return result; 140 } 141 }; 142 143 template <typename SrcScalarType> 144 struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>, 145 VectorMap<SrcScalarType, VectorShape::Row>> { 146 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 147 using RegisterBlockType = RegBlockInt32<4, 1>; 148 using ResultBlockType = 149 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 150 SrcObjectType>::Type; 151 152 static ResultBlockType Run(const SrcObjectType& src, int pos) { 153 ResultBlockType result; 154 result.buf.reg[0] = src(pos); 155 return result; 156 } 157 }; 158 159 template <typename SrcScalarType, int N> 160 struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>, 161 VectorMap<SrcScalarType, VectorShape::Row>> { 162 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 163 using RegisterBlockType = RegBlockInt32<N, 4>; 164 using ResultBlockType = 165 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 166 SrcObjectType>::Type; 167 168 static ResultBlockType Run(const SrcObjectType& src, int pos) { 169 ResultBlockType result; 170 static_assert(ResultBlockType::kRegisterCount == 1, ""); 171 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 172 return result; 173 } 174 }; 175 176 template <typename SrcScalarType, int N> 177 struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>, 178 VectorMap<SrcScalarType, VectorShape::Row>> { 179 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 180 using RegisterBlockType = RegBlockInt32<N, 8>; 181 using ResultBlockType = 182 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 183 SrcObjectType>::Type; 184 185 static ResultBlockType Run(const SrcObjectType& src, int pos) { 186 ResultBlockType result; 187 static_assert(ResultBlockType::kRegisterCount == 2, ""); 188 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 189 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 190 return result; 191 } 192 }; 193 194 // 4x1 := 4x1 + 1x1 195 template <> 196 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 197 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 198 const RegBlockInt32<1, 1>& rhs) { 199 RegBlockInt32<4, 1> result; 200 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 201 return result; 202 } 203 }; 204 205 // 1x4 := 1x4 + 1x1 206 template <> 207 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 208 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 209 const RegBlockInt32<1, 1>& rhs) { 210 RegBlockInt32<1, 4> result; 211 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 212 return result; 213 } 214 }; 215 216 // 4x1 := 4x1 + 4x1 217 template <> 218 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 219 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 220 const RegBlockInt32<4, 1>& rhs) { 221 RegBlockInt32<4, 1> result; 222 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 223 return result; 224 } 225 }; 226 227 // 1x4 := 1x4 + 1x4 228 template <> 229 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 230 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 231 const RegBlockInt32<1, 4>& rhs) { 232 RegBlockInt32<1, 4> result; 233 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 234 return result; 235 } 236 }; 237 238 // 4x4 := 4x4 + 1x4 239 template <> 240 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 241 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 242 const RegBlockInt32<1, 4>& rhs) { 243 RegBlockInt32<4, 4> result; 244 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 245 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 246 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 247 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 248 return result; 249 } 250 }; 251 252 // 4x4 := 4x4 + 4x1 253 template <> 254 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 255 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 256 const RegBlockInt32<4, 1>& rhs) { 257 RegBlockInt32<4, 4> result; 258 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 259 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]); 260 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 261 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]); 262 return result; 263 } 264 }; 265 266 // 8x1 := 8x1 + 1x1 267 template <> 268 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 269 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 270 const RegBlockInt32<1, 1>& rhs) { 271 RegBlockInt32<8, 1> result; 272 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 273 for (int i = 0; i < 2; i++) { 274 result.buf.reg[i] = Add(lhs.buf.reg[i], p); 275 } 276 return result; 277 } 278 }; 279 280 // 8x1 := 8x1 + 8x1 281 template <> 282 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 283 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 284 const RegBlockInt32<8, 1>& rhs) { 285 RegBlockInt32<8, 1> result; 286 for (int i = 0; i < 2; i++) { 287 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); 288 } 289 return result; 290 } 291 }; 292 293 // 8x4 := 8x4 + 1x4 294 template <> 295 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 296 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 297 const RegBlockInt32<1, 4>& rhs) { 298 RegBlockInt32<8, 4> result; 299 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 300 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 301 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 302 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 303 result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 304 result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 305 result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 306 result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 307 return result; 308 } 309 }; 310 311 // 8x4 := 8x4 + 8x1 312 template <> 313 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 314 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 315 const RegBlockInt32<8, 1>& rhs) { 316 RegBlockInt32<8, 4> result; 317 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 318 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 319 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 320 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]); 321 result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]); 322 result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]); 323 result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]); 324 result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]); 325 return result; 326 } 327 }; 328 329 // 1x8 := 1x8 + 1x8 330 template <> 331 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { 332 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 333 const RegBlockInt32<1, 8>& rhs) { 334 RegBlockInt32<1, 8> result; 335 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 336 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 337 return result; 338 } 339 }; 340 341 // 1x8 := 1x8 + 1x1 342 template <> 343 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { 344 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 345 const RegBlockInt32<1, 1>& rhs) { 346 RegBlockInt32<1, 8> result; 347 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 348 result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 349 return result; 350 } 351 }; 352 353 // 4x1 := 4x1 + 1x1 354 template <> 355 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, 356 RegBlockInt32<1, 1>> { 357 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 358 const RegBlockInt32<1, 1>& rhs) { 359 RegBlockInt32<4, 1> result; 360 result.buf.reg[0] = SaturatingRoundingDoublingHighMul( 361 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 362 return result; 363 } 364 }; 365 366 // 1x4 := 1x4 + 1x1 367 template <> 368 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, 369 RegBlockInt32<1, 1>> { 370 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 371 const RegBlockInt32<1, 1>& rhs) { 372 RegBlockInt32<1, 4> result; 373 result.buf.reg[0] = SaturatingRoundingDoublingHighMul( 374 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 375 return result; 376 } 377 }; 378 379 // 4x1 := 4x1 + 4x1 380 template <> 381 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, 382 RegBlockInt32<4, 1>> { 383 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 384 const RegBlockInt32<4, 1>& rhs) { 385 RegBlockInt32<4, 1> result; 386 result.buf.reg[0] = 387 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); 388 return result; 389 } 390 }; 391 392 // 1x4 := 1x4 + 1x4 393 template <> 394 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, 395 RegBlockInt32<1, 4>> { 396 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 397 const RegBlockInt32<1, 4>& rhs) { 398 RegBlockInt32<1, 4> result; 399 result.buf.reg[0] = 400 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); 401 return result; 402 } 403 }; 404 405 // 4x4 := 4x4 + 1x4 406 template <> 407 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, 408 RegBlockInt32<1, 4>> { 409 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 410 const RegBlockInt32<1, 4>& rhs) { 411 RegBlockInt32<4, 4> result; 412 result.buf.reg[0] = SaturatingRoundingDoublingHighMul( 413 lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 414 result.buf.reg[1] = SaturatingRoundingDoublingHighMul( 415 lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 416 result.buf.reg[2] = SaturatingRoundingDoublingHighMul( 417 lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 418 result.buf.reg[3] = SaturatingRoundingDoublingHighMul( 419 lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 420 return result; 421 } 422 }; 423 424 // 4x4 := 4x4 + 4x1 425 template <> 426 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, 427 RegBlockInt32<4, 1>> { 428 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 429 const RegBlockInt32<4, 1>& rhs) { 430 RegBlockInt32<4, 4> result; 431 result.buf.reg[0] = 432 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); 433 result.buf.reg[1] = 434 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]); 435 result.buf.reg[2] = 436 SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); 437 result.buf.reg[3] = 438 SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]); 439 return result; 440 } 441 }; 442 443 // 8x1 := 8x1 + 1x1 444 template <> 445 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, 446 RegBlockInt32<1, 1>> { 447 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 448 const RegBlockInt32<1, 1>& rhs) { 449 RegBlockInt32<8, 1> result; 450 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 451 for (int i = 0; i < 2; i++) { 452 result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p); 453 } 454 return result; 455 } 456 }; 457 458 // 8x1 := 8x1 + 8x1 459 template <> 460 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, 461 RegBlockInt32<8, 1>> { 462 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 463 const RegBlockInt32<8, 1>& rhs) { 464 RegBlockInt32<8, 1> result; 465 for (int i = 0; i < 2; i++) { 466 result.buf.reg[i] = 467 SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]); 468 } 469 return result; 470 } 471 }; 472 473 // 8x4 := 8x4 + 1x4 474 template <> 475 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, 476 RegBlockInt32<1, 4>> { 477 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 478 const RegBlockInt32<1, 4>& rhs) { 479 RegBlockInt32<8, 4> result; 480 result.buf.reg[0] = SaturatingRoundingDoublingHighMul( 481 lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 482 result.buf.reg[1] = SaturatingRoundingDoublingHighMul( 483 lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 484 result.buf.reg[2] = SaturatingRoundingDoublingHighMul( 485 lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 486 result.buf.reg[3] = SaturatingRoundingDoublingHighMul( 487 lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 488 result.buf.reg[4] = SaturatingRoundingDoublingHighMul( 489 lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 490 result.buf.reg[5] = SaturatingRoundingDoublingHighMul( 491 lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 492 result.buf.reg[6] = SaturatingRoundingDoublingHighMul( 493 lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 494 result.buf.reg[7] = SaturatingRoundingDoublingHighMul( 495 lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 496 return result; 497 } 498 }; 499 500 // 8x4 := 8x4 + 8x1 501 template <> 502 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, 503 RegBlockInt32<8, 1>> { 504 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 505 const RegBlockInt32<8, 1>& rhs) { 506 RegBlockInt32<8, 4> result; 507 result.buf.reg[0] = 508 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); 509 result.buf.reg[1] = 510 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); 511 result.buf.reg[2] = 512 SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); 513 result.buf.reg[3] = 514 SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]); 515 result.buf.reg[4] = 516 SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]); 517 result.buf.reg[5] = 518 SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]); 519 result.buf.reg[6] = 520 SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]); 521 result.buf.reg[7] = 522 SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]); 523 return result; 524 } 525 }; 526 527 // 1x8 := 1x8 + 1x8 528 template <> 529 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, 530 RegBlockInt32<1, 8>> { 531 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 532 const RegBlockInt32<1, 8>& rhs) { 533 RegBlockInt32<1, 8> result; 534 result.buf.reg[0] = 535 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); 536 result.buf.reg[1] = 537 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); 538 return result; 539 } 540 }; 541 542 // 1x8 := 1x8 + 1x1 543 template <> 544 struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, 545 RegBlockInt32<1, 1>> { 546 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 547 const RegBlockInt32<1, 1>& rhs) { 548 RegBlockInt32<1, 8> result; 549 result.buf.reg[0] = SaturatingRoundingDoublingHighMul( 550 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 551 result.buf.reg[1] = SaturatingRoundingDoublingHighMul( 552 lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 553 return result; 554 } 555 }; 556 557 // 4x1 := 4x1 * 1x1 558 template <> 559 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 560 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 561 const RegBlockInt32<1, 1>& rhs) { 562 RegBlockInt32<4, 1> result; 563 result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 564 return result; 565 } 566 }; 567 568 // 4x1 := 4x1 * 4x1 569 template <> 570 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 571 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 572 const RegBlockInt32<4, 1>& rhs) { 573 RegBlockInt32<4, 1> result; 574 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 575 return result; 576 } 577 }; 578 579 // 1x4 := 1x4 * 1x4 580 template <> 581 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 582 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 583 const RegBlockInt32<1, 4>& rhs) { 584 RegBlockInt32<1, 4> result; 585 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 586 return result; 587 } 588 }; 589 590 // 1x4 := 1x4 * 1x1 591 template <> 592 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 593 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 594 const RegBlockInt32<1, 1>& rhs) { 595 RegBlockInt32<1, 4> result; 596 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 597 return result; 598 } 599 }; 600 601 // 4x4 := 4x4 * 1x4 602 template <> 603 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 604 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 605 const RegBlockInt32<1, 4>& rhs) { 606 RegBlockInt32<4, 4> result; 607 const Int32x4 p = rhs.buf.reg[0]; 608 result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p); 609 result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p); 610 result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p); 611 result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p); 612 return result; 613 } 614 }; 615 616 // 4x4 := 4x4 * 4x1 617 template <> 618 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 619 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 620 const RegBlockInt32<4, 1>& rhs) { 621 RegBlockInt32<4, 4> result; 622 const Int32x4 p = rhs.buf.reg[0]; 623 result.buf.reg[0] = Mul(lhs.buf.reg[0], p); 624 result.buf.reg[1] = Mul(lhs.buf.reg[1], p); 625 result.buf.reg[2] = Mul(lhs.buf.reg[2], p); 626 result.buf.reg[3] = Mul(lhs.buf.reg[3], p); 627 return result; 628 } 629 }; 630 631 // 8x1 := 8x1 * 1x1 632 template <> 633 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 634 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 635 const RegBlockInt32<1, 1>& rhs) { 636 RegBlockInt32<8, 1> result; 637 const std::int32_t p = rhs.buf.reg[0]; 638 for (int i = 0; i < 2; i++) { 639 result.buf.reg[i] = Mul(lhs.buf.reg[i], p); 640 } 641 return result; 642 } 643 }; 644 645 // 8x1 := 8x1 * 8x1 646 template <> 647 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 648 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 649 const RegBlockInt32<8, 1>& rhs) { 650 RegBlockInt32<8, 1> result; 651 for (int i = 0; i < 2; i++) { 652 result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]); 653 } 654 return result; 655 } 656 }; 657 658 // 8x4 := 8x4 * 1x4 659 template <> 660 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 661 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 662 const RegBlockInt32<1, 4>& rhs) { 663 RegBlockInt32<8, 4> result; 664 const Int32x4 p = rhs.buf.reg[0]; 665 for (int i = 0; i < 2; i++) { 666 result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p); 667 result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p); 668 result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p); 669 result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p); 670 } 671 return result; 672 } 673 }; 674 675 // 8x4 := 8x4 * 8x1 676 template <> 677 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 678 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 679 const RegBlockInt32<8, 1>& rhs) { 680 RegBlockInt32<8, 4> result; 681 const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]}; 682 for (int i = 0; i < 4; i++) { 683 for (int j = 0; j < 2; j++) { 684 const int k = j + 2 * i; 685 result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]); 686 } 687 } 688 return result; 689 } 690 }; 691 692 // Rx1 += Rx1 * 1x1 693 template <int Rows> 694 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 695 RegBlockInt32<Rows, 1>> { 696 static void Run(const RegBlockInt32<Rows, 1>& lhs, 697 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) { 698 const std::int32_t p = rhs.buf.reg[0]; 699 for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) { 700 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 701 } 702 } 703 }; 704 705 // RxC += Rx1 * 1x1 706 template <int Rows, int Cols> 707 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 708 RegBlockInt32<Rows, Cols>> { 709 static void Run(const RegBlockInt32<Rows, 1>& lhs, 710 const RegBlockInt32<1, 1>& rhs, 711 RegBlockInt32<Rows, Cols>* acc) { 712 const std::int32_t p = rhs.buf.reg[0]; 713 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 714 for (int i = 0; i < kRegsPerCol; i++) { 715 const Int32x4 q = Mul(lhs.buf.reg[i], p); 716 for (int j = 0; j < Cols; j++) { 717 acc->buf.reg[i + j * kRegsPerCol] = 718 Add(acc->buf.reg[i + j * kRegsPerCol], q); 719 } 720 } 721 } 722 }; 723 724 // 1xC += 1xC * 1x1 725 template <int Cols> 726 struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>, 727 RegBlockInt32<1, Cols>> { 728 static void Run(const RegBlockInt32<1, Cols>& lhs, 729 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 730 const std::int32_t p = rhs.buf.reg[0]; 731 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 732 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 733 } 734 } 735 }; 736 737 // RxC += 1x1 * 1x1 738 template <int Rows, int Cols> 739 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 740 RegBlockInt32<Rows, Cols>> { 741 static void Run(const RegBlockInt32<1, 1>& lhs, 742 const RegBlockInt32<1, 1>& rhs, 743 RegBlockInt32<Rows, Cols>* acc) { 744 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 745 for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) { 746 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 747 } 748 } 749 }; 750 751 // 1x1 += 1x1 * 1x1 752 template <> 753 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 754 RegBlockInt32<1, 1>> { 755 static void Run(const RegBlockInt32<1, 1>& lhs, 756 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) { 757 MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]); 758 } 759 }; 760 761 // Rx4 += Rx1 * 1x4 762 template <int Rows> 763 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>, 764 RegBlockInt32<Rows, 4>> { 765 static void Run(const RegBlockInt32<Rows, 1>& lhs, 766 const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) { 767 const Int32x4 p = rhs.buf.reg[0]; 768 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 769 for (int i = 0; i < kRegsPerCol; i++) { 770 MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]); 771 MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]); 772 MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]); 773 MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]); 774 } 775 } 776 }; 777 778 // Rx4 += 1x4 * 1x1 779 template <int Rows> 780 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 781 RegBlockInt32<Rows, 4>> { 782 static void Run(const RegBlockInt32<1, 4>& lhs, 783 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) { 784 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 785 Int32x4 q[4]; 786 q[0] = DupLane<0>(p); 787 q[1] = DupLane<1>(p); 788 q[2] = DupLane<2>(p); 789 q[3] = DupLane<3>(p); 790 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 791 for (int i = 0; i < kRegsPerCol; i++) { 792 for (int j = 0; j < 4; j++) { 793 acc->buf.reg[i + j * kRegsPerCol] = 794 Add(q[j], acc->buf.reg[i + j * kRegsPerCol]); 795 } 796 } 797 } 798 }; 799 800 // 1xC += 1x1 * 1x1 801 template <int Cols> 802 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 803 RegBlockInt32<1, Cols>> { 804 static void Run(const RegBlockInt32<1, 1>& lhs, 805 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 806 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 807 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 808 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 809 } 810 } 811 }; 812 813 // 1x4 += 1x4 * 1x1 814 template <> 815 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 816 RegBlockInt32<1, 4>> { 817 static void Run(const RegBlockInt32<1, 4>& lhs, 818 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) { 819 const std::int32_t p = rhs.buf.reg[0]; 820 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 821 } 822 }; 823 824 // 4xC += 4x1 * 1x1 825 template <int Cols> 826 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 827 RegBlockInt32<4, Cols>> { 828 static void Run(const RegBlockInt32<4, 1>& lhs, 829 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) { 830 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 831 for (int i = 0; i < Cols; i++) { 832 acc->buf.reg[i] = Add(p, acc->buf.reg[i]); 833 } 834 } 835 }; 836 837 // 4x1 += 4x1 * 1x1 838 template <> 839 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 840 RegBlockInt32<4, 1>> { 841 static void Run(const RegBlockInt32<4, 1>& lhs, 842 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) { 843 const std::int32_t p = rhs.buf.reg[0]; 844 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 845 } 846 }; 847 848 } // namespace gemmlowp 849 850 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 851