xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/node_hash_map.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 void WarnIfBadDriverJITVersion();
33 
34 // NVPTXCompiler generates efficient GPU executables for NVPTX target.
35 class NVPTXCompiler : public GpuCompiler {
36  public:
37   NVPTXCompiler();
~NVPTXCompiler()38   ~NVPTXCompiler() override {}
39 
40   Status OptimizeHloConvolutionCanonicalization(
41       HloModule* hlo_module, se::StreamExecutor* stream_exec,
42       se::DeviceMemoryAllocator* device_allocator) override;
43 
44   Status OptimizeHloPostLayoutAssignment(
45       HloModule* hlo_module, se::StreamExecutor* stream_exec,
46       se::DeviceMemoryAllocator* device_allocator) override;
47 
48   HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() override;
49 
50   GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
51 
52   StatusOr<std::pair<std::string, std::vector<uint8_t>>> CompileTargetBinary(
53       const HloModuleConfig& module_config, llvm::Module* llvm_module,
54       GpuVersion gpu_version, se::StreamExecutor* stream_exec, bool relocatable,
55       const HloModule* debug_module) override;
56 
57  private:
58   StatusOr<std::vector<uint8_t>> LinkModules(
59       se::StreamExecutor* stream_exec,
60       std::vector<std::vector<uint8_t>> modules) override;
61 
62   absl::Mutex mutex_;
63 
64   // When compiling an HLO module, we need to find a path to the nvvm libdevice
65   // files.  We search in the module's config.debug_options().cuda_data_dir()
66   // and in tensorflow::LibdeviceRoot(), the latter of which is a constant.
67   //
68   // We cache the cuda_data_dir() and the result of our search, so that if the
69   // next module we have to compile has the same cuda_data_dir(), we can skip
70   // the search.
71   std::string cached_cuda_data_dir_ ABSL_GUARDED_BY(mutex_);
72   std::string cached_libdevice_dir_ ABSL_GUARDED_BY(mutex_);
73 
74   // Tries to compile the given ptx string to cubin.  Returns a vector with the
75   // compiled cubin.  If compilation was unsuccessful, returns an empty vector.
76   std::vector<uint8_t> CompileGpuAsmOrGetCachedResult(
77       se::StreamExecutor* stream_exec, const std::string& ptx,
78       se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config,
79       bool relocatable);
80 
81   // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
82   // -> cubin so we don't recompile the same ptx twice.  This is important for
83   // some interactive workflows.  (We also cache at the HLO level, but sometimes
84   // we can't realize that two modules are the same until we lower to ptx.)
85   //
86   // Compilation of distinct PTX happens in parallel. If more than one thread
87   // attempts to compile the same PTX, the fist thread to obtain
88   // cache_value_->mutex_ performs the compilation. The rest wait() on
89   // cache_value_->compilation_done_cv_ until the compilation is done.
90   //
91   // If compiling the ptx fails, we return an empty cubin, cross our fingers,
92   // and leave compilation up to the driver.
93   struct CompilationCacheKey {
CompilationCacheKeyCompilationCacheKey94     CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
95                         bool relocatable)
96         : ptx(std::move(ptx)),
97           cc_major(cc_major),
98           cc_minor(cc_minor),
99           relocatable(relocatable) {}
100     template <typename H>
AbslHashValueCompilationCacheKey101     friend H AbslHashValue(H h, const CompilationCacheKey& key) {
102       return H::combine(std::move(h), key.ptx, key.cc_major, key.cc_minor,
103                         key.relocatable);
104     }
105     friend bool operator==(const CompilationCacheKey& a,
106                            const CompilationCacheKey& b) {
107       return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
108              a.ptx == b.ptx && a.relocatable == b.relocatable;
109     }
110     std::string ptx;
111     int cc_major;
112     int cc_minor;
113     bool relocatable;
114   };
115   struct CompilationCacheValue {
116     bool compilation_done = false;
117     std::vector<uint8_t> cubin_data;
118     // mutex and condition variable to serialize compilation completing.
119     absl::Mutex mutex;
120     absl::CondVar compilation_done_cv;
121   };
122 
123   // Don't even think about switching this to flat_hash_map; iterator stability
124   // is critical here.
125   absl::node_hash_map<CompilationCacheKey, CompilationCacheValue>
126       compilation_cache_ ABSL_GUARDED_BY(mutex_);
127 
128   NVPTXCompiler(const NVPTXCompiler&) = delete;
129   NVPTXCompiler& operator=(const NVPTXCompiler&) = delete;
130 };
131 
132 }  // namespace gpu
133 }  // namespace xla
134 
135 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_
136