xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <cstdio>
22 #include <list>
23 #include <memory>
24 #include <utility>
25 
26 #include "llvm/ExecutionEngine/ExecutionEngine.h"
27 #include "llvm/ExecutionEngine/JITSymbol.h"
28 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
29 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
30 #include "llvm/IR/Mangler.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Support/CodeGen.h"
33 #include "llvm/Support/Host.h"
34 #include "mlir/ExecutionEngine/CRunnerUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
36 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
38 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_acl.h"
39 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
40 #include "tensorflow/compiler/xla/service/cpu/runtime_conv3d.h"
41 #include "tensorflow/compiler/xla/service/cpu/runtime_custom_call_status.h"
42 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
43 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
44 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
45 #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
46 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
47 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.h"
48 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
49 #include "tensorflow/compiler/xla/service/cpu/runtime_pow.h"
50 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
51 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv3d.h"
52 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
53 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
54 #include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
55 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
56 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
57 #include "tensorflow/compiler/xla/types.h"
58 #include "tensorflow/core/platform/logging.h"
59 
60 // Provided by compiler-rt and MLIR.
61 // Converts an F32 value to a BF16.
62 extern "C" uint16_t __truncsfbf2(float);
63 // Converts an F64 value to a BF16.
64 extern "C" uint16_t __truncdfbf2(double);
65 
66 namespace xla {
67 namespace cpu {
68 namespace {
69 
DetectMachineAttributes()70 llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
71   llvm::SmallVector<std::string, 0> result;
72   llvm::StringMap<bool> host_features;
73   if (llvm::sys::getHostCPUFeatures(host_features)) {
74     for (auto& feature : host_features) {
75       result.push_back((feature.second ? '+' : '-') +
76                        std::string(feature.first()));
77     }
78   }
79   return result;
80 }
81 
82 }  // namespace
83 
84 /*static*/ std::unique_ptr<llvm::TargetMachine>
InferTargetMachineForJIT(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level)85 SimpleOrcJIT::InferTargetMachineForJIT(
86     const llvm::TargetOptions& target_options,
87     llvm::CodeGenOpt::Level opt_level) {
88   std::unique_ptr<llvm::TargetMachine> target_machine(
89       llvm::EngineBuilder()
90           .setTargetOptions(target_options)
91           .setOptLevel(opt_level)
92           .selectTarget(
93               /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
94               /*MCPU=*/llvm::sys::getHostCPUName(),
95               /*MAttrs=*/DetectMachineAttributes()));
96   CHECK(target_machine != nullptr);
97   return target_machine;
98 }
99 
SimpleOrcJIT(std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control,std::unique_ptr<llvm::orc::ExecutionSession> execution_session,const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level,bool optimize_for_size,bool disable_expensive_passes,llvm::FastMathFlags fast_math_flags,LLVMCompiler::ModuleHook pre_optimization_hook,LLVMCompiler::ModuleHook post_optimization_hook,std::function<void (const llvm::object::ObjectFile &)> post_codegen_hook)100 SimpleOrcJIT::SimpleOrcJIT(
101     std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control,
102     std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
103     const llvm::TargetOptions& target_options,
104     llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
105     bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
106     LLVMCompiler::ModuleHook pre_optimization_hook,
107     LLVMCompiler::ModuleHook post_optimization_hook,
108     std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
109     : target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
110       target_triple_(target_machine_->getTargetTriple()),
111       data_layout_(target_machine_->createDataLayout()),
112       target_process_control_(std::move(target_process_control)),
113       execution_session_(std::move(execution_session)),
114       object_layer_(*execution_session_,
115                     []() {
116                       return std::make_unique<llvm::SectionMemoryManager>(
117                           orc_jit_memory_mapper::GetInstance());
118                     }),
119       compile_layer_(
120           *execution_session_, object_layer_,
121           std::make_unique<CompilerFunctor>(
122               target_machine_.get(), opt_level, optimize_for_size,
123               disable_expensive_passes, fast_math_flags,
124               std::move(pre_optimization_hook),
125               std::move(post_optimization_hook), std::move(post_codegen_hook))),
126       main_jit_dylib_(&execution_session_->createBareJITDylib("<main>")),
127       gdb_jit_event_listener_(
128           llvm::JITEventListener::createGDBRegistrationListener()) {
129   VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
130           << " features: " << target_machine_->getTargetFeatureString().str();
131 
132   // Materialize unknown symbols from the runtime symbol table.
133   class RuntimeSymbolGenerator : public llvm::orc::DefinitionGenerator {
134     SimpleOrcJIT& jit_;
135 
136    public:
RuntimeSymbolGenerator(SimpleOrcJIT & jit)137     explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {}
tryToGenerate(llvm::orc::LookupState &,llvm::orc::LookupKind,llvm::orc::JITDylib & jit_dylib,llvm::orc::JITDylibLookupFlags,const llvm::orc::SymbolLookupSet & names)138     llvm::Error tryToGenerate(
139         llvm::orc::LookupState&, llvm::orc::LookupKind,
140         llvm::orc::JITDylib& jit_dylib, llvm::orc::JITDylibLookupFlags,
141         const llvm::orc::SymbolLookupSet& names) override {
142       llvm::orc::SymbolMap new_defs;
143 
144       for (const auto& kv : names) {
145         const auto& name = kv.first;
146         if (llvm::JITEvaluatedSymbol symbol =
147                 jit_.ResolveRuntimeSymbol(*name)) {
148           new_defs[name] = symbol;
149         }
150       }
151 
152       cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs))));
153       return llvm::Error::success();
154     }
155   };
156   main_jit_dylib_->addGenerator(
157       std::make_unique<RuntimeSymbolGenerator>(*this));
158   object_layer_.registerJITEventListener(*this);
159 
160   // Copied from LLJIT, required to find symbols on Windows.
161   if (target_triple_.isOSBinFormatCOFF()) {
162     object_layer_.setOverrideObjectFlagsWithResponsibilityFlags(true);
163     object_layer_.setAutoClaimResponsibilityForObjectSymbols(true);
164   }
165 }
166 
~SimpleOrcJIT()167 SimpleOrcJIT::~SimpleOrcJIT() {
168   if (auto err = execution_session_->endSession()) {
169     execution_session_->reportError(std::move(err));
170   }
171 }
172 
Create(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level,bool optimize_for_size,bool disable_expensive_passes,llvm::FastMathFlags fast_math_flags,LLVMCompiler::ModuleHook pre_optimization_hook,LLVMCompiler::ModuleHook post_optimization_hook,std::function<void (const llvm::object::ObjectFile &)> post_codegen_hook)173 llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
174     const llvm::TargetOptions& target_options,
175     llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
176     bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
177     LLVMCompiler::ModuleHook pre_optimization_hook,
178     LLVMCompiler::ModuleHook post_optimization_hook,
179     std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) {
180   auto SSP = std::make_shared<llvm::orc::SymbolStringPool>();
181   auto target_process_control =
182       llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP));
183   if (!target_process_control) {
184     return target_process_control.takeError();
185   }
186 
187   auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(
188       std::make_unique<llvm::orc::UnsupportedExecutorProcessControl>());
189   return std::make_unique<SimpleOrcJIT>(
190       std::move(*target_process_control), std::move(execution_session),
191       target_options, opt_level, optimize_for_size, disable_expensive_passes,
192       fast_math_flags, std::move(pre_optimization_hook),
193       std::move(post_optimization_hook), std::move(post_codegen_hook));
194 }
195 
ResolveRuntimeSymbol(llvm::StringRef name)196 llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol(
197     llvm::StringRef name) {
198   void* func_addr = nullptr;
199   if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
200     // On Mac OS X, 'name' may have a leading underscore prefix, even though the
201     // registered name may not.
202     std::string stripped_name(name.begin() + 1, name.end());
203     func_addr =
204         xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
205   } else {
206     func_addr =
207         xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host");
208   }
209 
210   if (func_addr == nullptr) {
211     LOG(ERROR)
212         << "Unable to resolve runtime symbol: `" << name.str()
213         << "'.  Hint: if the symbol a custom call target, make sure you've "
214            "registered it with the JIT using "
215            "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
216     return nullptr;
217   }
218   llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
219                                        llvm::JITSymbolFlags::None);
220   return symbol_info;
221 }
222 
notifyObjectLoaded(llvm::JITEventListener::ObjectKey key,const llvm::object::ObjectFile & object,const llvm::RuntimeDyld::LoadedObjectInfo & object_info)223 void SimpleOrcJIT::notifyObjectLoaded(
224     llvm::JITEventListener::ObjectKey key,
225     const llvm::object::ObjectFile& object,
226     const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
227   gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
228   size_of_generated_code_in_bytes_ += object.getData().size();
229 }
230 
notifyFreeingObject(llvm::JITEventListener::ObjectKey key)231 void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) {
232   gdb_jit_event_listener_->notifyFreeingObject(key);
233 }
234 
AddModule(llvm::orc::ThreadSafeModule module)235 llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) {
236   return compile_layer_.add(*main_jit_dylib_, std::move(module));
237 }
238 
DoneCompiling()239 void SimpleOrcJIT::DoneCompiling() {
240   // The target machine takes a non-trivial amount of memory, so once we are
241   // done compiling throw it away.
242   target_machine_.reset();
243 }
244 
FindCompiledSymbol(const std::string & name)245 llvm::Expected<llvm::JITEvaluatedSymbol> SimpleOrcJIT::FindCompiledSymbol(
246     const std::string& name) {
247   return execution_session_->lookup({main_jit_dylib_}, name);
248 }
249 
250 #if defined(PLATFORM_WINDOWS)
251 // This function is used by compiler-generated code on windows, but it's not
252 // declared anywhere. The signature does not matter, we just need the address.
253 extern "C" void __chkstk(size_t);
254 #endif
255 
256 namespace {
257 // Register some known symbols with the CustomCallTargetRegistry.
RegisterKnownJITSymbols()258 bool RegisterKnownJITSymbols() {
259   xla::CustomCallTargetRegistry* registry =
260       xla::CustomCallTargetRegistry::Global();
261   registry->Register("printf", reinterpret_cast<void*>(&printf), "Host");
262   registry->Register("puts", reinterpret_cast<void*>(&puts), "Host");
263 
264 #define REGISTER_CPU_RUNTIME_SYMBOL(base_name)                               \
265   do {                                                                       \
266     auto* function_address =                                                 \
267         reinterpret_cast<void*>(__xla_cpu_runtime_##base_name);              \
268     registry->Register(xla::cpu::runtime::k##base_name##SymbolName,          \
269                        function_address, "Host");                            \
270     CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
271              "__xla_cpu_runtime_" #base_name);                               \
272   } while (false)
273 
274   REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
275   REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
276   REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
277   REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
278   REGISTER_CPU_RUNTIME_SYMBOL(AllToAll);
279   REGISTER_CPU_RUNTIME_SYMBOL(PartitionId);
280   REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
281   REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32);
282   REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF16);
283   REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF32);
284   REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF16);
285   REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF32);
286   REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
287   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
288   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
289   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
290   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
291   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
292   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
293   REGISTER_CPU_RUNTIME_SYMBOL(EigenBatchMatMulF32);
294   REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
295   REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
296   REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
297   REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
298   REGISTER_CPU_RUNTIME_SYMBOL(ACLMatMulF32);
299   REGISTER_CPU_RUNTIME_SYMBOL(ACLBatchMatMulF32);
300   REGISTER_CPU_RUNTIME_SYMBOL(ACLConv2DF32);
301   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF16);
302   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF32);
303   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16);
304   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32);
305   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft);
306   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
307   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
308   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
309   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
310   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
311   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
312   REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
313   REGISTER_CPU_RUNTIME_SYMBOL(PrintfToStderr);
314   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
315   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
316   REGISTER_CPU_RUNTIME_SYMBOL(StatusIsSuccess);
317   REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
318   REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
319   REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
320   REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
321 
322   registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
323                      "Host");
324   registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
325                      "Host");
326   registry->Register("__truncdfhf2", reinterpret_cast<void*>(__truncdfhf2),
327                      "Host");
328   registry->Register("__truncdfbf2", reinterpret_cast<void*>(__truncdfbf2),
329                      "Host");
330   registry->Register("__truncsfbf2", reinterpret_cast<void*>(__truncsfbf2),
331                      "Host");
332   registry->Register("__powisf2", reinterpret_cast<void*>(__powisf2), "Host");
333   registry->Register("__powidf2", reinterpret_cast<void*>(__powidf2), "Host");
334 
335 #undef REGISTER_CPU_RUNTIME_SYMBOL
336 
337 // Register both the f32 (float) and f64 (double) versions of a libm symbol.
338 // Unfortunately the double versions are overloaded on some systems, e.g.
339 // Mac so we need an explicit cast. This requires passing the function signature
340 // for that case.
341 #define REGISTER_LIBM_SYMBOL(name, double_sig)                                 \
342   do {                                                                         \
343     registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host");   \
344     registry->Register(#name,                                                  \
345                        reinterpret_cast<void*>(static_cast<double_sig>(name)), \
346                        "Host");                                                \
347   } while (false)
348 
349   REGISTER_LIBM_SYMBOL(acos, double (*)(double));
350   REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
351   REGISTER_LIBM_SYMBOL(asin, double (*)(double));
352   REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
353   REGISTER_LIBM_SYMBOL(atan, double (*)(double));
354   REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
355   REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
356   REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
357   REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
358   REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
359   REGISTER_LIBM_SYMBOL(cos, double (*)(double));
360   REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
361   REGISTER_LIBM_SYMBOL(erf, double (*)(double));
362   REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
363   REGISTER_LIBM_SYMBOL(exp, double (*)(double));
364   REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
365   REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
366   REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
367   REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
368   REGISTER_LIBM_SYMBOL(floor, double (*)(double));
369   REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
370   REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
371   REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
372   REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
373   REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
374   REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
375   REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
376   REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
377   REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
378   REGISTER_LIBM_SYMBOL(llrint, long long (*)(double));   // NOLINT(runtime/int)
379   REGISTER_LIBM_SYMBOL(llround, long long (*)(double));  // NOLINT(runtime/int)
380   REGISTER_LIBM_SYMBOL(log, double (*)(double));
381   REGISTER_LIBM_SYMBOL(log10, double (*)(double));
382   REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
383   REGISTER_LIBM_SYMBOL(log2, double (*)(double));
384   REGISTER_LIBM_SYMBOL(logb, double (*)(double));
385   REGISTER_LIBM_SYMBOL(lrint, long (*)(double));   // NOLINT(runtime/int)
386   REGISTER_LIBM_SYMBOL(lround, long (*)(double));  // NOLINT(runtime/int)
387   REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
388   REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
389   REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
390   REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
391   REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
392   REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
393   REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
394   REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
395   REGISTER_LIBM_SYMBOL(rint, double (*)(double));
396   REGISTER_LIBM_SYMBOL(round, double (*)(double));
397   REGISTER_LIBM_SYMBOL(scalbln,
398                        double (*)(double, long));  // NOLINT(runtime/int)
399   REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
400   REGISTER_LIBM_SYMBOL(sin, double (*)(double));
401 #ifdef __APPLE__
402   REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
403   registry->Register("__sincosf_stret",
404                      reinterpret_cast<void*>(__sincosf_stret), "Host");
405   registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
406                      "Host");
407 #else
408   REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
409 #endif
410   REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
411   REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
412   REGISTER_LIBM_SYMBOL(tan, double (*)(double));
413   REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
414   REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
415   REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
416 
417 #undef REGISTER_LIBM_SYMBOL
418 
419   registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
420   registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
421   registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
422 
423   // Used by MLIR lowering.
424   registry->Register("malloc", reinterpret_cast<void*>(malloc), "Host");
425   registry->Register("calloc", reinterpret_cast<void*>(calloc), "Host");
426   registry->Register("free", reinterpret_cast<void*>(free), "Host");
427 #ifndef _WIN32
428   // TODO(kramerb): This fails to link on windows because it's marked dllimport.
429   registry->Register("memrefCopy", reinterpret_cast<void*>(memrefCopy), "Host");
430 #endif
431 
432 #ifdef __APPLE__
433   registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
434   registry->Register("bzero", reinterpret_cast<void*>(bzero), "Host");
435   registry->Register("memset_pattern16",
436                      reinterpret_cast<void*>(memset_pattern16), "Host");
437 #endif
438 
439 #ifdef MEMORY_SANITIZER
440   registry->Register("__msan_unpoison",
441                      reinterpret_cast<void*>(__msan_unpoison), "Host");
442 #endif
443 
444 #if defined(PLATFORM_WINDOWS)
445   registry->Register("__chkstk", reinterpret_cast<void*>(__chkstk), "Host");
446 #endif
447 
448   return true;
449 }
450 
451 bool unused = RegisterKnownJITSymbols();
452 }  // namespace
453 
454 }  // namespace cpu
455 }  // namespace xla
456