xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
17 
18 #include <stdlib.h>
19 
20 #include <fstream>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/base/call_once.h"
25 #include "llvm/IRReader/IRReader.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
28 #include "tensorflow/compiler/xla/service/call_inliner.h"
29 #include "tensorflow/compiler/xla/service/dump.h"
30 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
31 #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h"
32 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
33 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
34 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
35 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
36 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
37 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
38 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
39 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
40 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
41 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
42 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
43 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
44 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
45 #include "tensorflow/compiler/xla/service/gpu/metrics.h"
46 #include "tensorflow/compiler/xla/service/gpu/nvptx_helper.h"
47 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
48 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_rewriter.h"
49 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
50 #include "tensorflow/compiler/xla/service/hlo_cse.h"
51 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
52 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
53 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
54 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
55 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
56 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
57 #include "tensorflow/compiler/xla/status_macros.h"
58 #include "tensorflow/compiler/xla/types.h"
59 #include "tensorflow/compiler/xla/util.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/lib/io/path.h"
62 #include "tensorflow/core/platform/statusor.h"
63 #include "tensorflow/core/profiler/lib/traceme.h"
64 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
65 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
66 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
67 
68 namespace xla {
69 namespace gpu {
70 
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)71 Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
72     HloModule* hlo_module, se::StreamExecutor* stream_exec,
73     se::DeviceMemoryAllocator* device_allocator) {
74   // Convert convolutions into CustomCalls to cudnn, then canonicalize them
75   // (GpuConvPaddingLegalization). Also expand cuSolver calls.
76   HloPassPipeline pipeline("conv_canonicalization");
77   pipeline.AddInvariantCheckerDebug<HloVerifier>(
78       /*layout_sensitive=*/false,
79       /*allow_mixed_precision=*/false);
80   pipeline.AddPass<GpusolverRewriter>();
81   pipeline.AddPass<GpuConvRewriter>();
82   pipeline.AddPass<CudnnFusedConvRewriter>();
83   pipeline.AddPass<GpuConvPaddingLegalization>();
84   pipeline.AddPass<CudnnPadForConvolutions>(
85       stream_exec->GetDeviceDescription().cuda_compute_capability());
86   pipeline.AddPass<CudnnVectorizeConvolutions>(
87       stream_exec->GetDeviceDescription().cuda_compute_capability());
88   // The conv padding/vectorization passes which we need to get rid of.  They
89   // also leave behind unnecessary tuple/get-tuple-element pairs that
90   // TupleSimplifier fixes.
91   pipeline.AddPass<CallInliner>();
92   pipeline.AddPass<TupleSimplifier>();
93 
94   AlgebraicSimplifierOptions algsimp_options;
95   algsimp_options.set_enable_conv_operand_swap(false);
96   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);
97 
98   // CudnnSimplifyPadding gets rid of some padding introduced by
99   // CudnnPadForConvolutions and used by CudnnVectorizeConvolutions.  The
100   // pattern-matches in this pass need to be run after inlining and simplifying
101   // tuples from CudnnVectorizeConvolutions.  We also need to run algsimp to
102   // e.g. clean up unnecessary nop `convert`s.
103   pipeline.AddPass<CudnnSimplifyPadding>();
104 
105   // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
106   // CudnnSimplifyPadding introduce reshapes and transposes that can be
107   // eliminated using AlgebraicSimplifier  We run algsimp to a fixed point.
108   pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);
109 
110   // GpuConvRewriter, GpuConvPaddingLegalization and
111   // CudnnConvPadForTensorCores may add instructions which can be simplified
112   // by constant folding.
113   pipeline.AddPass<HloConstantFolding>();
114   TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
115 
116   return OkStatus();
117 }
118 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)119 Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
120     HloModule* hlo_module, se::StreamExecutor* stream_exec,
121     se::DeviceMemoryAllocator* device_allocator) {
122   HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1");
123 
124   // This needs to run before GemmRewriter, which is part of
125   // OptimizeHloPostLayoutAssignment().
126   if (stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
127           se::CudaComputeCapability::AMPERE)) {
128     pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::BF16,
129                                             /*pad_to_multiple_of=*/8);
130   }
131   if (stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
132           se::CudaComputeCapability::VOLTA)) {
133     // Pad gemms over S8 to multiples of 4 so cuBLAS can run them.
134     pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::S8,
135                                             /*pad_to_multiple_of=*/4);
136 
137     // Pad the dimensions of matrices in dot operations to multiples of 8.
138     pre_pipeline.AddPass<CublasPadForGemms>(PrimitiveType::F16,
139                                             /*pad_to_multiple_of=*/8);
140   }
141   // Padding a gemm operand that's a constant results in pad(constant).  Run
142   // constant-folding to simplify this into a new constant.
143   pre_pipeline.AddPass<HloConstantFolding>();
144   TF_RETURN_IF_ERROR(pre_pipeline.Run(hlo_module).status());
145 
146   TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
147       hlo_module, stream_exec, device_allocator));
148 
149   HloPassPipeline post_pipeline("nvptx post-layout_assignment part 2");
150 
151   // Find the fastest algorithm for GEMMs. Skip on Ampere and later as the
152   // algorithm goes unused.
153   if (!stream_exec->GetDeviceDescription().cuda_compute_capability().IsAtLeast(
154           se::CudaComputeCapability::AMPERE)) {
155     post_pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
156   }
157 
158   // Transform TriangularSolve ops into custom-calls, so we can add temp
159   // memory.
160   post_pipeline.AddPass<TriangularSolveRewriter>();
161 
162   TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
163 
164   return OkStatus();
165 }
166 
167 namespace {
CanShareBufferHint(const HloInstruction * user,const HloInstruction * operand,const ShapeIndex & user_index)168 std::optional<bool> CanShareBufferHint(const HloInstruction* user,
169                                        const HloInstruction* operand,
170                                        const ShapeIndex& user_index) {
171   switch (user->opcode()) {
172     case HloOpcode::kAllReduce:
173       // NCCL all-reduce can be performed in-place.
174       return user->operand_count() == 1 ||
175              (user_index.size() == 1 &&
176               user->operand(user_index[0]) == operand);
177     case HloOpcode::kCustomCall:
178       // The matrix bias operand can be overwritten in-place.
179       if (user->custom_call_target() == kCublasLtMatmulCallTarget) {
180         GemmBackendConfig config =
181             std::move(user->backend_config<GemmBackendConfig>()).ValueOrDie();
182         return (config.beta() != 0.) && user->operand(2) == operand;
183       }
184       // The operand of cholesky can be shared with the first output.
185       if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
186         return user_index.size() == 1 && user_index[0] == 0;
187       }
188       return false;
189     default:
190       return std::nullopt;
191   }
192 }
193 
194 // Try to load ptx from files defined in the FLAGS. If successful, return true.
MaybeLoadPtxFromFile(const HloModuleConfig module_config,const HloModule * module,std::string * ptx)195 bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
196                           const HloModule* module, std::string* ptx) {
197   // If the xla_gpu_ptx_file option is set, be explicit if a file is used
198   // and warn when a file is not used to ease catching typo in filename.
199   std::string prefix = xla::FilenameFor(*module, "", *ptx);
200   std::string matched_filename;
201   for (const std::string& full_filename :
202        module_config.debug_options().xla_gpu_ptx_file()) {
203     // To ease comparing many PTX versions, accept different suffixes then
204     // the original filename.
205     auto filename = tensorflow::io::Basename(full_filename);
206     if (absl::StartsWith(filename, prefix)) {
207       matched_filename = full_filename;
208       VLOG(1) << "RunBackend() - Will load PTX from file: " << full_filename;
209       break;
210     }
211   }
212   if (!module_config.debug_options().xla_gpu_ptx_file().empty() &&
213       matched_filename.empty()) {
214     VLOG(1) << "RunBackend() - For module with prefix '" << prefix
215             << "', we did not found a PTX file to load.";
216   }
217 
218   if (!matched_filename.empty()) {
219     std::ifstream ifs(matched_filename, std::ifstream::in);
220     *ptx = std::string(std::istreambuf_iterator<char>(ifs),
221                        std::istreambuf_iterator<char>());
222     CHECK(!ptx->empty()) << "Empty or non existing PTX file: "
223                          << matched_filename;
224     return true;
225   }
226   return false;
227 }
228 
229 // Try to load textual LLVM IR from files defined in the FLAGS. If
230 // successful, return the llvm::Module, otherwise return nullptr.
MaybeLoadLLVMFromFile(const HloModule * module,llvm::Module * llvm_module)231 std::unique_ptr<llvm::Module> MaybeLoadLLVMFromFile(const HloModule* module,
232                                                     llvm::Module* llvm_module) {
233   // If the xla_gpu_llvm_ir_file option is set, be explicit if a file is used
234   // and warn when a file is not used to ease catching typo in filename.
235   if (module == nullptr) {
236     return nullptr;
237   }
238 
239   std::string prefix = xla::FilenameFor(*module, "", "");
240   auto xla_gpu_llvm_ir_file =
241       module->config().debug_options().xla_gpu_llvm_ir_file();
242   auto matched_filename = absl::c_find_if(
243       xla_gpu_llvm_ir_file, [prefix](const std::string& full_filename) {
244         // To ease comparing many LLVM versions, accept different suffixes then
245         // the original filename.
246         return absl::StartsWith(tensorflow::io::Basename(full_filename),
247                                 prefix);
248       });
249   if (!xla_gpu_llvm_ir_file.empty() &&
250       matched_filename == std::end(xla_gpu_llvm_ir_file)) {
251     VLOG(1) << "RunBackend() - For module with prefix '" << prefix
252             << "', we did not found a LLVM file to load.";
253   }
254 
255   if (matched_filename != std::end(xla_gpu_llvm_ir_file)) {
256     VLOG(1) << "RunBackend() - Will load LLVM from file: " << *matched_filename;
257     llvm::LLVMContext& context = llvm_module->getContext();
258     llvm::SMDiagnostic err;
259     std::unique_ptr<llvm::Module> loaded_module =
260         llvm::parseIRFile(*matched_filename, err, context);
261 
262     if (!loaded_module) {
263       err.print("ERR", llvm::errs());
264       LOG(FATAL) << "Failed to load an LLVM file. It is probably invalid LLVM.";
265     }
266     // Overwrite the dumped not optimized LLVM to show which one will be used.
267     llvm_ir::DumpIrIfEnabled(*module, *loaded_module, /*optimized=*/false);
268     return loaded_module;
269   }
270   return nullptr;
271 }
272 
273 }  // namespace
274 
275 // Prints a warning if the ptx->sass JIT in the driver has known bugs.
276 //
277 // Using such a driver only a problem if we fail to use ptxas to compile our ptx
278 // and have to use the driver instead, so you should only call this function if
279 // we're going to use the driver JIT.
280 //
281 // Only prints a warning the first time it's called.
WarnIfBadDriverJITVersion()282 void WarnIfBadDriverJITVersion() {
283   static absl::once_flag run_once;
284   absl::call_once(run_once, [] {
285     auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
286     if (!version_or_status.ok()) {
287       LOG(WARNING) << "Couldn't read CUDA driver version.";
288       return;
289     }
290     se::cuda::DriverVersion version = version_or_status.ValueOrDie();
291 
292     // The following versions of the driver JIT miscompile some address
293     // calculations with large offsets (e.g. "load ptr + large_constant"),
294     // b/70245379:
295     //
296     //  - 384.x before 384.108
297     //  - 387.x before 387.40
298     //  - 390.x before 390.10.
299     //
300     // In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the
301     // fix for the "large multioutput fusions" miscompile, b/111107644.
302     if (version < std::make_tuple(396, 20, 0)) {
303       LOG(WARNING)
304           << "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
305           << se::cuda::DriverVersionToString(version)
306           << ", which is older than 396.20.0. These versions are known to "
307              "miscompile XLA code, leading to incorrect results or "
308              "invalid-address errors.\nXLA only uses the driver JIT if it "
309              "cannot find ptxas; you don't need to update your driver if "
310              "you can point XLA to ptxas 9.2.88 or newer.";
311     }
312   });
313 }
314 
NVPTXCompiler()315 NVPTXCompiler::NVPTXCompiler()
316     : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::TargetTriple(),
317                   nvptx::DataLayout()) {}
318 
GetCanShareBuffer()319 HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() {
320   return &CanShareBufferHint;
321 }
322 
GetGpuVersion(se::StreamExecutor * stream_exec)323 GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
324   return stream_exec->GetDeviceDescription().cuda_compute_capability();
325 }
326 
327 StatusOr<std::pair<std::string, std::vector<uint8_t>>>
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)328 NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
329                                    llvm::Module* llvm_module,
330                                    GpuVersion gpu_version,
331                                    se::StreamExecutor* stream_exec,
332                                    bool relocatable,
333                                    const HloModule* debug_module) {
334   std::string libdevice_dir;
335   {
336     absl::MutexLock lock(&mutex_);
337 
338     // Find the directory containing libdevice.  To avoid searching for it every
339     // time, we have a one-element cache, keyed on the module's config's
340     // cuda_data_dir.
341     if (cached_libdevice_dir_.empty()) {
342       cached_libdevice_dir_ = GetLibdeviceDir(module_config);
343     }
344     libdevice_dir = cached_libdevice_dir_;
345   }
346   VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
347   std::unique_ptr<llvm::Module> loaded_module =
348       MaybeLoadLLVMFromFile(debug_module, llvm_module);
349   llvm::Module* selected_module = nullptr;
350   if (loaded_module) {
351     selected_module = loaded_module.get();
352   } else {
353     selected_module = llvm_module;
354   }
355 
356   std::string ptx;
357   if (!(debug_module &&
358         MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
359     XLA_SCOPED_LOGGING_TIMER(
360         "NVPTXCompiler::CompileTargetBinary - CompileToPtx");
361     uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
362     TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(selected_module, gpu_version,
363                                                  module_config, libdevice_dir));
364 
365     uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
366     // This won't record values for calls that error out (because if they error
367     // out we have no way of telling how far through the process we got).
368     RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs);
369   }
370 
371   std::vector<uint8_t> cubin = CompileGpuAsmOrGetCachedResult(
372       stream_exec, ptx, std::get<se::CudaComputeCapability>(gpu_version),
373       module_config, relocatable);
374 
375   return std::pair<std::string, std::vector<uint8_t>>(std::move(ptx),
376                                                       std::move(cubin));
377 }
378 
CompileGpuAsmOrGetCachedResult(se::StreamExecutor * stream_exec,const std::string & ptx,se::CudaComputeCapability cc,const HloModuleConfig & hlo_module_config,bool relocatable)379 std::vector<uint8_t> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
380     se::StreamExecutor* stream_exec, const std::string& ptx,
381     se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config,
382     bool relocatable) {
383   XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
384   tensorflow::profiler::TraceMe activity(
385       "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
386   bool inserted;
387   decltype(compilation_cache_.begin()) iter;
388   // Pointers into compilation_cache_ where the ptx and (optional) cubin are
389   // stored.
390   const std::string* cache_ptx = nullptr;
391   CompilationCacheValue* cache_value = nullptr;
392 
393   {
394     absl::MutexLock lock(&mutex_);
395     std::tie(iter, inserted) = compilation_cache_.emplace(
396         std::piecewise_construct,
397         std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable),
398         std::forward_as_tuple());
399     cache_ptx = &iter->first.ptx;
400     cache_value = &iter->second;
401   }
402 
403   // Compile the ptx if it wasn't in the cache before we called this function.
404   // Other threads asking for the same compilation key will block on
405   // cache_value->mutex_ until compilation is done.
406   {
407     absl::MutexLock lock(&cache_value->mutex);
408     if (inserted) {
409       CHECK(!cache_value->compilation_done);
410       if (!ptx.empty()) {
411         auto ptxas_config =
412             PtxOptsFromDebugOptions(hlo_module_config.debug_options());
413         if (relocatable) {
414           ptxas_config.extra_flags.push_back("-c");
415         }
416         uint64_t start_usecs = tensorflow::Env::Default()->NowMicros();
417 
418         StatusOr<std::vector<uint8_t>> maybe_cubin = se::CompileGpuAsm(
419             stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
420 
421         if (maybe_cubin.ok()) {
422           uint64_t end_usecs = tensorflow::Env::Default()->NowMicros();
423           // This won't record values for calls that error out (because if they
424           // error out we have no way of telling how far through the process we
425           // got).
426           RecordPtxToCubinDuration(end_usecs - start_usecs);
427           cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
428           VLOG(1) << "Compiled PTX size:" << ptx.size()
429                   << " CUBIN size: " << cache_value->cubin_data.size();
430         } else {
431           if (maybe_cubin.status().code() ==
432               tensorflow::error::Code::NOT_FOUND) {
433             if (!hlo_module_config.debug_options()
434                      .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) {
435               LOG(WARNING) << CantFindCudaMessage(
436                   "Can't find ptxas binary in ${CUDA_DIR}/bin.  Custom ptxas "
437                   "location can be specified using $PATH.",
438                   hlo_module_config);
439               LOG(FATAL)
440                   << "Can't find ptxas binary.  You can pass the flag "
441                      "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found "
442                      "to use the GPU driver for compiling ptx instead. However "
443                      "this option is discouraged and can lead to increased "
444                      "memory consumptions and other subtle runtime issues.";
445             }
446             // Missing ptxas is expected in some environments where CUDA SDK
447             // binaries are not available. We don't want to spam logs with
448             // identical warnings in this case.
449 
450             LOG_FIRST_N(WARNING, 1) << CantFindCudaMessage(
451                 "Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to "
452                 "the GPU driver for PTX -> sass compilation.  This is OK so "
453                 "long as you don't see a warning below about an out-of-date "
454                 "driver version. Custom ptxas location can be specified "
455                 "using $PATH.",
456                 hlo_module_config);
457           } else if (maybe_cubin.status().code() !=
458                      tensorflow::error::Code::UNIMPLEMENTED) {
459             // If unimplemented is returned, we fallback to the driver.
460             LOG(FATAL) << "ptxas returned an error during compilation of ptx "
461                           "to sass: '"
462                        << maybe_cubin.status() << "'  "
463                        << "If the error message indicates that a file could "
464                           "not be written, please verify that sufficient "
465                           "filesystem space is provided.";
466           }
467 
468           // We're going to use the driver to JIT our PTX->SASS, so warn if
469           // the JIT in the driver has known bugs.
470           WarnIfBadDriverJITVersion();
471         }
472       }
473       cache_value->compilation_done = true;
474       cache_value->compilation_done_cv.SignalAll();
475     } else {
476       while (!cache_value->compilation_done) {
477         cache_value->compilation_done_cv.Wait(&cache_value->mutex);
478       }
479     }
480   }
481 
482   CHECK(cache_value != nullptr);
483   CHECK(cache_value->compilation_done);
484   return cache_value->cubin_data;
485 }
486 
LinkModules(se::StreamExecutor * stream_exec,std::vector<std::vector<uint8_t>> modules)487 StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
488     se::StreamExecutor* stream_exec,
489     std::vector<std::vector<uint8_t>> modules) {
490   std::vector<stream_executor::CubinOrPTXImage> images;
491   images.reserve(modules.size());
492   for (auto& module : modules) {
493     images.push_back({"", std::move(module)});
494   }
495   return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
496                         stream_exec->implementation()->GpuContextHack()),
497                     images);
498 }
499 
500 }  // namespace gpu
501 }  // namespace xla
502