1 #pragma once 2 3 // DO NOT DEFINE STATIC DATA IN THIS HEADER! 4 // See Note [Do not compile initializers with AVX] 5 6 #include <ATen/cpu/vec/intrinsics.h> 7 #include <ATen/cpu/vec/vec_base.h> 8 #include <c10/util/irange.h> 9 10 #if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) 11 #include <sleef.h> 12 #endif 13 14 // Sleef offers vectorized versions of some transcedentals 15 // such as sin, cos, tan etc.. 16 // However for now opting for STL, since we are not building 17 // with Sleef for mobile yet. 18 19 namespace at::vec { 20 // See Note [CPU_CAPABILITY namespace] 21 inline namespace CPU_CAPABILITY { 22 23 // Right now contains only aarch64 implementation. 24 // Due to follow two reasons aarch32 is not currently supported. 25 // 1. Due to difference in ISA been aarch32 and aarch64, intrinsics 26 // that work for aarch64 dont work for aarch32. 27 // 2. Android NDK r21 has problems with compiling aarch32. 28 // Clang seg faults. 29 // https://github.com/android/ndk/issues/1248 30 // https://bugs.llvm.org/show_bug.cgi?id=45824 31 // Most likely we will do aarch32 support with inline asm. 32 #if defined(__aarch64__) 33 34 #ifdef __BIG_ENDIAN__ 35 #error "Big endian is not supported." 36 #endif 37 38 #if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) 39 #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code 40 #else 41 #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code 42 #endif 43 44 template<int index, bool mask_val> 45 struct BlendRegs { 46 static float32x4_t impl( 47 const float32x4_t& a, const float32x4_t& b, float32x4_t& res); 48 }; 49 50 template<int index> 51 struct BlendRegs<index, true>{ 52 static float32x4_t impl( 53 const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { 54 return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); 55 } 56 }; 57 58 template<int index> 59 struct BlendRegs<index, false>{ 60 static float32x4_t impl( 61 const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { 62 return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); 63 } 64 }; 65 66 template <> class Vectorized<float> { 67 private: 68 float32x4x2_t values; 69 public: 70 using value_type = float; 71 using size_type = int; 72 static constexpr size_type size() { 73 return 8; 74 } 75 Vectorized() {} 76 Vectorized(float32x4x2_t v) : values(v) {} 77 Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {} 78 Vectorized(float val0, float val1, float val2, float val3, 79 float val4, float val5, float val6, float val7) : 80 values{val0, val1, val2, val3, val4, val5, val6, val7} {} 81 Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {} 82 operator float32x4x2_t() const { 83 return values; 84 } 85 template <int64_t mask> 86 static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) { 87 Vectorized<float> vec; 88 // 0. 89 vec.values.val[0] = 90 BlendRegs<0, (mask & 0x01)!=0>::impl( 91 a.values.val[0], b.values.val[0], vec.values.val[0]); 92 vec.values.val[0] = 93 BlendRegs<1, (mask & 0x02)!=0>::impl( 94 a.values.val[0], b.values.val[0], vec.values.val[0]); 95 vec.values.val[0] = 96 BlendRegs<2, (mask & 0x04)!=0>::impl( 97 a.values.val[0], b.values.val[0], vec.values.val[0]); 98 vec.values.val[0] = 99 BlendRegs<3, (mask & 0x08)!=0>::impl( 100 a.values.val[0], b.values.val[0], vec.values.val[0]); 101 // 1. 102 vec.values.val[1] = 103 BlendRegs<0, (mask & 0x10)!=0>::impl( 104 a.values.val[1], b.values.val[1], vec.values.val[1]); 105 vec.values.val[1] = 106 BlendRegs<1, (mask & 0x20)!=0>::impl( 107 a.values.val[1], b.values.val[1], vec.values.val[1]); 108 vec.values.val[1] = 109 BlendRegs<2, (mask & 0x40)!=0>::impl( 110 a.values.val[1], b.values.val[1], vec.values.val[1]); 111 vec.values.val[1] = 112 BlendRegs<3, (mask & 0x80)!=0>::impl( 113 a.values.val[1], b.values.val[1], vec.values.val[1]); 114 return vec; 115 } 116 static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b, 117 const Vectorized<float>& mask) { 118 // TODO 119 // NB: This requires that each value, i.e., each uint value, 120 // of the mask either all be zeros or all be 1s. 121 // We perhaps need some kind of an assert? 122 // But that will affect performance. 123 Vectorized<float> vec(mask.values); 124 vec.values.val[0] = vbslq_f32( 125 vreinterpretq_u32_f32(vec.values.val[0]), 126 b.values.val[0], 127 a.values.val[0]); 128 vec.values.val[1] = vbslq_f32( 129 vreinterpretq_u32_f32(vec.values.val[1]), 130 b.values.val[1], 131 a.values.val[1]); 132 return vec; 133 } 134 template<typename step_t> 135 static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) { 136 const Vectorized<float> base_vec(base); 137 const Vectorized<float> step_vec(step); 138 const Vectorized<float> step_sizes(0, 1, 2, 3, 4, 5, 6, 7); 139 return fmadd(step_sizes, step_vec, base_vec); 140 } 141 static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b, 142 int64_t count = size()) { 143 switch (count) { 144 case 0: 145 return a; 146 case 1: 147 { 148 Vectorized<float> vec; 149 static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; 150 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 151 vec.values.val[1] = a.values.val[1]; 152 vec.values.val[0] = vbslq_f32( 153 vreinterpretq_u32_f32(vec.values.val[0]), 154 b.values.val[0], 155 a.values.val[0]); 156 return vec; 157 } 158 case 2: 159 { 160 Vectorized<float> vec; 161 static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; 162 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 163 vec.values.val[1] = a.values.val[1]; 164 vec.values.val[0] = vbslq_f32( 165 vreinterpretq_u32_f32(vec.values.val[0]), 166 b.values.val[0], 167 a.values.val[0]); 168 return vec; 169 } 170 case 3: 171 { 172 Vectorized<float> vec; 173 static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; 174 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 175 vec.values.val[1] = a.values.val[1]; 176 vec.values.val[0] = vbslq_f32( 177 vreinterpretq_u32_f32(vec.values.val[0]), 178 b.values.val[0], 179 a.values.val[0]); 180 return vec; 181 } 182 case 4: 183 return Vectorized<float>(b.values.val[0], a.values.val[1]); 184 case 5: 185 { 186 Vectorized<float> vec; 187 static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0}; 188 vec.values.val[0] = b.values.val[0]; 189 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 190 vec.values.val[1] = vbslq_f32( 191 vreinterpretq_u32_f32(vec.values.val[1]), 192 b.values.val[1], 193 a.values.val[1]); 194 return vec; 195 } 196 case 6: 197 { 198 Vectorized<float> vec; 199 static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; 200 vec.values.val[0] = b.values.val[0]; 201 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 202 vec.values.val[1] = vbslq_f32( 203 vreinterpretq_u32_f32(vec.values.val[1]), 204 b.values.val[1], 205 a.values.val[1]); 206 return vec; 207 } 208 case 7: 209 { 210 Vectorized<float> vec; 211 static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; 212 vec.values.val[0] = b.values.val[0]; 213 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 214 vec.values.val[1] = vbslq_f32( 215 vreinterpretq_u32_f32(vec.values.val[1]), 216 b.values.val[1], 217 a.values.val[1]); 218 return vec; 219 } 220 } 221 return b; 222 } 223 static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { 224 if (count == size()) { 225 return vld1q_f32_x2(reinterpret_cast<const float*>(ptr)); 226 } 227 else if (count == (size() >> 1)) { 228 Vectorized<float> res; 229 res.values.val[0] = vld1q_f32(reinterpret_cast<const float*>(ptr)); 230 res.values.val[1] = vdupq_n_f32(0.f); 231 return res; 232 } 233 else { 234 __at_align__ float tmp_values[size()]; 235 for (const auto i : c10::irange(size())) { 236 tmp_values[i] = 0.0; 237 } 238 std::memcpy( 239 tmp_values, 240 reinterpret_cast<const float*>(ptr), 241 count * sizeof(float)); 242 return vld1q_f32_x2(reinterpret_cast<const float*>(tmp_values)); 243 } 244 } 245 void store(void* ptr, int64_t count = size()) const { 246 if (count == size()) { 247 vst1q_f32_x2(reinterpret_cast<float*>(ptr), values); 248 } 249 else if (count == (size() >> 1)) { 250 vst1q_f32(reinterpret_cast<float*>(ptr), values.val[0]); 251 } 252 else { 253 float tmp_values[size()]; 254 vst1q_f32_x2(reinterpret_cast<float*>(tmp_values), values); 255 std::memcpy(ptr, tmp_values, count * sizeof(float)); 256 } 257 } 258 inline const float32x4_t& get_low() const { 259 return values.val[0]; 260 } 261 inline float32x4_t& get_low() { 262 return values.val[0]; 263 } 264 inline const float32x4_t& get_high() const { 265 return values.val[1]; 266 } 267 inline float32x4_t& get_high() { 268 return values.val[1]; 269 } 270 // Very slow implementation of indexing. 271 // Only required because vec256_qint refers to this. 272 // Once we specialize that implementation for ARM 273 // this should be removed. TODO (kimishpatel) 274 float operator[](int idx) const { 275 __at_align__ float tmp[size()]; 276 store(tmp); 277 return tmp[idx]; 278 } 279 float operator[](int idx) { 280 __at_align__ float tmp[size()]; 281 store(tmp); 282 return tmp[idx]; 283 } 284 // For boolean version where we want to if any 1/all zero 285 // etc. can be done faster in a different way. 286 int zero_mask() const { 287 __at_align__ float tmp[size()]; 288 store(tmp); 289 int mask = 0; 290 for (int i = 0; i < size(); ++ i) { 291 if (tmp[i] == 0.f) { 292 mask |= (1 << i); 293 } 294 } 295 return mask; 296 } 297 Vectorized<float> isnan() const { 298 __at_align__ float tmp[size()]; 299 __at_align__ float res[size()]; 300 store(tmp); 301 for (const auto i : c10::irange(size())) { 302 if (_isnan(tmp[i])) { 303 std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(float)); 304 } else { 305 std::memset(static_cast<void*>(&res[i]), 0, sizeof(float)); 306 } 307 } 308 return loadu(res); 309 }; 310 bool has_inf_nan() const { 311 __at_align__ float tmp[size()]; 312 store(tmp); 313 for (const auto i : c10::irange(size())) { 314 if(_isnan(tmp[i]) || _isinf(tmp[i])) { 315 return true; 316 } 317 } 318 return false; 319 } 320 Vectorized<float> map(float (*const f)(float)) const { 321 __at_align__ float tmp[size()]; 322 store(tmp); 323 for (const auto i : c10::irange(size())) { 324 tmp[i] = f(tmp[i]); 325 } 326 return loadu(tmp); 327 } 328 Vectorized<float> abs() const { 329 return Vectorized<float>(vabsq_f32(values.val[0]), vabsq_f32(values.val[1])); 330 } 331 Vectorized<float> angle() const { 332 auto zero = Vectorized<float>(0); 333 auto pi = Vectorized<float>(c10::pi<float>); 334 auto tmp = blendv(zero, pi, *this < zero); 335 return blendv(tmp, *this, isnan()); 336 } 337 Vectorized<float> real() const { 338 return *this; 339 } 340 Vectorized<float> imag() const { 341 return Vectorized<float>(0.f); 342 } 343 Vectorized<float> conj() const { 344 return *this; 345 } 346 Vectorized<float> acos() const { 347 return USE_SLEEF( 348 Vectorized<float>(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])), 349 map(std::acos) 350 ); 351 } 352 Vectorized<float> acosh() const { 353 return USE_SLEEF( 354 Vectorized<float>(Sleef_acoshf4_u10(values.val[0]), Sleef_acoshf4_u10(values.val[1])), 355 map(std::acosh) 356 ); 357 } 358 Vectorized<float> asin() const { 359 return USE_SLEEF( 360 Vectorized<float>(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])), 361 map(std::asin) 362 ); 363 } 364 Vectorized<float> atan() const { 365 return USE_SLEEF( 366 Vectorized<float>(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])), 367 map(std::atan) 368 ); 369 } 370 Vectorized<float> atanh() const { 371 return USE_SLEEF( 372 Vectorized<float>(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])), 373 map(std::atanh) 374 ); 375 } 376 Vectorized<float> atan2(const Vectorized<float> &exp) const { 377 USE_SLEEF( 378 { 379 return Vectorized<float>(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]), 380 Sleef_atan2f4_u10(values.val[1], exp.values.val[1])); 381 }, 382 { 383 __at_align__ float tmp[size()]; 384 __at_align__ float tmp_exp[size()]; 385 store(tmp); 386 exp.store(tmp_exp); 387 for (const auto i : c10::irange(size())) { 388 tmp[i] = std::atan2(tmp[i], tmp_exp[i]); 389 } 390 return loadu(tmp); 391 } 392 ) 393 } 394 Vectorized<float> copysign(const Vectorized<float> &sign) const { 395 USE_SLEEF( 396 { 397 return Vectorized<float>(Sleef_copysignf4(values.val[0], sign.values.val[0]), 398 Sleef_copysignf4(values.val[1], sign.values.val[1])); 399 }, 400 { 401 __at_align__ float tmp[size()]; 402 __at_align__ float tmp_sign[size()]; 403 store(tmp); 404 sign.store(tmp_sign); 405 for (size_type i = 0; i < size(); i++) { 406 tmp[i] = std::copysign(tmp[i], tmp_sign[i]); 407 } 408 return loadu(tmp); 409 } 410 ) 411 } 412 Vectorized<float> erf() const; 413 Vectorized<float> erfc() const { 414 return USE_SLEEF( 415 Vectorized<float>(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])), 416 map(std::erfc) 417 ); 418 } 419 Vectorized<float> erfinv() const { 420 return map(calc_erfinv); 421 } 422 Vectorized<float> exp() const { 423 return USE_SLEEF( 424 Vectorized<float>(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])), 425 map(std::exp) 426 ); 427 } 428 Vectorized<float> exp2() const { 429 return USE_SLEEF( 430 Vectorized<float>(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])), 431 map(std::exp2) 432 ); 433 } 434 Vectorized<float> expm1() const { 435 return USE_SLEEF( 436 Vectorized<float>(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])), 437 map(std::expm1) 438 ); 439 } 440 Vectorized<float> exp_u20() const { 441 return exp(); 442 } 443 Vectorized<float> fmod(const Vectorized<float>& q) const { 444 USE_SLEEF( 445 { 446 return Vectorized<float>(Sleef_fmodf4(values.val[0], q.values.val[0]), 447 Sleef_fmodf4(values.val[1], q.values.val[1])); 448 }, 449 { 450 __at_align__ float tmp[size()]; 451 __at_align__ float tmp_q[size()]; 452 store(tmp); 453 q.store(tmp_q); 454 for (const auto i : c10::irange(size())) { 455 tmp[i] = std::fmod(tmp[i], tmp_q[i]); 456 } 457 return loadu(tmp); 458 } 459 ) 460 } 461 Vectorized<float> hypot(const Vectorized<float> &b) const { 462 USE_SLEEF( 463 { 464 return Vectorized<float>(Sleef_hypotf4_u05(values.val[0], b.values.val[0]), 465 Sleef_hypotf4_u05(values.val[1], b.values.val[1])); 466 }, 467 { 468 __at_align__ float tmp[size()]; 469 __at_align__ float tmp_b[size()]; 470 store(tmp); 471 b.store(tmp_b); 472 for (const auto i : c10::irange(size())) { 473 tmp[i] = std::hypot(tmp[i], tmp_b[i]); 474 } 475 return loadu(tmp); 476 } 477 ) 478 } 479 Vectorized<float> i0() const { 480 return map(calc_i0); 481 } 482 Vectorized<float> i0e() const { 483 return map(calc_i0e); 484 } 485 Vectorized<float> digamma() const { 486 return map(calc_digamma); 487 } 488 Vectorized<float> igamma(const Vectorized<float> &x) const { 489 __at_align__ float tmp[size()]; 490 __at_align__ float tmp_x[size()]; 491 store(tmp); 492 x.store(tmp_x); 493 for (const auto i : c10::irange(size())) { 494 tmp[i] = calc_igamma(tmp[i], tmp_x[i]); 495 } 496 return loadu(tmp); 497 } 498 Vectorized<float> igammac(const Vectorized<float> &x) const { 499 __at_align__ float tmp[size()]; 500 __at_align__ float tmp_x[size()]; 501 store(tmp); 502 x.store(tmp_x); 503 for (const auto i : c10::irange(size())) { 504 tmp[i] = calc_igammac(tmp[i], tmp_x[i]); 505 } 506 return loadu(tmp); 507 } 508 Vectorized<float> log() const { 509 return USE_SLEEF( 510 Vectorized<float>(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])), 511 map(std::log) 512 ); 513 } 514 Vectorized<float> log10() const { 515 return USE_SLEEF( 516 Vectorized<float>(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])), 517 map(std::log10) 518 ); 519 } 520 Vectorized<float> log1p() const { 521 return USE_SLEEF( 522 Vectorized<float>(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])), 523 map(std::log1p) 524 ); 525 } 526 Vectorized<float> log2() const { 527 return USE_SLEEF( 528 Vectorized<float>(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])), 529 map(std::log2) 530 ); 531 } 532 Vectorized<float> nextafter(const Vectorized<float> &b) const { 533 USE_SLEEF( 534 { 535 return Vectorized<float>(Sleef_nextafterf4(values.val[0], b.values.val[0]), 536 Sleef_nextafterf4(values.val[1], b.values.val[1])); 537 }, 538 { 539 __at_align__ float tmp[size()]; 540 __at_align__ float tmp_b[size()]; 541 store(tmp); 542 b.store(tmp_b); 543 for (const auto i : c10::irange(size())) { 544 tmp[i] = std::nextafter(tmp[i], tmp_b[i]); 545 } 546 return loadu(tmp); 547 } 548 ) 549 } 550 Vectorized<float> frac() const; 551 Vectorized<float> sin() const { 552 return USE_SLEEF( 553 Vectorized<float>(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])), 554 map(std::sin) 555 ); 556 } 557 Vectorized<float> sinh() const { 558 return USE_SLEEF( 559 Vectorized<float>(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])), 560 map(std::sinh) 561 ); 562 } 563 Vectorized<float> cos() const { 564 return USE_SLEEF( 565 Vectorized<float>(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])), 566 map(std::cos) 567 ); 568 } 569 Vectorized<float> cosh() const { 570 return USE_SLEEF( 571 Vectorized<float>(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])), 572 map(std::cosh) 573 ); 574 } 575 Vectorized<float> ceil() const { 576 return map(at::native::ceil_impl); 577 } 578 Vectorized<float> floor() const { 579 return map(at::native::floor_impl); 580 } 581 Vectorized<float> neg() const { 582 return Vectorized<float>( 583 vnegq_f32(values.val[0]), 584 vnegq_f32(values.val[1])); 585 } 586 Vectorized<float> round() const { 587 // We do not use std::round because we would like to round midway numbers to the nearest even integer. 588 return map(at::native::round_impl); 589 } 590 Vectorized<float> tan() const { 591 return USE_SLEEF( 592 Vectorized<float>(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])), 593 map(std::tan) 594 ); 595 } 596 Vectorized<float> tanh() const { 597 return USE_SLEEF( 598 Vectorized<float>(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])), 599 map(std::tanh) 600 ); 601 } 602 Vectorized<float> trunc() const { 603 float32x4_t r0 = vrndq_f32(values.val[0]); 604 float32x4_t r1 = vrndq_f32(values.val[1]); 605 return Vectorized<float>(r0, r1); 606 } 607 Vectorized<float> lgamma() const { 608 return USE_SLEEF( 609 Vectorized<float>(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])), 610 map(std::lgamma) 611 ); 612 } 613 Vectorized<float> sqrt() const { 614 return Vectorized<float>( 615 vsqrtq_f32(values.val[0]), 616 vsqrtq_f32(values.val[1])); 617 } 618 Vectorized<float> reciprocal() const { 619 auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]); 620 auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]); 621 return Vectorized<float>(r0, r1); 622 } 623 Vectorized<float> rsqrt() const { 624 return this->sqrt().reciprocal(); 625 } 626 Vectorized<float> pow(const Vectorized<float> &exp) const { 627 USE_SLEEF( 628 { 629 return Vectorized<float>(Sleef_powf4_u10(values.val[0], exp.values.val[0]), 630 Sleef_powf4_u10(values.val[1], exp.values.val[1])); 631 }, 632 { 633 __at_align__ float tmp[size()]; 634 __at_align__ float tmp_exp[size()]; 635 store(tmp); 636 exp.store(tmp_exp); 637 for (const auto i : c10::irange(size())) { 638 tmp[i] = std::pow(tmp[i], tmp_exp[i]); 639 } 640 return loadu(tmp); 641 } 642 ) 643 } 644 Vectorized<float> operator==(const Vectorized<float>& other) const { 645 float32x4_t r0 = 646 vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0])); 647 float32x4_t r1 = 648 vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1])); 649 return Vectorized<float>(r0, r1); 650 } 651 652 Vectorized<float> operator!=(const Vectorized<float>& other) const { 653 float32x4_t r0 = vreinterpretq_f32_u32( 654 vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0]))); 655 float32x4_t r1 = vreinterpretq_f32_u32( 656 vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1]))); 657 return Vectorized<float>(r0, r1); 658 } 659 660 Vectorized<float> operator<(const Vectorized<float>& other) const { 661 float32x4_t r0 = 662 vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0])); 663 float32x4_t r1 = 664 vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1])); 665 return Vectorized<float>(r0, r1); 666 } 667 668 Vectorized<float> operator<=(const Vectorized<float>& other) const { 669 float32x4_t r0 = 670 vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0])); 671 float32x4_t r1 = 672 vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1])); 673 return Vectorized<float>(r0, r1); 674 } 675 676 Vectorized<float> operator>(const Vectorized<float>& other) const { 677 float32x4_t r0 = 678 vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0])); 679 float32x4_t r1 = 680 vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1])); 681 return Vectorized<float>(r0, r1); 682 } 683 684 Vectorized<float> operator>=(const Vectorized<float>& other) const { 685 float32x4_t r0 = 686 vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0])); 687 float32x4_t r1 = 688 vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1])); 689 return Vectorized<float>(r0, r1); 690 } 691 692 Vectorized<float> eq(const Vectorized<float>& other) const; 693 Vectorized<float> ne(const Vectorized<float>& other) const; 694 Vectorized<float> gt(const Vectorized<float>& other) const; 695 Vectorized<float> ge(const Vectorized<float>& other) const; 696 Vectorized<float> lt(const Vectorized<float>& other) const; 697 Vectorized<float> le(const Vectorized<float>& other) const; 698 }; 699 700 template <> 701 Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) { 702 float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low()); 703 float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high()); 704 return Vectorized<float>(r0, r1); 705 } 706 707 template <> 708 Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) { 709 float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low()); 710 float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high()); 711 return Vectorized<float>(r0, r1); 712 } 713 714 template <> 715 Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) { 716 float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low()); 717 float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high()); 718 return Vectorized<float>(r0, r1); 719 } 720 721 template <> 722 Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) { 723 float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low()); 724 float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high()); 725 return Vectorized<float>(r0, r1); 726 } 727 728 // frac. Implement this here so we can use subtraction 729 inline Vectorized<float> Vectorized<float>::frac() const { 730 return *this - this->trunc(); 731 } 732 733 //Added sleef Implementation for Maximum 734 Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) { 735 if(!a.has_inf_nan() && !b.has_inf_nan()){ 736 return USE_SLEEF( 737 Vectorized<float>(Sleef_fmaxf4(a.get_low(), b.get_low()),Sleef_fmaxf4(a.get_high(), b.get_high())), 738 Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high()))); 739 } 740 else{ 741 return Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high())); 742 } 743 } 744 745 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if 746 // either input is a NaN. 747 template <> 748 Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) { 749 float32x4_t r0 = vminq_f32(a.get_low(), b.get_low()); 750 float32x4_t r1 = vminq_f32(a.get_high(), b.get_high()); 751 return Vectorized<float>(r0, r1); 752 } 753 754 template <> 755 Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) { 756 return minimum(max, maximum(min, a)); 757 } 758 759 template <> 760 Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) { 761 return minimum(max, a); 762 } 763 764 template <> 765 Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) { 766 return maximum(min, a); 767 } 768 769 template <> 770 Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) { 771 float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32( 772 vreinterpretq_u32_f32(a.get_low()), 773 vreinterpretq_u32_f32(b.get_low()))); 774 float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32( 775 vreinterpretq_u32_f32(a.get_high()), 776 vreinterpretq_u32_f32(b.get_high()))); 777 return Vectorized<float>(r0, r1); 778 } 779 780 template <> 781 Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) { 782 float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32( 783 vreinterpretq_u32_f32(a.get_low()), 784 vreinterpretq_u32_f32(b.get_low()))); 785 float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32( 786 vreinterpretq_u32_f32(a.get_high()), 787 vreinterpretq_u32_f32(b.get_high()))); 788 return Vectorized<float>(r0, r1); 789 } 790 791 template <> 792 Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) { 793 float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32( 794 vreinterpretq_u32_f32(a.get_low()), 795 vreinterpretq_u32_f32(b.get_low()))); 796 float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32( 797 vreinterpretq_u32_f32(a.get_high()), 798 vreinterpretq_u32_f32(b.get_high()))); 799 return Vectorized<float>(r0, r1); 800 } 801 802 inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const { 803 return (*this == other) & Vectorized<float>(1.0f); 804 } 805 806 inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const { 807 return (*this != other) & Vectorized<float>(1.0f); 808 } 809 810 inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const { 811 return (*this > other) & Vectorized<float>(1.0f); 812 } 813 814 inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const { 815 return (*this >= other) & Vectorized<float>(1.0f); 816 } 817 818 inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const { 819 return (*this < other) & Vectorized<float>(1.0f); 820 } 821 822 inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const { 823 return (*this <= other) & Vectorized<float>(1.0f); 824 } 825 826 template <> 827 inline void convert(const float* src, int32_t* dst, int64_t n) { 828 int64_t i; 829 #ifndef __msvc_cl__ 830 #pragma unroll 831 #endif 832 for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { 833 vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); 834 vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4))); 835 } 836 #ifndef __msvc_cl__ 837 #pragma unroll 838 #endif 839 for (; i < n; i++) { 840 dst[i] = static_cast<int32_t>(src[i]); 841 } 842 } 843 844 template <> 845 inline void convert(const int32_t* src, float* dst, int64_t n) { 846 int64_t i; 847 #ifndef __msvc_cl__ 848 #pragma unroll 849 #endif 850 for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { 851 vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); 852 vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4))); 853 } 854 #ifndef __msvc_cl__ 855 #pragma unroll 856 #endif 857 for (; i < n; i++) { 858 dst[i] = static_cast<float>(src[i]); 859 } 860 } 861 862 template <> 863 Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { 864 float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low()); 865 float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high()); 866 return Vectorized<float>(r0, r1); 867 } 868 869 template <> 870 Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { 871 float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low()); 872 float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high()); 873 return Vectorized<float>(r0, r1); 874 } 875 876 inline Vectorized<float> Vectorized<float>::erf() const{ 877 // constants 878 const Vectorized<float> neg_zero_vec(-0.f); 879 const Vectorized<float> one_vec(1.0f); 880 const Vectorized<float> p(0.3275911f); 881 const Vectorized<float> p1(0.254829592f); 882 const Vectorized<float> p2(-0.284496736f); 883 const Vectorized<float> p3(1.421413741f); 884 const Vectorized<float> p4(-1.453152027f); 885 const Vectorized<float> p5(1.061405429f); 886 // sign(x) 887 auto sign_mask = neg_zero_vec & *this; 888 auto abs_vec = this->abs(); 889 // t = 1 / (p * abs(x) + 1) 890 auto tmp0 = fmadd(p, abs_vec, one_vec); 891 auto t = one_vec / tmp0; 892 // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 893 auto tmp1 = fmadd(p5, t, p4); 894 auto tmp2 = fmadd(tmp1, t, p3); 895 auto tmp3 = fmadd(tmp2, t, p2); 896 auto r = fmadd(tmp3, t, p1); 897 // - exp(- x * x) 898 auto pow_2 = (*this) * (*this); 899 auto neg_pow_2 = pow_2 ^ neg_zero_vec; 900 auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp. 901 auto tmp5 = tmp4 ^ neg_zero_vec; 902 // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) 903 auto tmp6 = t * tmp5; 904 auto tmp7 = fmadd(tmp6, r, one_vec); 905 return tmp7 ^ sign_mask; 906 } 907 #endif /* defined(aarch64) */ 908 909 }} // namespace at::vec::CPU_CAPABILITY 910