1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
17
18 #include <stdlib.h>
19
20 #include <fstream>
21 #include <string>
22 #include <utility>
23
24 #include "absl/base/call_once.h"
25 #include "llvm/IRReader/IRReader.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
28 #include "tensorflow/compiler/xla/service/call_inliner.h"
29 #include "tensorflow/compiler/xla/service/dump.h"
30 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
31 #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h"
32 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
33 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
34 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
35 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
36 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
37 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
38 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
39 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
40 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
41 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
42 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
43 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
44 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
45 #include "tensorflow/compiler/xla/service/gpu/metrics.h"
46 #include "tensorflow/compiler/xla/service/gpu/nvptx_helper.h"
47 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
48 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_rewriter.h"
49 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
50 #include "tensorflow/compiler/xla/service/hlo_cse.h"
51 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
52 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
53 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
54 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
55 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
56 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
57 #include "tensorflow/compiler/xla/status_macros.h"
58 #include "tensorflow/compiler/xla/types.h"
59 #include "tensorflow/compiler/xla/util.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/lib/io/path.h"
62 #include "tensorflow/core/platform/statusor.h"
63 #include "tensorflow/core/profiler/lib/traceme.h"
64 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
65 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
66 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
67
68 namespace xla {
69 namespace gpu {
70
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)71 Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
72 HloModule* hlo_module, se::StreamExecutor* stream_exec,
73 se::DeviceMemoryAllocator* device_allocator) {
74 // Convert convolutions into CustomCalls to cudnn, then canonicalize them
75 // (GpuConvPaddingLegalization). Also expand cuSolver calls.
76 HloPassPipeline pipeline("conv_canonicalization");
77 pipeline.AddInvariantCheckerDebug<HloVerifier>(
78 /*layout_sensitive=*/false,
79 /*allow_mixed_precision=*/false);
80 pipeline.AddPass<GpusolverRewriter>();
81 pipeline.AddPass<GpuConvRewriter>();
82 pipeline.AddPass<CudnnFusedConvRewriter>();
83 pipeline.AddPass<GpuConvPaddingLegalization>();
84 pipeline.AddPass<CudnnPadForConvolutions>(
85 stream_exec->GetDeviceDescription().cuda_compute_capability());
86 pipeline.AddPass<CudnnVectorizeConvolutions>(
87 stream_exec->GetDeviceDescription().cuda_compute_capability());
88 // The conv padding/vectorization passes which we need to get rid of. They
89 // also leave behind unnecessary tuple/get-tuple-element pairs that
90 // TupleSimplifier fixes.
91 pipeline.AddPass<CallInliner>();
92 pipeline.AddPass<TupleSimplifier>();
93
94 AlgebraicSimplifierOptions algsimp_options;
95 algsimp_options.set_enable_conv_operand_swap(false);
96 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);
97
98 // CudnnSimplifyPadding gets rid of some padding introduced by
99 // CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
100 // pattern-matches in this pass need to be run after inlining and simplifying
101 // tuples from CudnnVectorizeConvolutions. We also need to run algsimp to
102 // e.g. clean up unnecessary nop `convert`s.
103 pipeline.AddPass<CudnnSimplifyPadding>();
104
105 // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
106 // CudnnSimplifyPadding introduce reshapes and transposes that can be
107 // eliminated using AlgebraicSimplifier We run algsimp to a fixed point.
108 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);
109
110 // GpuConvRewriter, GpuConvPaddingLegalization and
111 // CudnnConvPadForTensorCores may add instructions which can be simplified
112 // by constant folding.
113 pipeline.AddPass<HloConstantFolding>();
114 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
115
116 return OkStatus();
117 }
118
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)119 Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
120 HloModule* hlo_module, se::StreamExecutor* stream_exec,
121 se::DeviceMemoryAllocator* device_allocator) {
122 HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1");
123
124 // This needs to run before GemmRewriter, which is part of
125 // OptimizeHloPostLayoutAssignment().
126 if (stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
127 se::CudaComputeCapability::AMPERE)) {
128 pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::BF16,
129 /*pad_to_multiple_of=*/8);
130 }
131 if (stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
132 se::CudaComputeCapability::VOLTA)) {
133 // Pad gemms over S8 to multiples of 4 so cuBLAS can run them.
134 pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::S8,
135 /*pad_to_multiple_of=*/4);
136
137 // Pad the dimensions of matrices in dot operations to multiples of 8.
138 pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::F16,
139 /*pad_to_multiple_of=*/8);
140 }
141 // Padding a gemm operand that's a constant results in pad(constant). Run
142 // constant-folding to simplify this into a new constant.
143 pre_pipeline.AddPass<HloConstantFolding>();
144 TF_RETURN_IF_ERROR(pre_pipeline.Run(hlo_module).status());
145
146 TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
147 hlo_module, stream_exec, device_allocator));
148
149 HloPassPipeline post_pipeline("nvptx post-layout_assignment part 2");
150
151 // Find the fastest algorithm for GEMMs. Skip on Ampere and later as the
152 // algorithm goes unused.
153 if (!stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
154 se::CudaComputeCapability::AMPERE)) {
155 post_pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
156 }
157
158 // Transform TriangularSolve ops into custom-calls, so we can add temp
159 // memory.
160 post_pipeline.AddPass<TriangularSolveRewriter>();
161
162 TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
163
164 return OkStatus();
165 }
166
167 namespace {
CanShareBufferHint(const HloInstruction * user,const HloInstruction * operand,const ShapeIndex & user_index)168 std::optional<bool> CanShareBufferHint(const HloInstruction* user,
169 const HloInstruction* operand,
170 const ShapeIndex& user_index) {
171 switch (user->opcode()) {
172 case HloOpcode::kAllReduce:
173 // NCCL all-reduce can be performed in-place.
174 return user->operand_count() == 1 ||
175 (user_index.size() == 1 &&
176 user->operand(user_index[0]) == operand);
177 case HloOpcode::kCustomCall:
178 // The matrix bias operand can be overwritten in-place.
179 if (user->custom_call_target() == kCublasLtMatmulCallTarget) {
180 GemmBackendConfig config =
181 std::move(user->backend_config<GemmBackendConfig>()).ValueOrDie();
182 return (config.beta() != 0.) && user->operand(2) == operand;
183 }
184 // The operand of cholesky can be shared with the first output.
185 if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
186 return user_index.size() == 1 && user_index[0] == 0;
187 }
188 return false;
189 default:
190 return std::nullopt;
191 }
192 }
193
194 // Try to load ptx from files defined in the FLAGS. If successful, return true.
MaybeLoadPtxFromFile(const HloModuleConfig module_config,const HloModule * module,std::string * ptx)195 bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
196 const HloModule* module, std::string* ptx) {
197 // If the xla_gpu_ptx_file option is set, be explicit if a file is used
198 // and warn when a file is not used to ease catching typo in filename.
199 std::string prefix = xla::FilenameFor(*module, "", *ptx);
200 std::string matched_filename;
201 for (const std::string& full_filename :
202 module_config.debug_options().xla_gpu_ptx_file()) {
203 // To ease comparing many PTX versions, accept different suffixes then
204 // the original filename.
205 auto filename = tensorflow::io::Basename(full_filename);
206 if (absl::StartsWith(filename, prefix)) {
207 matched_filename = full_filename;
208 VLOG(1) << "RunBackend() - Will load PTX from file: " << full_filename;
209 break;
210 }
211 }
212 if (!module_config.debug_options().xla_gpu_ptx_file().empty() &&
213 matched_filename.empty()) {
214 VLOG(1) << "RunBackend() - For module with prefix '" << prefix
215 << "', we did not found a PTX file to load.";
216 }
217
218 if (!matched_filename.empty()) {
219 std::ifstream ifs(matched_filename, std::ifstream::in);
220 *ptx = std::string(std::istreambuf_iterator<char>(ifs),
221 std::istreambuf_iterator<char>());
222 CHECK(!ptx->empty()) << "Empty or non existing PTX file: "
223 << matched_filename;
224 return true;
225 }
226 return false;
227 }
228
229 // Try to load textual LLVM IR from files defined in the FLAGS. If
230 // successful, return the llvm::Module, otherwise return nullptr.
MaybeLoadLLVMFromFile(const HloModule * module,llvm::Module * llvm_module)231 std::unique_ptr<llvm::Module> MaybeLoadLLVMFromFile(const HloModule* module,
232 llvm::Module* llvm_module) {
233 // If the xla_gpu_llvm_ir_file option is set, be explicit if a file is used
234 // and warn when a file is not used to ease catching typo in filename.
235 if (module == nullptr) {
236 return nullptr;
237 }
238
239 std::string prefix = xla::FilenameFor(*module, "", "");
240 auto xla_gpu_llvm_ir_file =
241 module->config().debug_options().xla_gpu_llvm_ir_file();
242 auto matched_filename = absl::c_find_if(
243 xla_gpu_llvm_ir_file, [prefix](const std::string& full_filename) {
244 // To ease comparing many LLVM versions, accept different suffixes then
245 // the original filename.
246 return absl::StartsWith(tensorflow::io::Basename(full_filename),
247 prefix);
248 });
249 if (!xla_gpu_llvm_ir_file.empty() &&
250 matched_filename == std::end(xla_gpu_llvm_ir_file)) {
251 VLOG(1) << "RunBackend() - For module with prefix '" << prefix
252 << "', we did not found a LLVM file to load.";
253 }
254
255 if (matched_filename != std::end(xla_gpu_llvm_ir_file)) {
256 VLOG(1) << "RunBackend() - Will load LLVM from file: " << *matched_filename;
257 llvm::LLVMContext& context = llvm_module->getContext();
258 llvm::SMDiagnostic err;
259 std::unique_ptr<llvm::Module> loaded_module =
260 llvm::parseIRFile(*matched_filename, err, context);
261
262 if (!loaded_module) {
263 err.print("ERR", llvm::errs());
264 LOG(FATAL) << "Failed to load an LLVM file. It is probably invalid LLVM.";
265 }
266 // Overwrite the dumped not optimized LLVM to show which one will be used.
267 llvm_ir::DumpIrIfEnabled(*module, *loaded_module, /*optimized=*/false);
268 return loaded_module;
269 }
270 return nullptr;
271 }
272
273 } // namespace
274
275 // Prints a warning if the ptx->sass JIT in the driver has known bugs.
276 //
277 // Using such a driver only a problem if we fail to use ptxas to compile our ptx
278 // and have to use the driver instead, so you should only call this function if
279 // we're going to use the driver JIT.
280 //
281 // Only prints a warning the first time it's called.
WarnIfBadDriverJITVersion()282 void WarnIfBadDriverJITVersion() {
283 static absl::once_flag run_once;
284 absl::call_once(run_once, [] {
285 auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
286 if (!version_or_status.ok()) {
287 LOG(WARNING) << "Couldn't read CUDA driver version.";
288 return;
289 }
290 se::cuda::DriverVersion version = version_or_status.ValueOrDie();
291
292 // The following versions of the driver JIT miscompile some address
293 // calculations with large offsets (e.g. "load ptr + large_constant"),
294 // b/70245379:
295 //
296 // - 384.x before 384.108
297 // - 387.x before 387.40
298 // - 390.x before 390.10.
299 //
300 // In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the
301 // fix for the "large multioutput fusions" miscompile, b/111107644.
302 if (version < std::make_tuple(396, 20, 0)) {
303 LOG(WARNING)
304 << "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
305 << se::cuda::DriverVersionToString(version)
306 << ", which is older than 396.20.0. These versions are known to "
307 "miscompile XLA code, leading to incorrect results or "
308 "invalid-address errors.\nXLA only uses the driver JIT if it "
309 "cannot find ptxas; you don't need to update your driver if "
310 "you can point XLA to ptxas 9.2.88 or newer.";
311 }
312 });
313 }
314
NVPTXCompiler()315 NVPTXCompiler::NVPTXCompiler()
316 : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::TargetTriple(),
317 nvptx::DataLayout()) {}
318
GetCanShareBuffer()319 HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() {
320 return &CanShareBufferHint;
321 }
322
GetGpuVersion(se::StreamExecutor * stream_exec)323 GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
324 return stream_exec->GetDeviceDescription().cuda_compute_capability();
325 }
326
327 StatusOr<std::pair<std::string, std::vector<uint8_t>>>
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)328 NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
329 llvm::Module* llvm_module,
330 GpuVersion gpu_version,
331 se::StreamExecutor* stream_exec,
332 bool relocatable,
333 const HloModule* debug_module) {
334 std::string libdevice_dir;
335 {
336 absl::MutexLock lock(&mutex_);
337
338 // Find the directory containing libdevice. To avoid searching for it every
339 // time, we have a one-element cache, keyed on the module's config's
340 // cuda_data_dir.
341 if (cached_libdevice_dir_.empty()) {
342 cached_libdevice_dir_ = GetLibdeviceDir(module_config);
343 }
344 libdevice_dir = cached_libdevice_dir_;
345 }
346 VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
347 std::unique_ptr<llvm::Module> loaded_module =
348 MaybeLoadLLVMFromFile(debug_module, llvm_module);
349 llvm::Module* selected_module = nullptr;
350 if (loaded_module) {
351 selected_module = loaded_module.get();
352 } else {
353 selected_module = llvm_module;
354 }
355
356 std::string ptx;
357 if (!(debug_module &&
358 MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
359 XLA_SCOPED_LOGGING_TIMER(
360 "NVPTXCompiler::CompileTargetBinary - CompileToPtx");
361 uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
362 TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(selected_module, gpu_version,
363 module_config, libdevice_dir));
364
365 uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
366 // This won't record values for calls that error out (because if they error
367 // out we have no way of telling how far through the process we got).
368 RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs);
369 }
370
371 std::vector<uint8_t> cubin = CompileGpuAsmOrGetCachedResult(
372 stream_exec, ptx, std::get<se::CudaComputeCapability>(gpu_version),
373 module_config, relocatable);
374
375 return std::pair<std::string, std::vector<uint8_t>>(std::move(ptx),
376 std::move(cubin));
377 }
378
CompileGpuAsmOrGetCachedResult(se::StreamExecutor * stream_exec,const std::string & ptx,se::CudaComputeCapability cc,const HloModuleConfig & hlo_module_config,bool relocatable)379 std::vector<uint8_t> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
380 se::StreamExecutor* stream_exec, const std::string& ptx,
381 se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config,
382 bool relocatable) {
383 XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
384 tensorflow::profiler::TraceMe activity(
385 "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
386 bool inserted;
387 decltype(compilation_cache_.begin()) iter;
388 // Pointers into compilation_cache_ where the ptx and (optional) cubin are
389 // stored.
390 const std::string* cache_ptx = nullptr;
391 CompilationCacheValue* cache_value = nullptr;
392
393 {
394 absl::MutexLock lock(&mutex_);
395 std::tie(iter, inserted) = compilation_cache_.emplace(
396 std::piecewise_construct,
397 std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable),
398 std::forward_as_tuple());
399 cache_ptx = &iter->first.ptx;
400 cache_value = &iter->second;
401 }
402
403 // Compile the ptx if it wasn't in the cache before we called this function.
404 // Other threads asking for the same compilation key will block on
405 // cache_value->mutex_ until compilation is done.
406 {
407 absl::MutexLock lock(&cache_value->mutex);
408 if (inserted) {
409 CHECK(!cache_value->compilation_done);
410 if (!ptx.empty()) {
411 auto ptxas_config =
412 PtxOptsFromDebugOptions(hlo_module_config.debug_options());
413 if (relocatable) {
414 ptxas_config.extra_flags.push_back("-c");
415 }
416 uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
417
418 StatusOr<std::vector<uint8_t>> maybe_cubin = se::CompileGpuAsm(
419 stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
420
421 if (maybe_cubin.ok()) {
422 uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
423 // This won't record values for calls that error out (because if they
424 // error out we have no way of telling how far through the process we
425 // got).
426 RecordPtxToCubinDuration(end_usecs - start_usecs);
427 cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
428 VLOG(1) << "Compiled PTX size:" << ptx.size()
429 << " CUBIN size: " << cache_value->cubin_data.size();
430 } else {
431 if (maybe_cubin.status().code() ==
432 tensorflow::error::Code::NOT_FOUND) {
433 if (!hlo_module_config.debug_options()
434 .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) {
435 LOG(WARNING) << CantFindCudaMessage(
436 "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas "
437 "location can be specified using $PATH.",
438 hlo_module_config);
439 LOG(FATAL)
440 << "Can't find ptxas binary. You can pass the flag "
441 "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found "
442 "to use the GPU driver for compiling ptx instead. However "
443 "this option is discouraged and can lead to increased "
444 "memory consumptions and other subtle runtime issues.";
445 }
446 // Missing ptxas is expected in some environments where CUDA SDK
447 // binaries are not available. We don't want to spam logs with
448 // identical warnings in this case.
449
450 LOG_FIRST_N(WARNING, 1) << CantFindCudaMessage(
451 "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to "
452 "the GPU driver for PTX -> sass compilation. This is OK so "
453 "long as you don't see a warning below about an out-of-date "
454 "driver version. Custom ptxas location can be specified "
455 "using $PATH.",
456 hlo_module_config);
457 } else if (maybe_cubin.status().code() !=
458 tensorflow::error::Code::UNIMPLEMENTED) {
459 // If unimplemented is returned, we fallback to the driver.
460 LOG(FATAL) << "ptxas returned an error during compilation of ptx "
461 "to sass: '"
462 << maybe_cubin.status() << "' "
463 << "If the error message indicates that a file could "
464 "not be written, please verify that sufficient "
465 "filesystem space is provided.";
466 }
467
468 // We're going to use the driver to JIT our PTX->SASS, so warn if
469 // the JIT in the driver has known bugs.
470 WarnIfBadDriverJITVersion();
471 }
472 }
473 cache_value->compilation_done = true;
474 cache_value->compilation_done_cv.SignalAll();
475 } else {
476 while (!cache_value->compilation_done) {
477 cache_value->compilation_done_cv.Wait(&cache_value->mutex);
478 }
479 }
480 }
481
482 CHECK(cache_value != nullptr);
483 CHECK(cache_value->compilation_done);
484 return cache_value->cubin_data;
485 }
486
LinkModules(se::StreamExecutor * stream_exec,std::vector<std::vector<uint8_t>> modules)487 StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
488 se::StreamExecutor* stream_exec,
489 std::vector<std::vector<uint8_t>> modules) {
490 std::vector<stream_executor::CubinOrPTXImage> images;
491 images.reserve(modules.size());
492 for (auto& module : modules) {
493 images.push_back({"", std::move(module)});
494 }
495 return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
496 stream_exec->implementation()->GpuContextHack()),
497 images);
498 }
499
500 } // namespace gpu
501 } // namespace xla
502