1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
17
18 #include <stdint.h>
19
20 #include <algorithm>
21 #include <cstdio>
22 #include <list>
23 #include <memory>
24 #include <utility>
25
26 #include "llvm/ExecutionEngine/ExecutionEngine.h"
27 #include "llvm/ExecutionEngine/JITSymbol.h"
28 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
29 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
30 #include "llvm/IR/Mangler.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Support/CodeGen.h"
33 #include "llvm/Support/Host.h"
34 #include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
36 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
37 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
38 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_acl.h"
39 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
40 #include "tensorflow/compiler/xla/service/cpu/runtime_conv3d.h"
41 #include "tensorflow/compiler/xla/service/cpu/runtime_custom_call_status.h"
42 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
43 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
44 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
45 #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
46 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
47 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.h"
48 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
49 #include "tensorflow/compiler/xla/service/cpu/runtime_pow.h"
50 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
51 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv3d.h"
52 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
53 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
54 #include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
55 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
56 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
57 #include "tensorflow/compiler/xla/types.h"
58 #include "tensorflow/core/platform/logging.h"
59
60 // Provided by compiler-rt and MLIR.
61 // Converts an F32 value to a BF16.
62 extern "C" uint16_t __truncsfbf2(float);
63 // Converts an F64 value to a BF16.
64 extern "C" uint16_t __truncdfbf2(double);
65
66 namespace xla {
67 namespace cpu {
68 namespace {
69
DetectMachineAttributes()70 llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
71 llvm::SmallVector<std::string, 0> result;
72 llvm::StringMap<bool> host_features;
73 if (llvm::sys::getHostCPUFeatures(host_features)) {
74 for (auto& feature : host_features) {
75 result.push_back((feature.second ? '+' : '-') +
76 std::string(feature.first()));
77 }
78 }
79 return result;
80 }
81
82 } // namespace
83
84 /*static*/ std::unique_ptr<llvm::TargetMachine>
InferTargetMachineForJIT(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level)85 SimpleOrcJIT::InferTargetMachineForJIT(
86 const llvm::TargetOptions& target_options,
87 llvm::CodeGenOpt::Level opt_level) {
88 std::unique_ptr<llvm::TargetMachine> target_machine(
89 llvm::EngineBuilder()
90 .setTargetOptions(target_options)
91 .setOptLevel(opt_level)
92 .selectTarget(
93 /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
94 /*MCPU=*/llvm::sys::getHostCPUName(),
95 /*MAttrs=*/DetectMachineAttributes()));
96 CHECK(target_machine != nullptr);
97 return target_machine;
98 }
99
SimpleOrcJIT(std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control,std::unique_ptr<llvm::orc::ExecutionSession> execution_session,const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level,bool optimize_for_size,bool disable_expensive_passes,llvm::FastMathFlags fast_math_flags,LLVMCompiler::ModuleHook pre_optimization_hook,LLVMCompiler::ModuleHook post_optimization_hook,std::function<void (const llvm::object::ObjectFile &)> post_codegen_hook)100 SimpleOrcJIT::SimpleOrcJIT(
101 std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control,
102 std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
103 const llvm::TargetOptions& target_options,
104 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
105 bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
106 LLVMCompiler::ModuleHook pre_optimization_hook,
107 LLVMCompiler::ModuleHook post_optimization_hook,
108 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
109 : target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
110 target_triple_(target_machine_->getTargetTriple()),
111 data_layout_(target_machine_->createDataLayout()),
112 target_process_control_(std::move(target_process_control)),
113 execution_session_(std::move(execution_session)),
114 object_layer_(*execution_session_,
115 []() {
116 return std::make_unique<llvm::SectionMemoryManager>(
117 orc_jit_memory_mapper::GetInstance());
118 }),
119 compile_layer_(
120 *execution_session_, object_layer_,
121 std::make_unique<CompilerFunctor>(
122 target_machine_.get(), opt_level, optimize_for_size,
123 disable_expensive_passes, fast_math_flags,
124 std::move(pre_optimization_hook),
125 std::move(post_optimization_hook), std::move(post_codegen_hook))),
126 main_jit_dylib_(&execution_session_->createBareJITDylib("<main>")),
127 gdb_jit_event_listener_(
128 llvm::JITEventListener::createGDBRegistrationListener()) {
129 VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
130 << " features: " << target_machine_->getTargetFeatureString().str();
131
132 // Materialize unknown symbols from the runtime symbol table.
133 class RuntimeSymbolGenerator : public llvm::orc::DefinitionGenerator {
134 SimpleOrcJIT& jit_;
135
136 public:
RuntimeSymbolGenerator(SimpleOrcJIT & jit)137 explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {}
tryToGenerate(llvm::orc::LookupState &,llvm::orc::LookupKind,llvm::orc::JITDylib & jit_dylib,llvm::orc::JITDylibLookupFlags,const llvm::orc::SymbolLookupSet & names)138 llvm::Error tryToGenerate(
139 llvm::orc::LookupState&, llvm::orc::LookupKind,
140 llvm::orc::JITDylib& jit_dylib, llvm::orc::JITDylibLookupFlags,
141 const llvm::orc::SymbolLookupSet& names) override {
142 llvm::orc::SymbolMap new_defs;
143
144 for (const auto& kv : names) {
145 const auto& name = kv.first;
146 if (llvm::JITEvaluatedSymbol symbol =
147 jit_.ResolveRuntimeSymbol(*name)) {
148 new_defs[name] = symbol;
149 }
150 }
151
152 cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs))));
153 return llvm::Error::success();
154 }
155 };
156 main_jit_dylib_->addGenerator(
157 std::make_unique<RuntimeSymbolGenerator>(*this));
158 object_layer_.registerJITEventListener(*this);
159
160 // Copied from LLJIT, required to find symbols on Windows.
161 if (target_triple_.isOSBinFormatCOFF()) {
162 object_layer_.setOverrideObjectFlagsWithResponsibilityFlags(true);
163 object_layer_.setAutoClaimResponsibilityForObjectSymbols(true);
164 }
165 }
166
~SimpleOrcJIT()167 SimpleOrcJIT::~SimpleOrcJIT() {
168 if (auto err = execution_session_->endSession()) {
169 execution_session_->reportError(std::move(err));
170 }
171 }
172
Create(const llvm::TargetOptions & target_options,llvm::CodeGenOpt::Level opt_level,bool optimize_for_size,bool disable_expensive_passes,llvm::FastMathFlags fast_math_flags,LLVMCompiler::ModuleHook pre_optimization_hook,LLVMCompiler::ModuleHook post_optimization_hook,std::function<void (const llvm::object::ObjectFile &)> post_codegen_hook)173 llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
174 const llvm::TargetOptions& target_options,
175 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
176 bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags,
177 LLVMCompiler::ModuleHook pre_optimization_hook,
178 LLVMCompiler::ModuleHook post_optimization_hook,
179 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) {
180 auto SSP = std::make_shared<llvm::orc::SymbolStringPool>();
181 auto target_process_control =
182 llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP));
183 if (!target_process_control) {
184 return target_process_control.takeError();
185 }
186
187 auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(
188 std::make_unique<llvm::orc::UnsupportedExecutorProcessControl>());
189 return std::make_unique<SimpleOrcJIT>(
190 std::move(*target_process_control), std::move(execution_session),
191 target_options, opt_level, optimize_for_size, disable_expensive_passes,
192 fast_math_flags, std::move(pre_optimization_hook),
193 std::move(post_optimization_hook), std::move(post_codegen_hook));
194 }
195
ResolveRuntimeSymbol(llvm::StringRef name)196 llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol(
197 llvm::StringRef name) {
198 void* func_addr = nullptr;
199 if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
200 // On Mac OS X, 'name' may have a leading underscore prefix, even though the
201 // registered name may not.
202 std::string stripped_name(name.begin() + 1, name.end());
203 func_addr =
204 xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
205 } else {
206 func_addr =
207 xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host");
208 }
209
210 if (func_addr == nullptr) {
211 LOG(ERROR)
212 << "Unable to resolve runtime symbol: `" << name.str()
213 << "'. Hint: if the symbol a custom call target, make sure you've "
214 "registered it with the JIT using "
215 "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
216 return nullptr;
217 }
218 llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
219 llvm::JITSymbolFlags::None);
220 return symbol_info;
221 }
222
notifyObjectLoaded(llvm::JITEventListener::ObjectKey key,const llvm::object::ObjectFile & object,const llvm::RuntimeDyld::LoadedObjectInfo & object_info)223 void SimpleOrcJIT::notifyObjectLoaded(
224 llvm::JITEventListener::ObjectKey key,
225 const llvm::object::ObjectFile& object,
226 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
227 gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
228 size_of_generated_code_in_bytes_ += object.getData().size();
229 }
230
notifyFreeingObject(llvm::JITEventListener::ObjectKey key)231 void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) {
232 gdb_jit_event_listener_->notifyFreeingObject(key);
233 }
234
AddModule(llvm::orc::ThreadSafeModule module)235 llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) {
236 return compile_layer_.add(*main_jit_dylib_, std::move(module));
237 }
238
DoneCompiling()239 void SimpleOrcJIT::DoneCompiling() {
240 // The target machine takes a non-trivial amount of memory, so once we are
241 // done compiling throw it away.
242 target_machine_.reset();
243 }
244
FindCompiledSymbol(const std::string & name)245 llvm::Expected<llvm::JITEvaluatedSymbol> SimpleOrcJIT::FindCompiledSymbol(
246 const std::string& name) {
247 return execution_session_->lookup({main_jit_dylib_}, name);
248 }
249
250 #if defined(PLATFORM_WINDOWS)
251 // This function is used by compiler-generated code on windows, but it's not
252 // declared anywhere. The signature does not matter, we just need the address.
253 extern "C" void __chkstk(size_t);
254 #endif
255
256 namespace {
257 // Register some known symbols with the CustomCallTargetRegistry.
RegisterKnownJITSymbols()258 bool RegisterKnownJITSymbols() {
259 xla::CustomCallTargetRegistry* registry =
260 xla::CustomCallTargetRegistry::Global();
261 registry->Register("printf", reinterpret_cast<void*>(&printf), "Host");
262 registry->Register("puts", reinterpret_cast<void*>(&puts), "Host");
263
264 #define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
265 do { \
266 auto* function_address = \
267 reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
268 registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
269 function_address, "Host"); \
270 CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
271 "__xla_cpu_runtime_" #base_name); \
272 } while (false)
273
274 REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
275 REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
276 REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
277 REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
278 REGISTER_CPU_RUNTIME_SYMBOL(AllToAll);
279 REGISTER_CPU_RUNTIME_SYMBOL(PartitionId);
280 REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
281 REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32);
282 REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF16);
283 REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF32);
284 REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF16);
285 REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF32);
286 REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
287 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
288 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
289 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
290 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64);
291 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128);
292 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32);
293 REGISTER_CPU_RUNTIME_SYMBOL(EigenBatchMatMulF32);
294 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
295 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
296 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
297 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
298 REGISTER_CPU_RUNTIME_SYMBOL(ACLMatMulF32);
299 REGISTER_CPU_RUNTIME_SYMBOL(ACLBatchMatMulF32);
300 REGISTER_CPU_RUNTIME_SYMBOL(ACLConv2DF32);
301 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF16);
302 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv2DF32);
303 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16);
304 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32);
305 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft);
306 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
307 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
308 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
309 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64);
310 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128);
311 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32);
312 REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
313 REGISTER_CPU_RUNTIME_SYMBOL(PrintfToStderr);
314 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
315 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
316 REGISTER_CPU_RUNTIME_SYMBOL(StatusIsSuccess);
317 REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
318 REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
319 REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
320 REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
321
322 registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
323 "Host");
324 registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
325 "Host");
326 registry->Register("__truncdfhf2", reinterpret_cast<void*>(__truncdfhf2),
327 "Host");
328 registry->Register("__truncdfbf2", reinterpret_cast<void*>(__truncdfbf2),
329 "Host");
330 registry->Register("__truncsfbf2", reinterpret_cast<void*>(__truncsfbf2),
331 "Host");
332 registry->Register("__powisf2", reinterpret_cast<void*>(__powisf2), "Host");
333 registry->Register("__powidf2", reinterpret_cast<void*>(__powidf2), "Host");
334
335 #undef REGISTER_CPU_RUNTIME_SYMBOL
336
337 // Register both the f32 (float) and f64 (double) versions of a libm symbol.
338 // Unfortunately the double versions are overloaded on some systems, e.g.
339 // Mac so we need an explicit cast. This requires passing the function signature
340 // for that case.
341 #define REGISTER_LIBM_SYMBOL(name, double_sig) \
342 do { \
343 registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
344 registry->Register(#name, \
345 reinterpret_cast<void*>(static_cast<double_sig>(name)), \
346 "Host"); \
347 } while (false)
348
349 REGISTER_LIBM_SYMBOL(acos, double (*)(double));
350 REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
351 REGISTER_LIBM_SYMBOL(asin, double (*)(double));
352 REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
353 REGISTER_LIBM_SYMBOL(atan, double (*)(double));
354 REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
355 REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
356 REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
357 REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
358 REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
359 REGISTER_LIBM_SYMBOL(cos, double (*)(double));
360 REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
361 REGISTER_LIBM_SYMBOL(erf, double (*)(double));
362 REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
363 REGISTER_LIBM_SYMBOL(exp, double (*)(double));
364 REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
365 REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
366 REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
367 REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
368 REGISTER_LIBM_SYMBOL(floor, double (*)(double));
369 REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
370 REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
371 REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
372 REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
373 REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
374 REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
375 REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
376 REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
377 REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
378 REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int)
379 REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int)
380 REGISTER_LIBM_SYMBOL(log, double (*)(double));
381 REGISTER_LIBM_SYMBOL(log10, double (*)(double));
382 REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
383 REGISTER_LIBM_SYMBOL(log2, double (*)(double));
384 REGISTER_LIBM_SYMBOL(logb, double (*)(double));
385 REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int)
386 REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int)
387 REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
388 REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
389 REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
390 REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
391 REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
392 REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
393 REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
394 REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
395 REGISTER_LIBM_SYMBOL(rint, double (*)(double));
396 REGISTER_LIBM_SYMBOL(round, double (*)(double));
397 REGISTER_LIBM_SYMBOL(scalbln,
398 double (*)(double, long)); // NOLINT(runtime/int)
399 REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
400 REGISTER_LIBM_SYMBOL(sin, double (*)(double));
401 #ifdef __APPLE__
402 REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
403 registry->Register("__sincosf_stret",
404 reinterpret_cast<void*>(__sincosf_stret), "Host");
405 registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
406 "Host");
407 #else
408 REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
409 #endif
410 REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
411 REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
412 REGISTER_LIBM_SYMBOL(tan, double (*)(double));
413 REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
414 REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
415 REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
416
417 #undef REGISTER_LIBM_SYMBOL
418
419 registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
420 registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
421 registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
422
423 // Used by MLIR lowering.
424 registry->Register("malloc", reinterpret_cast<void*>(malloc), "Host");
425 registry->Register("calloc", reinterpret_cast<void*>(calloc), "Host");
426 registry->Register("free", reinterpret_cast<void*>(free), "Host");
427 #ifndef _WIN32
428 // TODO(kramerb): This fails to link on windows because it's marked dllimport.
429 registry->Register("memrefCopy", reinterpret_cast<void*>(memrefCopy), "Host");
430 #endif
431
432 #ifdef __APPLE__
433 registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
434 registry->Register("bzero", reinterpret_cast<void*>(bzero), "Host");
435 registry->Register("memset_pattern16",
436 reinterpret_cast<void*>(memset_pattern16), "Host");
437 #endif
438
439 #ifdef MEMORY_SANITIZER
440 registry->Register("__msan_unpoison",
441 reinterpret_cast<void*>(__msan_unpoison), "Host");
442 #endif
443
444 #if defined(PLATFORM_WINDOWS)
445 registry->Register("__chkstk", reinterpret_cast<void*>(__chkstk), "Host");
446 #endif
447
448 return true;
449 }
450
451 bool unused = RegisterKnownJITSymbols();
452 } // namespace
453
454 } // namespace cpu
455 } // namespace xla
456