xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/FlushDenormal.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cpu/FlushDenormal.h>
2 #include <ATen/cpu/vec/intrinsics.h>
3 #if !defined(__s390x__) && !defined(__powerpc__)
4 #include <cpuinfo.h>
5 #endif
6 
7 namespace at::cpu {
8 
9 #if defined(__SSE__) || defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
10 static constexpr unsigned int DENORMALS_ZERO = 0x0040;
11 static constexpr unsigned int FLUSH_ZERO = 0x8000;
12 
set_flush_denormal(bool on)13 bool set_flush_denormal(bool on) {
14   // Compile if we have SSE support (GCC), x86-64 (MSVC), or x86 with SSE (MSVC)
15   // Denormals-Are-Zero is supported by most SSE2 processors, with the exception
16   // of some early Pentium 4 processors. We guard it with a runtime check.
17   // Flush-To-Zero (FTZ) only requires SSE.
18   if (cpuinfo_has_x86_daz()) {
19     unsigned int csr = _mm_getcsr();
20     csr &= ~DENORMALS_ZERO;
21     csr &= ~FLUSH_ZERO;
22     if (on) {
23       csr |= DENORMALS_ZERO;
24       csr |= FLUSH_ZERO;
25     }
26     _mm_setcsr(csr);
27     return true;
28   }
29   return false;
30 }
31 #elif defined(__ARM_FP) && (__ARM_FP > 0)
32 // Imported from TensorFlow, tensorflow/third_party/xla/third_party/tsl/tsl/platform/denormal.cc
33 // Copyright 2015 The TensorFlow Authors. All Rights Reserved.
34 
35 // Flush-to-zero bit on the ARM floating-point control register.
36 #define ARM_FPCR_FZ   (1 << 24)
37 
38 static inline void ArmSetFloatingPointControlRegister(uint32_t fpcr) {
39 #if defined(__aarch64__)
40   __asm__ __volatile__("msr fpcr, %[fpcr]"
41                        :
42                        : [fpcr] "r"(static_cast<uint64_t>(fpcr)));
43 #else
44   __asm__ __volatile__("vmsr fpscr, %[fpcr]" : : [fpcr] "r"(fpcr));
45 #endif
46 }
47 
48 static inline uint32_t ArmGetFloatingPointControlRegister() {
49   uint32_t fpcr;
50 #if defined(__aarch64__)
51   uint64_t fpcr64;
52   __asm__ __volatile__("mrs %[fpcr], fpcr" : [fpcr] "=r"(fpcr64));
53   fpcr = static_cast<uint32_t>(fpcr64);
54 #else
55   __asm__ __volatile__("vmrs %[fpcr], fpscr" : [fpcr] "=r"(fpcr));
56 #endif
57   return fpcr;
58 }
59 
60 bool set_flush_denormal(bool on) {
61     uint32_t fpcr = ArmGetFloatingPointControlRegister();
62     if (on) {
63       fpcr |= ARM_FPCR_FZ;
64     } else {
65       fpcr &= ~ ARM_FPCR_FZ;
66     }
67     ArmSetFloatingPointControlRegister(fpcr);
68     return true;
69 }
70 #else
71 bool set_flush_denormal(bool on) {
72   return false;
73 }
74 #endif
75 
76 }  // namespace at::cpu
77