1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
17
18 #include <functional>
19 #include <limits>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
27 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
28 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logger.h"
34 #include "tensorflow/core/platform/statusor.h"
35 #include "tensorflow/core/protobuf/autotuning.pb.h"
36 #include "tensorflow/core/util/proto/proto_utils.h"
37 #include "tensorflow/stream_executor/blas.h"
38 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
39 #include "tensorflow/stream_executor/device_memory.h"
40 #include "tensorflow/stream_executor/device_memory_allocator.h"
41 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
42
43 namespace xla {
44 namespace gpu {
45
46 using tensorflow::AutotuneResult;
47
48 namespace {
49
50 struct AutotuneConfig {
should_init_buffersxla::gpu::__anonf4bf08a10111::AutotuneConfig51 bool should_init_buffers() const { return autotune_level >= 2; }
should_reinit_output_bufferxla::gpu::__anonf4bf08a10111::AutotuneConfig52 bool should_reinit_output_buffer() const { return autotune_level >= 3; }
should_check_correctnessxla::gpu::__anonf4bf08a10111::AutotuneConfig53 bool should_check_correctness() const { return autotune_level >= 4; }
54
55 int32_t autotune_level;
56 bool should_crash_on_check_failure;
57 };
58
GetConfig(const DebugOptions & debug_options)59 AutotuneConfig GetConfig(const DebugOptions& debug_options) {
60 return {debug_options.xla_gpu_autotune_level(),
61 debug_options.xla_gpu_crash_on_verification_failures()};
62 }
63
CreateRedzoneAllocator(se::Stream * stream,se::DeviceMemoryAllocator * allocator,const DebugOptions & debug_options,const AutotuneConfig & config)64 se::RedzoneAllocator CreateRedzoneAllocator(
65 se::Stream* stream, se::DeviceMemoryAllocator* allocator,
66 const DebugOptions& debug_options, const AutotuneConfig& config) {
67 int64_t redzone_size = config.should_check_correctness()
68 ? se::RedzoneAllocator::kDefaultRedzoneSize
69 : 0;
70
71 return se::RedzoneAllocator(
72 stream, allocator, PtxOptsFromDebugOptions(debug_options),
73 /*memory_limit=*/std::numeric_limits<int64_t>::max(),
74 /*redzone_size=*/redzone_size);
75 }
76
CreateBuffer(se::RedzoneAllocator & allocator,const HloInstruction & op,const AutotuneConfig & config,int64_t & rng_state)77 StatusOr<se::DeviceMemoryBase> CreateBuffer(se::RedzoneAllocator& allocator,
78 const HloInstruction& op,
79 const AutotuneConfig& config,
80 int64_t& rng_state) {
81 TF_ASSIGN_OR_RETURN(
82 se::DeviceMemoryBase buffer,
83 allocator.AllocateBytes(ShapeUtil::ByteSizeOf(op.shape())));
84 if (config.should_init_buffers()) {
85 InitializeBuffer(allocator.stream(), op.shape().element_type(), &rng_state,
86 buffer);
87 }
88 return buffer;
89 }
90
91 // Returns the index (into `algorithms`) of the fastest algorithm.
92 template <typename AlgoT>
GetBestAlgorithm(se::Stream * stream,se::RedzoneAllocator & allocator,const HloInstruction & gemm,const AutotuneConfig & autotune_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,absl::Span<const AlgoT> algorithms,const std::function<StatusOr<se::blas::ProfileResult> (const AlgoT &)> & run_benchmark)93 StatusOr<std::optional<size_t>> GetBestAlgorithm(
94 se::Stream* stream, se::RedzoneAllocator& allocator,
95 const HloInstruction& gemm, const AutotuneConfig& autotune_config,
96 se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
97 se::DeviceMemoryBase output_buffer, absl::Span<const AlgoT> algorithms,
98 const std::function<StatusOr<se::blas::ProfileResult>(const AlgoT&)>&
99 run_benchmark) {
100 if (!stream->parent()->SynchronizeAllActivity()) {
101 return InternalError("Failed to synchronize GPU for autotuning.");
102 }
103
104 TF_ASSIGN_OR_RETURN(GemmBackendConfig backend_config,
105 gemm.backend_config<GemmBackendConfig>());
106
107 se::DeviceMemoryBase reference_buffer;
108 if (autotune_config.should_check_correctness()) {
109 TF_ASSIGN_OR_RETURN(
110 reference_buffer,
111 allocator.AllocateBytes(ShapeUtil::ByteSizeOf(gemm.shape())));
112 }
113
114 BufferComparator comparator(gemm.shape(), gemm.GetModule()->config());
115
116 std::vector<AutotuneResult> results;
117 std::optional<int64_t> reference_algorithm;
118
119 for (const AlgoT& algorithm : algorithms) {
120 // Make sure the output buffer always has the same value if we use
121 // the bias parameter.
122 if (autotune_config.should_reinit_output_buffer() &&
123 backend_config.beta() != 0) {
124 int64_t rng_state = 0;
125 InitializeBuffer(stream, gemm.shape().element_type(), &rng_state,
126 output_buffer);
127 }
128
129 TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result,
130 run_benchmark(algorithm));
131
132 results.emplace_back();
133 AutotuneResult& result = results.back();
134 result.mutable_gemm()->set_algorithm(profile_result.algorithm());
135
136 if (!profile_result.is_valid()) { // Unsupported algorithm.
137 result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
138 continue;
139 }
140
141 VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
142 << profile_result.elapsed_time_in_ms() << "ms";
143
144 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
145 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
146
147 if (!autotune_config.should_check_correctness()) {
148 continue;
149 }
150
151 TF_ASSIGN_OR_RETURN(
152 se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
153 allocator.CheckRedzones());
154
155 if (!rz_check_status.ok()) {
156 result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
157 *result.mutable_failure()->mutable_msg() =
158 rz_check_status.RedzoneFailureMsg();
159 LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
160 CHECK(!autotune_config.should_crash_on_check_failure);
161 continue;
162 }
163
164 if (!reference_algorithm) {
165 stream->ThenMemcpy(&reference_buffer, output_buffer,
166 output_buffer.size());
167 reference_algorithm = profile_result.algorithm();
168 } else {
169 // Perform the comparison.
170 TF_ASSIGN_OR_RETURN(
171 bool outputs_match,
172 comparator.CompareEqual(stream, output_buffer, reference_buffer));
173 if (!outputs_match) {
174 LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
175 << "This is likely a bug/unexpected loss of precision.";
176 CHECK(!autotune_config.should_crash_on_check_failure);
177
178 result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT);
179 result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
180 *reference_algorithm);
181 }
182 }
183 }
184
185 if (!autotune_config.should_crash_on_check_failure) {
186 tensorflow::AutotuningLog log;
187 for (const AutotuneResult& result : results) {
188 *log.add_results() = result;
189 }
190 tensorflow::Logger::GetSingleton()->LogProto(log);
191 }
192
193 StatusOr<AutotuneResult> best = PickBestResult(results, gemm);
194 if (best.ok()) {
195 for (size_t i = 0; i < results.size(); ++i) {
196 if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
197 return {i};
198 }
199 }
200 return InternalError("unknown best algorithm");
201 }
202
203 LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
204 "might be suboptimal: "
205 << best.status();
206 return {std::nullopt};
207 }
208
DoGemmAutotune(const HloInstruction * gemm,const GemmBackendConfig & gemm_config,se::DeviceMemoryAllocator * allocator,se::Stream * stream)209 StatusOr<std::optional<se::blas::AlgorithmType>> DoGemmAutotune(
210 const HloInstruction* gemm, const GemmBackendConfig& gemm_config,
211 se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
212 VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
213 const HloInstruction* lhs = gemm->operand(0);
214 const HloInstruction* rhs = gemm->operand(1);
215
216 TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
217 // Don't run autotuning concurrently on the same GPU.
218 absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent()));
219
220 auto key = std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
221 gemm->shape(), gemm_config.SerializeAsString(),
222 IsCublasLtMatmul(*gemm));
223
224 static absl::Mutex mutex(absl::kConstInit);
225 static auto& cache ABSL_GUARDED_BY(mutex) =
226 *new absl::flat_hash_map<decltype(key),
227 std::optional<se::blas::AlgorithmType>>();
228 static int64_t cache_hits ABSL_GUARDED_BY(mutex) = 0;
229 static int64_t cache_misses ABSL_GUARDED_BY(mutex) = 0;
230
231 absl::MutexLock lock(&mutex);
232 auto it = cache.find(key);
233 int64_t requests = cache_hits + cache_misses;
234 if (requests && requests % 10 == 0) {
235 VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
236 << requests;
237 }
238
239 if (it != cache.end()) {
240 cache_hits++;
241 VLOG(4) << "Autotuning cache hit, using algorithm: "
242 << (it->second.has_value() ? absl::StrCat(*(it->second))
243 : "<generic>");
244 return it->second;
245 }
246 cache_misses++;
247 VLOG(4) << "Autotuning cache miss";
248
249 const DebugOptions& debug_options =
250 gemm->GetModule()->config().debug_options();
251 AutotuneConfig autotune_config = GetConfig(debug_options);
252
253 se::RedzoneAllocator buffer_allocator =
254 CreateRedzoneAllocator(stream, allocator, debug_options, autotune_config);
255
256 int64_t rng_state = 0;
257 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
258 CreateBuffer(buffer_allocator, *gemm->operand(0),
259 autotune_config, rng_state));
260 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
261 CreateBuffer(buffer_allocator, *gemm->operand(1),
262 autotune_config, rng_state));
263 TF_ASSIGN_OR_RETURN(
264 se::DeviceMemoryBase output_buffer,
265 CreateBuffer(buffer_allocator, *gemm, autotune_config, rng_state));
266
267 std::optional<se::blas::AlgorithmType> best_algorithm;
268 if (IsCublasLtMatmul(*gemm)) {
269 bool has_matrix_bias = config.beta != 0.;
270 bool has_vector_bias = gemm_config.epilogue() == GemmBackendConfig::BIAS;
271
272 auto epilogue = has_vector_bias ? se::cuda::BlasLt::Epilogue::kBias
273 : se::cuda::BlasLt::Epilogue::kDefault;
274
275 se::DeviceMemoryBase bias_buffer;
276 if (has_vector_bias) {
277 TF_ASSIGN_OR_RETURN(bias_buffer,
278 CreateBuffer(buffer_allocator,
279 *gemm->operand(has_matrix_bias ? 3 : 2),
280 autotune_config, rng_state));
281 }
282
283 TF_ASSIGN_OR_RETURN(auto plan,
284 cublas_lt::MatmulPlan::From(config, epilogue));
285 TF_ASSIGN_OR_RETURN(
286 std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms,
287 plan.GetAlgorithms(stream));
288
289 TF_ASSIGN_OR_RETURN(
290 std::optional<size_t> best_algorithm_idx,
291 GetBestAlgorithm<se::cuda::BlasLt::MatmulAlgorithm>(
292 stream, buffer_allocator, *gemm, autotune_config, lhs_buffer,
293 rhs_buffer, output_buffer, algorithms,
294 [&](const se::cuda::BlasLt::MatmulAlgorithm& algorithm)
295 -> StatusOr<se::blas::ProfileResult> {
296 se::OwningScratchAllocator<> scratch_allocator(
297 stream->parent()->device_ordinal(), allocator);
298 se::blas::ProfileResult profile_result;
299 TF_RETURN_IF_ERROR(plan.ExecuteOnStream(
300 stream, lhs_buffer, rhs_buffer, output_buffer, output_buffer,
301 bias_buffer, algorithm, scratch_allocator, &profile_result));
302 return std::move(profile_result);
303 }));
304
305 TF_RET_CHECK(best_algorithm_idx) << "failed to auto-tune cublas_lt matmul";
306 best_algorithm = *best_algorithm_idx;
307 } else {
308 std::vector<se::blas::AlgorithmType> algorithms;
309 TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms));
310
311 TF_ASSIGN_OR_RETURN(std::optional<size_t> best_algorithm_idx,
312 GetBestAlgorithm<se::blas::AlgorithmType>(
313 stream, buffer_allocator, *gemm, autotune_config,
314 lhs_buffer, rhs_buffer, output_buffer, algorithms,
315 [&](const se::blas::AlgorithmType& algorithm)
316 -> StatusOr<se::blas::ProfileResult> {
317 se::blas::ProfileResult profile_result;
318 // We expect GemmWithAlgorithm to fail sometimes
319 // -- in fact, it will fail for all algorithms if
320 // we're targeting < sm_50. But because we pass a
321 // non-null ProfileResult, DoGemmWithAlgorithm
322 // should always return true, and the actual
323 // success-ness is returned in
324 // ProfileResult::is_valid.
325 TF_RETURN_IF_ERROR(RunGemm(
326 config, lhs_buffer, rhs_buffer, output_buffer,
327 stream, algorithm, &profile_result));
328 return std::move(profile_result);
329 }));
330
331 if (best_algorithm_idx) best_algorithm = algorithms[*best_algorithm_idx];
332 }
333
334 CHECK(cache.emplace(key, best_algorithm).second);
335 return best_algorithm;
336 }
337
RunOnInstruction(HloInstruction * instr,se::StreamExecutor * executor,se::DeviceMemoryAllocator * allocator)338 StatusOr<bool> RunOnInstruction(HloInstruction* instr,
339 se::StreamExecutor* executor,
340 se::DeviceMemoryAllocator* allocator) {
341 if (allocator == nullptr) {
342 allocator = executor->GetAllocator();
343 }
344 TF_ASSIGN_OR_RETURN(se::Stream* const stream,
345 allocator->GetStream(executor->device_ordinal()));
346
347 GemmBackendConfig gemm_config =
348 instr->backend_config<GemmBackendConfig>().ValueOrDie();
349
350 TF_ASSIGN_OR_RETURN(std::optional<se::blas::AlgorithmType> gemm_algorithm,
351 DoGemmAutotune(instr, gemm_config, allocator, stream));
352
353 // We update instruction->backend_config(); if no algorithms are supported,
354 // a different API is used, which does not require specifying an algorithm.
355 GemmBackendConfig updated_config = gemm_config;
356 if (gemm_algorithm) {
357 VLOG(4) << "GEMM autotuning picked algorithm " << *gemm_algorithm << " for "
358 << instr->name();
359 updated_config.set_selected_algorithm(*gemm_algorithm);
360 }
361 TF_RETURN_IF_ERROR(instr->set_backend_config(updated_config));
362 return updated_config.SerializeAsString() != gemm_config.SerializeAsString();
363 }
364
RunOnComputation(HloComputation * computation,se::StreamExecutor * se,se::DeviceMemoryAllocator * allocator)365 StatusOr<bool> RunOnComputation(HloComputation* computation,
366 se::StreamExecutor* se,
367 se::DeviceMemoryAllocator* allocator) {
368 bool changed = false;
369 for (HloInstruction* instr : computation->instructions()) {
370 if (IsCublasGemm(*instr) || IsCublasLtMatmul(*instr)) {
371 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, se, allocator));
372 changed |= result;
373 }
374 }
375 return changed;
376 }
377
378 } // namespace
379
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)380 StatusOr<bool> GemmAlgorithmPicker::Run(
381 HloModule* module,
382 const absl::flat_hash_set<absl::string_view>& execution_threads) {
383 XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker");
384
385 if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
386 VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
387 return false;
388 }
389
390 bool changed = false;
391 for (HloComputation* computation :
392 module->MakeNonfusionComputations(execution_threads)) {
393 TF_ASSIGN_OR_RETURN(
394 bool result, RunOnComputation(computation, stream_exec_, allocator_));
395 changed |= result;
396 }
397 return changed;
398 }
399
400 } // namespace gpu
401 } // namespace xla
402