1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
17
18 #include <limits>
19 #include <optional>
20 #include <string>
21 #include <utility>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
30 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
32 #include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
33 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
34 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 #include "tensorflow/core/platform/logger.h"
43 #include "tensorflow/core/util/env_var.h"
44 #include "tensorflow/core/util/proto/proto_utils.h"
45
46 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
47 #include "third_party/gpus/cudnn/cudnn.h"
48 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
49 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
50 #endif
51
52 namespace xla {
53 namespace gpu {
54 namespace {
55
56 using se::DeviceMemoryBase;
57 using se::dnn::AlgorithmDesc;
58 using std::optional;
59 using tensorflow::AutotuneResult;
60
61 class ScratchAllocator : public se::ScratchAllocator {
62 public:
ScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)63 ScratchAllocator(int device_ordinal,
64 se::DeviceMemoryAllocator* memory_allocator)
65 : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
66
GetMemoryLimitInBytes()67 int64_t GetMemoryLimitInBytes() override {
68 return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
69 }
TotalAllocatedBytes()70 int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
71
72 StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(int64_t byte_size) override;
73
74 template <typename T>
Allocate(int64_t num_elements)75 StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
76 TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
77 AllocateBytes(num_elements * sizeof(T)));
78 return se::DeviceMemory<T>(bytes);
79 }
80
81 private:
82 const int device_ordinal_;
83 se::DeviceMemoryAllocator* memory_allocator_;
84 std::vector<se::OwningDeviceMemory> allocated_buffers_;
85 int64_t total_allocated_bytes_ = 0;
86 };
87
AllocateBytes(int64_t byte_size)88 StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
89 int64_t byte_size) {
90 CHECK_GE(byte_size, 0) << "byte_size must be positive.";
91 if (byte_size > GetMemoryLimitInBytes()) {
92 return se::port::Status(
93 se::port::error::RESOURCE_EXHAUSTED,
94 absl::StrFormat(
95 "Allocating %d bytes exceeds the memory limit of %d bytes.",
96 byte_size, GetMemoryLimitInBytes()));
97 }
98
99 TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
100 memory_allocator_->Allocate(device_ordinal_, byte_size,
101 /*retry_on_failure=*/false));
102 total_allocated_bytes_ += byte_size;
103
104 se::DeviceMemoryBase buffer_addr = *allocated_buffer;
105 allocated_buffers_.push_back(std::move(allocated_buffer));
106 return se::DeviceMemory<uint8_t>(buffer_addr);
107 }
108
GetAlgorithms(const GpuConvConfig & config,se::Stream * stream,bool use_cudnn_frontend,bool use_fallback)109 StatusOr<std::vector<MaybeFusedConvRunner>> GetAlgorithms(
110 const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
111 bool use_fallback) {
112 TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
113 GetDNNConvKindFromCudnnConvKind(config.kind));
114
115 TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
116 GetDNNDataTypeFromPrimitiveType(config.input_type));
117
118 TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
119 GetDNNDataTypeFromPrimitiveType(config.output_type));
120
121 se::StreamExecutor* stream_exec = stream->parent();
122
123 std::vector<MaybeFusedConvRunner> result;
124
125 switch (kind) {
126 default:
127 return InternalError("Unknown ConvolutionKind %d", kind);
128 case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
129 if (!config.fusion) {
130 return InternalError(
131 "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
132 }
133 std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
134 TF_RETURN_IF_ERROR(stream_exec->GetFusedConvolveRunners(
135 use_cudnn_frontend,
136 // This refers to the kind of convolution op inside the fusion, not
137 // the whole fused graph.
138 se::dnn::ConvolutionKind::FORWARD, input_type,
139 BiasTypeForInputType(input_type), output_type,
140 /* conv_input_scale = */ config.conv_result_scale,
141 /* side_input_scale = */ config.fusion->side_input_scale,
142 /* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor,
143 config.filter_descriptor, GetBiasDescriptor(config),
144 config.output_descriptor, config.conv_desc, use_fallback,
145 config.fusion->mode, &runners));
146 for (auto& runner : runners) {
147 TF_ASSIGN_OR_RETURN(
148 auto runner_cache,
149 se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
150 std::move(runner)));
151 result.emplace_back(std::move(runner_cache));
152 }
153 break;
154 }
155
156 case se::dnn::ConvolutionKind::FORWARD:
157 case se::dnn::ConvolutionKind::BACKWARD_DATA:
158 case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
159 std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
160 // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
161 // allocator are unused; so, they're all provided as nullptr.
162 TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners(
163 use_cudnn_frontend, kind, input_type, output_type, stream,
164 config.input_descriptor,
165 /* input_data = */ DeviceMemoryBase(nullptr),
166 config.filter_descriptor,
167 /* filter_data = */ DeviceMemoryBase(nullptr),
168 config.output_descriptor,
169 /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
170 use_fallback, nullptr, &runners));
171 for (auto& runner : runners) {
172 TF_ASSIGN_OR_RETURN(
173 auto runner_cache,
174 se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
175 std::move(runner)));
176 result.emplace_back(std::move(runner_cache));
177 }
178 break;
179 }
180 }
181
182 return result;
183 }
184
185 StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
GetMIOpenAlgorithms(const HloCustomCallInstruction * instr,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::StreamExecutor * stream_exec,ScratchAllocator * scratch_allocator,se::Stream * stream)186 GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
187 absl::Span<se::DeviceMemoryBase> operand_buffers,
188 se::DeviceMemoryBase result_buffer,
189 se::StreamExecutor* stream_exec,
190 ScratchAllocator* scratch_allocator, se::Stream* stream) {
191 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
192
193 TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
194 GetDNNConvKindFromCudnnConvKind(config.kind));
195
196 TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
197 GetDNNDataTypeFromPrimitiveType(config.output_type));
198
199 TF_ASSIGN_OR_RETURN(GpuConvParams params,
200 GetGpuConvParams(config, operand_buffers, result_buffer));
201
202 std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
203 TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners(
204 /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
205 params.config->input_descriptor, params.input_buf,
206 params.config->filter_descriptor, params.filter_buf,
207 params.config->output_descriptor, params.output_buf,
208 params.config->conv_desc, /* use_fallback = */ false, scratch_allocator,
209 &runners));
210
211 return runners;
212 }
213
NumBytesToString(int64_t bytes)214 std::string NumBytesToString(int64_t bytes) {
215 return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
216 bytes, "B)");
217 }
218
GetCudnnVersion(se::StreamExecutor * stream_executor)219 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
220 tensorflow::CudnnVersion cudnn_version;
221 if (auto* dnn = stream_executor->AsDnn()) {
222 StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
223 if (version_or.ok()) {
224 const auto& version = version_or.ValueOrDie();
225 cudnn_version.set_major(version.major_version());
226 cudnn_version.set_minor(version.minor_version());
227 cudnn_version.set_patch(version.patch());
228 }
229 }
230 return cudnn_version;
231 }
232
GetComputeCapability(se::StreamExecutor * stream_executor)233 tensorflow::ComputeCapability GetComputeCapability(
234 se::StreamExecutor* stream_executor) {
235 tensorflow::ComputeCapability cc;
236 se::CudaComputeCapability se_cc =
237 stream_executor->GetDeviceDescription().cuda_compute_capability();
238 cc.set_major(se_cc.major);
239 cc.set_minor(se_cc.minor);
240 return cc;
241 }
242
PrintPlatformInfo(const se::Stream * stream)243 void PrintPlatformInfo(const se::Stream* stream) {
244 auto* se = stream->parent();
245 const auto& desc = se->GetDeviceDescription();
246 LOG(ERROR) << "Device: " << desc.name();
247 LOG(ERROR) << "Platform: " << desc.platform_version();
248 LOG(ERROR) << "Driver: " << desc.driver_version();
249 LOG(ERROR) << "Runtime: " << desc.runtime_version();
250
251 auto* dnn = se->AsDnn();
252 if (dnn) {
253 auto dnn_version = dnn->GetVersion();
254 if (dnn_version.ok()) {
255 auto v = dnn_version.ValueOrDie();
256 LOG(ERROR) << "cudnn version: " << v.major_version() << "."
257 << v.minor_version() << "." << v.patch();
258 }
259 }
260 }
261
262 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
263 // Returns true if the redzones in `allocator`'s allocations are unmodified.
264 //
265 // If the redzones are modified, logs an error, sets the appropriate failure
266 // bits on `result`, and returns false.
267 //
268 // Returns a status if an unexpected error has occurred, and the stream
269 // has been poisoned.
270 //
271 // `name` is a user-friendly name for the set of redzones being checked, e.g.
272 // "input/output" or "scratch".
CheckRedzones(const se::RedzoneAllocator & allocator,se::Stream * stream,absl::string_view name,const HloInstruction * instr,AutotuneResult * result)273 StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
274 se::Stream* stream, absl::string_view name,
275 const HloInstruction* instr,
276 AutotuneResult* result) {
277 XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
278 2);
279 using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
280 TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
281 allocator.CheckRedzones());
282 if (redzone_check.ok()) {
283 return true;
284 }
285
286 auto* fail = result->mutable_failure();
287 fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
288 *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
289 fail->set_buffer_address(
290 reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
291
292 LOG(ERROR) << absl::StreamFormat(
293 "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
294 "cudnn bug. We will skip this algorithm in the future, but your GPU "
295 "state may already be corrupted, leading to incorrect results. Within "
296 "Google, no action is needed on your part. Outside of Google, please "
297 "ensure you're running the latest version of cudnn. If that doesn't fix "
298 "the problem, please file a bug with this full error message and we'll "
299 "contact nvidia.",
300 name);
301 LOG(ERROR) << redzone_check.RedzoneFailureMsg();
302 LOG(ERROR) << "HloInstruction " << instr->ToString();
303 PrintPlatformInfo(stream);
304 return false;
305 }
306 #endif
307
308 using ConvCacheKey =
309 std::tuple<se::StreamExecutor*,
310 /* conv->ToString(HloPrintOptions::Canonical()) */ std::string>;
311
312 struct ConvCacheStats {
313 int64_t cache_hits = 0;
314 int64_t cache_misses = 0;
315
LogStatsxla::gpu::__anon042fc3bc0111::ConvCacheStats316 void LogStats() {
317 VLOG(2) << "Cache hits: " << cache_hits;
318 VLOG(2) << "Cache misses: " << cache_misses;
319 }
320 };
321
AutotuneCacheKeyfromInstruction(const HloCustomCallInstruction * conv,se::StreamExecutor * se)322 ConvCacheKey AutotuneCacheKeyfromInstruction(
323 const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
324 auto options = HloPrintOptions::Canonical();
325 options.set_print_backend_config(true);
326 return std::make_tuple(se, conv->ToString(options));
327 }
328
329 absl::Mutex autotune_cache_lock(absl::kConstInit);
330 auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_lock) =
331 *new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
332 auto& autotune_cache_stats ABSL_GUARDED_BY(autotune_cache_lock) =
333 *new ConvCacheStats();
334 } // anonymous namespace
335
PickBestAlgorithm(const HloCustomCallInstruction * instr)336 StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
337 const HloCustomCallInstruction* instr) {
338 // Don't run this function concurrently on the same GPU.
339 //
340 // This is a bit of a hack and doesn't protect us against arbitrary concurrent
341 // use of a GPU, but it's sufficient to let us compile two HLO modules
342 // concurrently and then run them sequentially.
343 //
344 // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
345 // avoid ever doing duplicate work. If we have a cache miss, only one thread
346 // will run PickBestAlgorithmImpl for a particular device.
347 absl::MutexLock lock(&GetGpuMutex(stream_exec_));
348
349 // We cache the autotuning results to avoid doing the duplicate work,
350 // which can greatly improve both stability (deterministic numeric results
351 // within a process for a given input) and performance (2x speedup on some
352 // models).
353 ConvCacheKey key = AutotuneCacheKeyfromInstruction(instr, stream_exec_);
354 {
355 absl::MutexLock lock(&autotune_cache_lock);
356 auto it = autotune_cache.find(key);
357 if (it != autotune_cache.end()) {
358 autotune_cache_stats.cache_hits++;
359 return it->second;
360 }
361 autotune_cache_stats.cache_misses++;
362 }
363
364 // Make sure any previous activity on this executor is done. We don't want
365 // other work still running on the GPU to interfere with autotuning.
366 if (!stream_exec_->SynchronizeAllActivity()) {
367 return InternalError(
368 "Failed to synchronize GPU for autotuning conv instruction: %s",
369 std::get<1>(key) /* instr */);
370 }
371
372 // allocator either points to this->allocator_ or, if that's null, to a
373 // se::StreamExecutorMemoryAllocator for stream_exec_.
374 se::DeviceMemoryAllocator* allocator;
375 optional<se::StreamExecutorMemoryAllocator> se_allocator;
376 if (allocator_ != nullptr) {
377 allocator = allocator_;
378 } else {
379 se_allocator.emplace(stream_exec_);
380 allocator = &*se_allocator;
381 }
382
383 TF_ASSIGN_OR_RETURN(se::Stream* const stream,
384 allocator->GetStream(stream_exec_->device_ordinal()));
385 StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
386 // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
387 // have diverged. Specifically, we need to make sure redzone allocator related
388 // utilities are not used in ROCm routine
389 if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
390 result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream);
391 } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
392 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
393 result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream);
394 #endif
395 }
396
397 if (result_or.ok()) {
398 absl::MutexLock lock(&autotune_cache_lock);
399 CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
400 }
401 return result_or;
402 }
403
404 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
405
406 namespace {
ShouldInitConvData(const HloCustomCallInstruction * instr)407 bool ShouldInitConvData(const HloCustomCallInstruction* instr) {
408 const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
409 const int32_t conv_autotune_level =
410 hlo_module_config.debug_options().xla_gpu_autotune_level();
411 return conv_autotune_level >= 2;
412 }
413
ShouldCheckConv(const HloCustomCallInstruction * instr)414 bool ShouldCheckConv(const HloCustomCallInstruction* instr) {
415 const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
416 const int32_t conv_autotune_level =
417 hlo_module_config.debug_options().xla_gpu_autotune_level();
418 return conv_autotune_level >= 4;
419 }
420 } // namespace
421
422 // There are three tiers of errors possible here: returning a failed StatusOr
423 // means autotuning fails immediately; returning an AutotuneResult with a
424 // failure code other than DISQUALIFIED means autotuning fails if
425 // crash_on_checking_failure is set; and returning a DISQUALIFIED AutotuneResult
426 // simply skips the engine/algorithm while recording a reason for skipping it.
427 StatusOr<tensorflow::AutotuneResult>
AutotuneOneConvRunner(const GpuConvConfig & config,const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::RedzoneAllocator * input_output_allocator,se::Stream * stream,MaybeFusedConvRunner * const runner,absl::Span<const DeviceMemoryBase> operand_buffers,DeviceMemoryBase result_buffer,std::optional<ReferenceResult> * reference_result,absl::Span<const AlgorithmDesc> disabled_algos)428 GpuConvAlgorithmPicker::AutotuneOneConvRunner(
429 const GpuConvConfig& config, const HloCustomCallInstruction* instr,
430 se::DeviceMemoryAllocator* allocator,
431 se::RedzoneAllocator* input_output_allocator, se::Stream* stream,
432 MaybeFusedConvRunner* const runner,
433 absl::Span<const DeviceMemoryBase> operand_buffers,
434 DeviceMemoryBase result_buffer,
435 std::optional<ReferenceResult>* reference_result,
436 absl::Span<const AlgorithmDesc> disabled_algos) {
437 auto alg = runner->ToAlgorithmDesc();
438
439 XLA_SCOPED_LOGGING_TIMER_LEVEL(
440 absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
441 alg.ToString()),
442 2);
443
444 const auto& hlo_module_config = instr->GetModule()->config();
445 const Shape& result_shape = instr->shape().tuple_shapes(0);
446
447 auto make_failure = [&alg](AutotuneResult::FailureKind kind,
448 absl::string_view msg) {
449 tensorflow::AutotuneResult result;
450 *result.mutable_algorithm() = alg.ToProto();
451 result.mutable_failure()->set_kind(kind);
452 result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
453 return result;
454 };
455
456 AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
457
458 if (absl::c_linear_search(disabled_algos, alg_key)) {
459 LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
460 << " for conv " << instr->ToString();
461 return make_failure(AutotuneResult::DISQUALIFIED,
462 "Disqualified for being known-buggy.");
463 }
464
465 auto activation_mode =
466 config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
467
468 // For fused convolutions with the identity function as the activation, only
469 // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
470 // silently do Relu. See
471 // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
472 //
473 // For cuDNN Frontend, there is no way to check whether we're using a broken
474 // algorithm, so on versions where some algorithms are broken, we don't use
475 // the cuDNN Frontend for these convs at all. As such, if we get a
476 // frontend-based runner, we can be sure it's not one of the broken
477 // algorithms we're checking for.
478 if (!alg.is_cudnn_frontend() &&
479 config.kind == CudnnConvKind::kForwardActivation &&
480 activation_mode == se::dnn::ActivationMode::kNone &&
481 alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
482 return make_failure(AutotuneResult::DISQUALIFIED,
483 "Disqualified for implicit RELU.");
484 }
485
486 const int64_t rz_space_limit = hlo_module_config.debug_options()
487 .xla_gpu_redzone_scratch_max_megabytes() *
488 (1LL << 20);
489 se::RedzoneAllocator scratch_allocator(
490 stream, allocator,
491 PtxOptsFromDebugOptions(hlo_module_config.debug_options()),
492 /*memory_limit=*/rz_space_limit);
493 se::dnn::ProfileResult profile_result;
494 VLOG(3) << "Trying algorithm " << alg.ToString() << " for "
495 << instr->ToString();
496
497 std::optional<size_t> workspace_size =
498 runner->ToAlgorithmDesc().workspace_size();
499 if (!workspace_size) {
500 return make_failure(AutotuneResult::UNKNOWN,
501 "Internal error: missing workspace size from "
502 "OpRunner::ToAlgorithmDesc()");
503 }
504
505 auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
506 if (!scratch_or.ok()) {
507 return make_failure(AutotuneResult::DISQUALIFIED,
508 absl::StrCat("Scratch allocation failed: ",
509 scratch_or.status().ToString()));
510 }
511 se::DeviceMemoryBase scratch_memory = scratch_or.ValueOrDie();
512
513 // Use assignment instead of brace-list to make GCC 4.9 happy.
514 RunConvOptions options;
515 options.profile_result = &profile_result;
516 options.runner_cache = runner;
517 Status launch_status = RunGpuConv(config, operand_buffers, result_buffer,
518 scratch_memory, stream, options);
519
520 if (!launch_status.ok()) {
521 VLOG(4) << "Launch failed: " << launch_status;
522 return make_failure(
523 AutotuneResult::DISQUALIFIED,
524 absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
525 launch_status.ToString()));
526 }
527
528 if (!profile_result.is_valid()) {
529 VLOG(4) << "Launch succeeded but profile result is invalid.";
530 // Not DISQUALIFIED: this means something went wrong internally.
531 return make_failure(
532 AutotuneResult::UNKNOWN,
533 absl::StrCat("Launch succeeded but profile result is invalid, "
534 "with cuDNN engine ",
535 alg.ToString(), ": ", launch_status.ToString()));
536 }
537
538 int64_t scratch_bytes_used =
539 scratch_allocator.TotalAllocatedBytesExcludingRedzones();
540
541 tensorflow::AutotuneResult result;
542 *result.mutable_algorithm() = alg.ToProto();
543 result.set_scratch_bytes(scratch_bytes_used);
544 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
545 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
546
547 if (!ShouldCheckConv(instr)) {
548 if (!reference_result->has_value()) {
549 (*reference_result) = {alg, DeviceMemoryBase()};
550 }
551 return result;
552 }
553
554 // Check for writes to redzones.
555 TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
556 CheckRedzones(*input_output_allocator, stream,
557 "input/output", instr, &result));
558
559 TF_ASSIGN_OR_RETURN(
560 bool scratch_allocator_redzone_clear,
561 CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
562
563 if (!input_output_allocator_redzone_clear ||
564 !scratch_allocator_redzone_clear) {
565 std::string canonical_hlo =
566 std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
567
568 std::string blas_version;
569 if (auto* blas = stream_exec_->AsBlas()) {
570 (void)blas->GetVersion(&blas_version);
571 }
572
573 AlgorithmDenylist proto;
574 auto entry = proto.add_entries();
575 entry->set_hlo(canonical_hlo);
576 *entry->mutable_cc() = GetComputeCapability(stream_exec_);
577 *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
578 entry->set_blas_version(blas_version);
579 auto algo = entry->add_algos();
580 algo->set_id(alg.algo_id());
581 algo->set_tensor_ops(alg.tensor_ops_enabled());
582
583 LOG(ERROR) << "To denylist this algorithm for this convolution, "
584 "copy-paste the following "
585 "proto to the denylist file pointed by XLA_FLAGS "
586 "--xla_gpu_algorithm_denylist_path="
587 << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
588 << " : " << proto.ShortDebugString();
589
590 // CheckRedzones has modified the result in-place to include a failure.
591 return result;
592 }
593
594 if (reference_result->has_value()) {
595 XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
596 BufferComparator comparator(result_shape, hlo_module_config);
597 StatusOr<bool> compare_result = comparator.CompareEqual(
598 stream, (*reference_result)->buffer, result_buffer);
599 if (!compare_result.ok()) {
600 LOG(ERROR) << "Unable to compare "
601 << (*reference_result)->algorithm.ToString() << " against "
602 << alg.ToString() << " for " << instr->ToString() << ": "
603 << compare_result.status();
604 if (compare_result.status().code() ==
605 tensorflow::error::RESOURCE_EXHAUSTED) {
606 // Possibly OOM. Propagate the error.
607 return compare_result.status();
608 }
609 const DebugOptions& debug_options =
610 instr->GetModule()->config().debug_options();
611 CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
612 } else if (!compare_result.ValueOrDie()) {
613 LOG(ERROR)
614 << "Results mismatch between different convolution algorithms. "
615 "This is likely a bug/unexpected loss of precision in cudnn.\n"
616 << instr->ToString() << " for "
617 << (*reference_result)->algorithm.ToString() << " vs "
618 << alg.ToString();
619 PrintPlatformInfo(stream);
620 VLOG(1) << "Full module on failure: \n" << instr->GetModule()->ToString();
621 auto* fail = result.mutable_failure();
622 fail->set_kind(AutotuneResult::WRONG_RESULT);
623 fail->set_buffer_address(
624 reinterpret_cast<uint64_t>(result_buffer.opaque()));
625 *fail->mutable_reference_algorithm() =
626 (*reference_result)->algorithm.ToProto();
627 }
628 } else {
629 XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
630 TF_ASSIGN_OR_RETURN(
631 auto reference_result_buffer,
632 input_output_allocator->AllocateBytes(result_buffer.size()));
633 stream->ThenMemcpy(&reference_result_buffer, result_buffer,
634 result_buffer.size());
635 (*reference_result) = {alg, reference_result_buffer};
636 }
637
638 return result;
639 }
640
641 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheCuda(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)642 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
643 const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
644 se::Stream* stream) {
645 // Right now Redzone allocator is available in Cuda target only
646 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
647 "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
648
649 const Shape& result_shape = instr->shape().tuple_shapes(0);
650 int64_t rng_state = 0;
651
652 const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
653 const bool init_conv_data = ShouldInitConvData(instr);
654 const auto initialize_buffer = [init_conv_data, &stream, &rng_state](
655 DeviceMemoryBase buffer,
656 const Shape& buffer_shape) {
657 if (init_conv_data) {
658 InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
659 }
660 };
661
662 // Allocate space for the input, filter, and output of the convolution.
663 const int64_t redzone_size =
664 ShouldCheckConv(instr) ? se::RedzoneAllocator::kDefaultRedzoneSize : 0;
665 se::RedzoneAllocator input_output_allocator(
666 stream, allocator,
667 PtxOptsFromDebugOptions(hlo_module_config.debug_options()),
668 /*memory_limit=*/std::numeric_limits<int64_t>::max(),
669 /*redzone_size=*/redzone_size);
670 std::vector<se::DeviceMemoryBase> operand_buffers;
671 for (const auto* operand : instr->operands()) {
672 TF_ASSIGN_OR_RETURN(auto buffer,
673 input_output_allocator.AllocateBytes(
674 ShapeUtil::ByteSizeOf(operand->shape())));
675 initialize_buffer(buffer, operand->shape());
676 operand_buffers.push_back(buffer);
677 }
678 TF_ASSIGN_OR_RETURN(auto result_buffer,
679 input_output_allocator.AllocateBytes(
680 ShapeUtil::ByteSizeOf(result_shape)));
681 initialize_buffer(result_buffer, result_shape);
682
683 const DebugOptions& debug_options =
684 instr->GetModule()->config().debug_options();
685
686 const bool crash_on_checking_failure =
687 debug_options.xla_gpu_crash_on_verification_failures();
688
689 std::string canonical_hlo =
690 std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
691
692 std::string blas_version;
693 if (auto* blas = stream_exec_->AsBlas()) {
694 (void)blas->GetVersion(&blas_version);
695 }
696
697 absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
698 GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
699 blas_version, canonical_hlo);
700
701 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
702
703 const bool cudnn_frontend_enabled =
704 debug_options.xla_gpu_enable_cudnn_frontend();
705
706 // Use the first algorithm that's supported as reference. There isn't a
707 // particular reason to use it, as any algorithm suffices. It doesn't make
708 // this algorithm considered correct, though.
709 std::optional<ReferenceResult> reference_result;
710
711 TF_ASSIGN_OR_RETURN(std::vector<MaybeFusedConvRunner> runners,
712 GetAlgorithms(config, stream, cudnn_frontend_enabled,
713 /* use_fallback = */ false));
714
715 std::vector<AutotuneResult> profile_results;
716 for (auto& runner_cache : runners) {
717 TF_ASSIGN_OR_RETURN(
718 auto result, AutotuneOneConvRunner(
719 config, instr, allocator, &input_output_allocator,
720 stream, &runner_cache, operand_buffers, result_buffer,
721 &reference_result, disabled_algos));
722 profile_results.emplace_back(std::move(result));
723 }
724
725 // If any algorithm has worked, we'll skip the fallback algorithms, since
726 // they include some very slow algorithms.
727 if (!reference_result) {
728 LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
729 "worked; trying fallback algorithms. Conv: "
730 << canonical_hlo;
731
732 TF_ASSIGN_OR_RETURN(std::vector<MaybeFusedConvRunner> fallback_runners,
733 GetAlgorithms(config, stream, cudnn_frontend_enabled,
734 /* use_fallback = */ true));
735
736 for (auto& runner_cache : fallback_runners) {
737 TF_ASSIGN_OR_RETURN(
738 auto result, AutotuneOneConvRunner(
739 config, instr, allocator, &input_output_allocator,
740 stream, &runner_cache, operand_buffers,
741 result_buffer, &reference_result, disabled_algos));
742 profile_results.emplace_back(std::move(result));
743 }
744 }
745
746 // Log the autotuning result.
747 {
748 tensorflow::AutotuningLog log;
749 {
750 ConvInstructionLog instr_log;
751 *instr_log.mutable_instruction() = instr->ToProto();
752 for (int i = 0; i < instr->operand_count(); i++) {
753 *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
754 instr_log.add_operand_addresses(
755 reinterpret_cast<uint64_t>(operand_buffers[i].opaque()));
756 }
757 instr_log.set_result_address(
758 reinterpret_cast<uint64_t>(result_buffer.opaque()));
759 log.mutable_instr()->PackFrom(instr_log);
760 }
761 for (const auto& profile : profile_results) {
762 *log.add_results() = profile;
763 }
764 *log.mutable_compute_capability() = GetComputeCapability(stream_exec_);
765 *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
766 log.set_device_pci_bus_id(
767 stream_exec_->GetDeviceDescription().pci_bus_id());
768 log.set_blas_version(blas_version);
769 VLOG(1) << "Autotuning result: " << log.ShortDebugString();
770 // If we crash on checking failure, we are in a testing/benchmark mode, thus
771 // omitting logging through the logger.
772 if (!crash_on_checking_failure) {
773 tensorflow::Logger::GetSingleton()->LogProto(log);
774 } else {
775 // Crash on miscompares and redzone violations if desired.
776 for (const auto& profile : profile_results) {
777 if (profile.has_failure() &&
778 profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
779 LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
780 << log.DebugString();
781 }
782 }
783 }
784 }
785
786 TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
787 PickBestResult(profile_results, *instr));
788 return selected_algorithm;
789 }
790 #endif
791
792 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheRocm(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)793 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
794 const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
795 se::Stream* stream) {
796 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
797 "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
798
799 const auto device_ordinal = stream_exec_->device_ordinal();
800 std::vector<se::DeviceMemoryBase> operand_buffers;
801
802 ScratchAllocator input_output_allocator(device_ordinal, allocator);
803 const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
804 // Although we don't have evidence this matters, zero out the buffers
805 // before autotuning. It's conceivable that using uninitialized memory as
806 // the inputs might affect performance if e.g. the inputs contain
807 // denormals, and this is easy enough.
808 stream->ThenMemZero(&buffer, buffer.size());
809 };
810
811 // Allocate space for the input, filter, and output of the convolution. We
812 // use a ScratchAllocator for this instead of calling allocator_ directly so
813 // that our allocations don't leak.
814 for (const auto* operand : instr->operands()) {
815 TF_ASSIGN_OR_RETURN(auto buffer,
816 input_output_allocator.AllocateBytes(
817 ShapeUtil::ByteSizeOf(operand->shape())));
818 initialize_buffer(buffer);
819 operand_buffers.push_back(buffer);
820 }
821
822 TF_ASSIGN_OR_RETURN(
823 auto result_buffer,
824 input_output_allocator.AllocateBytes(
825 ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
826 initialize_buffer(result_buffer);
827
828 ScratchAllocator scratch_allocator(device_ordinal, allocator);
829
830 TF_ASSIGN_OR_RETURN(
831 std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
832 GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer,
833 stream_exec_, &scratch_allocator, stream));
834
835 std::vector<AutotuneResult> profile_results;
836
837 if (runners.size() == 1) {
838 TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
839 auto algorithm_proto = alg.ToProto();
840 profile_results.emplace_back();
841 auto& result = profile_results.back();
842 *result.mutable_algorithm() = algorithm_proto;
843
844 result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
845
846 // TODO(awpr): if the profile result time for a singleton algorithm is
847 // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
848 // re-profiling ROCm algorithms anyway.
849 *result.mutable_run_time() =
850 tensorflow::proto_utils::ToDurationProto(absl::Milliseconds(-1));
851 } else {
852 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
853 for (auto& runner : runners) {
854 TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
855 XLA_SCOPED_LOGGING_TIMER_LEVEL(
856 absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
857 alg.ToString()),
858 2);
859
860 se::dnn::ProfileResult profile_result;
861 VLOG(3) << "Trying algorithm " << alg.ToString() << " for "
862 << instr->ToString();
863
864 TF_ASSIGN_OR_RETURN(
865 DeviceMemoryBase scratch_memory,
866 scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
867
868 TF_ASSIGN_OR_RETURN(auto lazy_runner,
869 se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
870 std::move(runner)));
871
872 MaybeFusedConvRunner runner_cache(std::move(lazy_runner));
873
874 // Use assignment instead of brace-list to make GCC 4.9 happy.
875 RunConvOptions options;
876 options.profile_result = &profile_result;
877 options.runner_cache = &runner_cache;
878 Status launch_status =
879 RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
880 scratch_memory, stream, options);
881
882 if (!launch_status.ok()) {
883 continue;
884 }
885
886 if (!profile_result.is_valid()) {
887 continue;
888 }
889
890 profile_results.emplace_back();
891 AutotuneResult& result = profile_results.back();
892 *result.mutable_algorithm() = alg.ToProto();
893
894 int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
895 result.set_scratch_bytes(scratch_bytes_used);
896 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
897 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
898 }
899 }
900
901 TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
902 PickBestResult(profile_results, *instr));
903 return selected_algorithm;
904 }
905
RunOnInstruction(HloInstruction * instr)906 StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
907 CHECK(IsCustomCallToDnnConvolution(*instr));
908
909 const bool strict = instr->parent()
910 ->parent()
911 ->config()
912 .debug_options()
913 .xla_gpu_strict_conv_algorithm_picker();
914
915 StatusOr<AutotuneResult> best_algo_or =
916 PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
917 if (!best_algo_or.ok()) {
918 auto msg = absl::StrFormat(
919 "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
920 "Original error: %s",
921 instr->ToString(), best_algo_or.status().ToString());
922
923 if (strict) {
924 return Unknown(
925 "%s\n\nTo ignore this failure and try to use a fallback algorithm "
926 "(which may have suboptimal performance), use "
927 "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please "
928 "also file a bug for the root cause of failing autotuning.",
929 msg);
930 }
931 LOG(WARNING)
932 << msg << "\n\nAs a result, convolution performance may be suboptimal.";
933 return false;
934 }
935
936 auto best_algo = std::move(best_algo_or).ValueOrDie();
937 VLOG(2) << "Setting cudnn conv to use algorithm "
938 << best_algo.conv().algorithm() << " and "
939 << NumBytesToString(best_algo.scratch_bytes())
940 << " of scratch memory: " << instr->ToString()
941 << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
942
943 // Replace instr with a new CustomCall which has the correct algorithm, and
944 // whose output shape has the appropriate amount of scratch memory.
945 HloComputation* computation = instr->parent();
946 Shape new_call_shape = ShapeUtil::MakeTupleShape(
947 {instr->shape().tuple_shapes(0),
948 ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})});
949
950 TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
951 instr->backend_config<CudnnConvBackendConfig>());
952 *backend_config.mutable_algorithm() = best_algo.algorithm();
953 backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
954 best_algo.scratch_bytes());
955
956 HloInstruction* new_call = computation->AddInstruction(
957 instr->CloneWithNewOperands(new_call_shape, instr->operands()));
958
959 // Preserve the name of the old instruction. This is safe because we're going
960 // to remove the old one anyway, and it makes it easier to trace how our conv
961 // is transformed through all our passes.
962 new_call->SetAndSanitizeName(instr->name());
963
964 VLOG(2) << "Replacing convolution " << instr->ToString() << " with "
965 << new_call->ToString();
966
967 TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
968
969 // Repackage new_call so it has the same shape as the original call, namely
970 // (conv_result, u8[0]).
971 HloInstruction* new_tuple =
972 computation->AddInstruction(HloInstruction::CreateTuple(
973 {computation->AddInstruction(HloInstruction::CreateGetTupleElement(
974 new_call_shape.tuple_shapes(0), new_call, 0)),
975 computation->AddInstruction(HloInstruction::CreateConstant(
976 LiteralUtil::CreateR1<uint8_t>({})))}));
977
978 TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
979 return true;
980 }
981
RunOnComputation(HloComputation * computation)982 StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
983 HloComputation* computation) {
984 std::vector<HloInstruction*> convs;
985 for (auto* instr : computation->instructions()) {
986 if (IsCustomCallToDnnConvolution(*instr)) {
987 convs.push_back(instr);
988 }
989 }
990
991 bool changed = false;
992 for (auto* instr : convs) {
993 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
994 changed |= result;
995 }
996 return changed;
997 }
998
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)999 StatusOr<bool> GpuConvAlgorithmPicker::Run(
1000 HloModule* module,
1001 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1002 XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
1003
1004 if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
1005 VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
1006 "returning early.";
1007 return false;
1008 }
1009
1010 bool changed = false;
1011 for (HloComputation* computation :
1012 module->MakeNonfusionComputations(execution_threads)) {
1013 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
1014 changed |= result;
1015 }
1016
1017 {
1018 absl::MutexLock lock(&autotune_cache_lock);
1019 autotune_cache_stats.LogStats();
1020 }
1021
1022 return changed;
1023 }
1024
1025 } // namespace gpu
1026 } // namespace xla
1027