1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/lite/kernels/cpu_backend_context.h" 17 18 #include <memory> 19 20 #ifdef TFLITE_HAVE_CPUINFO 21 #include "include/cpuinfo.h" 22 #endif 23 24 #include "public/gemmlowp.h" 25 #include "ruy/context.h" // from @ruy 26 #include "ruy/path.h" // from @ruy 27 #include "tensorflow/lite/c/common.h" 28 #include "tensorflow/lite/core/macros.h" 29 #include "tensorflow/lite/external_cpu_backend_context.h" 30 #include "tensorflow/lite/kernels/internal/compatibility.h" 31 #include "tensorflow/lite/kernels/op_macros.h" 32 33 namespace { 34 const int kDefaultNumThreadpoolThreads = 1; 35 36 } // namespace 37 38 namespace tflite { 39 40 // Use weak symbols if possible to dispatch to deprecated paths. 41 #if TFLITE_HAS_ATTRIBUTE_WEAK && !defined(__APPLE__) 42 extern TFLITE_ATTRIBUTE_WEAK bool UseGemmlowpOnX86(); 43 #endif // defined(TFLITE_HAS_ATTRIBUTE_WEAK) && !(__APPLE__) 44 45 // TODO(b/138922878) Enable when Ruy builds on Apple. 46 #if defined(TFLITE_HAVE_CPUINFO) && !defined(__APPLE__) ~CpuInfo()47CpuBackendContext::CpuInfo::~CpuInfo() { 48 if (init_status_ == InitStatus::kInitialized) { 49 cpuinfo_deinitialize(); 50 } 51 } 52 EnsureInitialized()53bool CpuBackendContext::CpuInfo::EnsureInitialized() { 54 if (init_status_ == InitStatus::kNotYetAttempted) { 55 init_status_ = Initialize(); 56 } 57 return init_status_ == InitStatus::kInitialized; 58 } 59 60 CpuBackendContext::CpuInfo::InitStatus Initialize()61CpuBackendContext::CpuInfo::Initialize() { 62 TFLITE_DCHECK_EQ(init_status_, InitStatus::kNotYetAttempted); 63 if (!cpuinfo_initialize()) { 64 return InitStatus::kFailed; 65 } 66 return InitStatus::kInitialized; 67 } 68 Avx2Fma()69bool CpuBackendContext::CpuInfo::Avx2Fma() { 70 return EnsureInitialized() && cpuinfo_has_x86_avx2() && 71 cpuinfo_has_x86_fma3(); 72 } 73 Avx()74bool CpuBackendContext::CpuInfo::Avx() { 75 return EnsureInitialized() && cpuinfo_has_x86_avx(); 76 } 77 Avx512()78bool CpuBackendContext::CpuInfo::Avx512() { 79 return EnsureInitialized() && cpuinfo_has_x86_avx512f() && 80 cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() && 81 cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512vl(); 82 } 83 #else 84 ~CpuInfo()85CpuBackendContext::CpuInfo::~CpuInfo() {} 86 EnsureInitialized()87bool CpuBackendContext::CpuInfo::EnsureInitialized() { 88 if (init_status_ == InitStatus::kNotYetAttempted) { 89 init_status_ = InitStatus::kInitialized; 90 } 91 TFLITE_DCHECK_EQ(init_status_, InitStatus::kInitialized); 92 return true; 93 } 94 Avx2Fma()95bool CpuBackendContext::CpuInfo::Avx2Fma() { return false; } 96 Avx()97bool CpuBackendContext::CpuInfo::Avx() { return false; } 98 Avx512()99bool CpuBackendContext::CpuInfo::Avx512() { return false; } 100 #endif // TFLITE_HAVE_CPUINFO 101 GetFromContext(TfLiteContext * context)102CpuBackendContext* CpuBackendContext::GetFromContext(TfLiteContext* context) { 103 auto* external_context = static_cast<ExternalCpuBackendContext*>( 104 context->GetExternalContext(context, kTfLiteCpuBackendContext)); 105 106 if (external_context == nullptr) { 107 TF_LITE_FATAL( 108 "ExternalCpuBackendContext isn't properly initialized during TFLite " 109 "interpreter initialization."); 110 } 111 112 auto* cpu_backend_context = static_cast<CpuBackendContext*>( 113 external_context->internal_backend_context()); 114 if (cpu_backend_context == nullptr) { 115 // We do the lazy initialization here for the TfLiteInternalBackendContext 116 // that's wrapped inside ExternalCpuBackendContext. 117 cpu_backend_context = new CpuBackendContext(); 118 cpu_backend_context->SetMaxNumThreads(context->recommended_num_threads); 119 external_context->set_internal_backend_context( 120 std::unique_ptr<TfLiteInternalBackendContext>(cpu_backend_context)); 121 } 122 123 return cpu_backend_context; 124 } 125 CpuBackendContext()126CpuBackendContext::CpuBackendContext() 127 : TfLiteInternalBackendContext(), 128 ruy_context_(new ruy::Context), 129 gemmlowp_context_(new gemmlowp::GemmContext) { 130 SetMaxNumThreads(kDefaultNumThreadpoolThreads); 131 // TODO(b/148289189) Remove when clients have transitioned to runtime flag. 132 #ifdef TFLITE_WITH_RUY_GEMV 133 SetUseCaching(true); 134 #else 135 SetUseCaching(false); 136 #endif 137 } 138 ~CpuBackendContext()139CpuBackendContext::~CpuBackendContext() {} 140 SetMaxNumThreads(int max_num_threads)141void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { 142 const int target_num_threads = 143 max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads; 144 max_num_threads_ = target_num_threads; 145 ruy_context_->set_max_num_threads(target_num_threads); 146 gemmlowp_context_->set_max_num_threads(target_num_threads); 147 } 148 SetUseCaching(bool flag)149void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; } 150 PreferGemmlowpOnX86()151bool CpuBackendContext::PreferGemmlowpOnX86() { 152 bool use_gemmlowp_on_x86 = false; 153 #if defined(TFLITE_X86_PLATFORM) && TFLITE_HAS_ATTRIBUTE_WEAK && \ 154 !defined(__APPLE__) 155 if (::tflite::UseGemmlowpOnX86 != nullptr) { 156 use_gemmlowp_on_x86 = ::tflite::UseGemmlowpOnX86(); 157 } 158 #endif // TFLITE_X86_PLATFORM && TFLITE_HAS_ATTRIBUTE_WEAK && !(__APPLE__) 159 return use_gemmlowp_on_x86 || !RuyHasAvxOrAbove(); 160 } 161 RuyHasAvxOrAbove()162bool CpuBackendContext::RuyHasAvxOrAbove() { 163 // TODO(b/183178387): Use a proper query to detect AVX/optimized paths. 164 #if RUY_PLATFORM_X86_ENHANCEMENTS 165 return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512(); 166 #else 167 return false; 168 #endif 169 } 170 171 } // namespace tflite 172