xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
17 
18 #include <limits>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
30 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
32 #include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
33 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
34 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 #include "tensorflow/core/platform/logger.h"
43 #include "tensorflow/core/util/env_var.h"
44 #include "tensorflow/core/util/proto/proto_utils.h"
45 
46 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
47 #include "third_party/gpus/cudnn/cudnn.h"
48 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
49 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
50 #endif
51 
52 namespace xla {
53 namespace gpu {
54 namespace {
55 
56 using se::DeviceMemoryBase;
57 using se::dnn::AlgorithmDesc;
58 using std::optional;
59 using tensorflow::AutotuneResult;
60 
61 class ScratchAllocator : public se::ScratchAllocator {
62  public:
ScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)63   ScratchAllocator(int device_ordinal,
64                    se::DeviceMemoryAllocator* memory_allocator)
65       : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
66 
GetMemoryLimitInBytes()67   int64_t GetMemoryLimitInBytes() override {
68     return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
69   }
TotalAllocatedBytes()70   int64_t TotalAllocatedBytes() { return total_allocated_bytes_; }
71 
72   StatusOr<se::DeviceMemory<uint8_t>> AllocateBytes(int64_t byte_size) override;
73 
74   template <typename T>
Allocate(int64_t num_elements)75   StatusOr<se::DeviceMemory<T>> Allocate(int64_t num_elements) {
76     TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8_t> bytes,
77                         AllocateBytes(num_elements * sizeof(T)));
78     return se::DeviceMemory<T>(bytes);
79   }
80 
81  private:
82   const int device_ordinal_;
83   se::DeviceMemoryAllocator* memory_allocator_;
84   std::vector<se::OwningDeviceMemory> allocated_buffers_;
85   int64_t total_allocated_bytes_ = 0;
86 };
87 
AllocateBytes(int64_t byte_size)88 StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
89     int64_t byte_size) {
90   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
91   if (byte_size > GetMemoryLimitInBytes()) {
92     return se::port::Status(
93         se::port::error::RESOURCE_EXHAUSTED,
94         absl::StrFormat(
95             "Allocating %d bytes exceeds the memory limit of %d bytes.",
96             byte_size, GetMemoryLimitInBytes()));
97   }
98 
99   TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
100                       memory_allocator_->Allocate(device_ordinal_, byte_size,
101                                                   /*retry_on_failure=*/false));
102   total_allocated_bytes_ += byte_size;
103 
104   se::DeviceMemoryBase buffer_addr = *allocated_buffer;
105   allocated_buffers_.push_back(std::move(allocated_buffer));
106   return se::DeviceMemory<uint8_t>(buffer_addr);
107 }
108 
GetAlgorithms(const GpuConvConfig & config,se::Stream * stream,bool use_cudnn_frontend,bool use_fallback)109 StatusOr<std::vector<MaybeFusedConvRunner>> GetAlgorithms(
110     const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
111     bool use_fallback) {
112   TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
113                       GetDNNConvKindFromCudnnConvKind(config.kind));
114 
115   TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type,
116                       GetDNNDataTypeFromPrimitiveType(config.input_type));
117 
118   TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type,
119                       GetDNNDataTypeFromPrimitiveType(config.output_type));
120 
121   se::StreamExecutor* stream_exec = stream->parent();
122 
123   std::vector<MaybeFusedConvRunner> result;
124 
125   switch (kind) {
126     default:
127       return InternalError("Unknown ConvolutionKind %d", kind);
128     case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: {
129       if (!config.fusion) {
130         return InternalError(
131             "GpuConvConfig had fusion ConvolutionKind but no FusionConfig.");
132       }
133       std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
134       TF_RETURN_IF_ERROR(stream_exec->GetFusedConvolveRunners(
135           use_cudnn_frontend,
136           // This refers to the kind of convolution op inside the fusion, not
137           // the whole fused graph.
138           se::dnn::ConvolutionKind::FORWARD, input_type,
139           BiasTypeForInputType(input_type), output_type,
140           /* conv_input_scale = */ config.conv_result_scale,
141           /* side_input_scale = */ config.fusion->side_input_scale,
142           /* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor,
143           config.filter_descriptor, GetBiasDescriptor(config),
144           config.output_descriptor, config.conv_desc, use_fallback,
145           config.fusion->mode, &runners));
146       for (auto& runner : runners) {
147         TF_ASSIGN_OR_RETURN(
148             auto runner_cache,
149             se::dnn::LazyOpRunner<se::dnn::FusedConvOp>::FromOpRunner(
150                 std::move(runner)));
151         result.emplace_back(std::move(runner_cache));
152       }
153       break;
154     }
155 
156     case se::dnn::ConvolutionKind::FORWARD:
157     case se::dnn::ConvolutionKind::BACKWARD_DATA:
158     case se::dnn::ConvolutionKind::BACKWARD_FILTER: {
159       std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
160       // This path is cuDNN-only, where the DeviceMemoryBase arguments and the
161       // allocator are unused; so, they're all provided as nullptr.
162       TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners(
163           use_cudnn_frontend, kind, input_type, output_type, stream,
164           config.input_descriptor,
165           /* input_data = */ DeviceMemoryBase(nullptr),
166           config.filter_descriptor,
167           /* filter_data = */ DeviceMemoryBase(nullptr),
168           config.output_descriptor,
169           /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
170           use_fallback, nullptr, &runners));
171       for (auto& runner : runners) {
172         TF_ASSIGN_OR_RETURN(
173             auto runner_cache,
174             se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
175                 std::move(runner)));
176         result.emplace_back(std::move(runner_cache));
177       }
178       break;
179     }
180   }
181 
182   return result;
183 }
184 
185 StatusOr<std::vector<std::unique_ptr<const se::dnn::ConvRunner>>>
GetMIOpenAlgorithms(const HloCustomCallInstruction * instr,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::StreamExecutor * stream_exec,ScratchAllocator * scratch_allocator,se::Stream * stream)186 GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
187                     absl::Span<se::DeviceMemoryBase> operand_buffers,
188                     se::DeviceMemoryBase result_buffer,
189                     se::StreamExecutor* stream_exec,
190                     ScratchAllocator* scratch_allocator, se::Stream* stream) {
191   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
192 
193   TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
194                       GetDNNConvKindFromCudnnConvKind(config.kind));
195 
196   TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
197                       GetDNNDataTypeFromPrimitiveType(config.output_type));
198 
199   TF_ASSIGN_OR_RETURN(GpuConvParams params,
200                       GetGpuConvParams(config, operand_buffers, result_buffer));
201 
202   std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
203   TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners(
204       /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream,
205       params.config->input_descriptor, params.input_buf,
206       params.config->filter_descriptor, params.filter_buf,
207       params.config->output_descriptor, params.output_buf,
208       params.config->conv_desc, /* use_fallback = */ false, scratch_allocator,
209       &runners));
210 
211   return runners;
212 }
213 
NumBytesToString(int64_t bytes)214 std::string NumBytesToString(int64_t bytes) {
215   return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
216                       bytes, "B)");
217 }
218 
GetCudnnVersion(se::StreamExecutor * stream_executor)219 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
220   tensorflow::CudnnVersion cudnn_version;
221   if (auto* dnn = stream_executor->AsDnn()) {
222     StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
223     if (version_or.ok()) {
224       const auto& version = version_or.ValueOrDie();
225       cudnn_version.set_major(version.major_version());
226       cudnn_version.set_minor(version.minor_version());
227       cudnn_version.set_patch(version.patch());
228     }
229   }
230   return cudnn_version;
231 }
232 
GetComputeCapability(se::StreamExecutor * stream_executor)233 tensorflow::ComputeCapability GetComputeCapability(
234     se::StreamExecutor* stream_executor) {
235   tensorflow::ComputeCapability cc;
236   se::CudaComputeCapability se_cc =
237       stream_executor->GetDeviceDescription().cuda_compute_capability();
238   cc.set_major(se_cc.major);
239   cc.set_minor(se_cc.minor);
240   return cc;
241 }
242 
PrintPlatformInfo(const se::Stream * stream)243 void PrintPlatformInfo(const se::Stream* stream) {
244   auto* se = stream->parent();
245   const auto& desc = se->GetDeviceDescription();
246   LOG(ERROR) << "Device: " << desc.name();
247   LOG(ERROR) << "Platform: " << desc.platform_version();
248   LOG(ERROR) << "Driver: " << desc.driver_version();
249   LOG(ERROR) << "Runtime: " << desc.runtime_version();
250 
251   auto* dnn = se->AsDnn();
252   if (dnn) {
253     auto dnn_version = dnn->GetVersion();
254     if (dnn_version.ok()) {
255       auto v = dnn_version.ValueOrDie();
256       LOG(ERROR) << "cudnn version: " << v.major_version() << "."
257                  << v.minor_version() << "." << v.patch();
258     }
259   }
260 }
261 
262 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
263 // Returns true if the redzones in `allocator`'s allocations are unmodified.
264 //
265 // If the redzones are modified, logs an error, sets the appropriate failure
266 // bits on `result`, and returns false.
267 //
268 // Returns a status if an unexpected error has occurred, and the stream
269 // has been poisoned.
270 //
271 // `name` is a user-friendly name for the set of redzones being checked, e.g.
272 // "input/output" or "scratch".
CheckRedzones(const se::RedzoneAllocator & allocator,se::Stream * stream,absl::string_view name,const HloInstruction * instr,AutotuneResult * result)273 StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
274                              se::Stream* stream, absl::string_view name,
275                              const HloInstruction* instr,
276                              AutotuneResult* result) {
277   XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
278                                  2);
279   using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
280   TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
281                       allocator.CheckRedzones());
282   if (redzone_check.ok()) {
283     return true;
284   }
285 
286   auto* fail = result->mutable_failure();
287   fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
288   *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
289   fail->set_buffer_address(
290       reinterpret_cast<uint64_t>(redzone_check.user_buffer_address));
291 
292   LOG(ERROR) << absl::StreamFormat(
293       "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
294       "cudnn bug. We will skip this algorithm in the future, but your GPU "
295       "state may already be corrupted, leading to incorrect results. Within "
296       "Google, no action is needed on your part. Outside of Google, please "
297       "ensure you're running the latest version of cudnn. If that doesn't fix "
298       "the problem, please file a bug with this full error message and we'll "
299       "contact nvidia.",
300       name);
301   LOG(ERROR) << redzone_check.RedzoneFailureMsg();
302   LOG(ERROR) << "HloInstruction " << instr->ToString();
303   PrintPlatformInfo(stream);
304   return false;
305 }
306 #endif
307 
308 using ConvCacheKey =
309     std::tuple<se::StreamExecutor*,
310                /* conv->ToString(HloPrintOptions::Canonical()) */ std::string>;
311 
312 struct ConvCacheStats {
313   int64_t cache_hits = 0;
314   int64_t cache_misses = 0;
315 
LogStatsxla::gpu::__anon042fc3bc0111::ConvCacheStats316   void LogStats() {
317     VLOG(2) << "Cache hits: " << cache_hits;
318     VLOG(2) << "Cache misses: " << cache_misses;
319   }
320 };
321 
AutotuneCacheKeyfromInstruction(const HloCustomCallInstruction * conv,se::StreamExecutor * se)322 ConvCacheKey AutotuneCacheKeyfromInstruction(
323     const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
324   auto options = HloPrintOptions::Canonical();
325   options.set_print_backend_config(true);
326   return std::make_tuple(se, conv->ToString(options));
327 }
328 
329 absl::Mutex autotune_cache_lock(absl::kConstInit);
330 auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_lock) =
331     *new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
332 auto& autotune_cache_stats ABSL_GUARDED_BY(autotune_cache_lock) =
333     *new ConvCacheStats();
334 }  // anonymous namespace
335 
PickBestAlgorithm(const HloCustomCallInstruction * instr)336 StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
337     const HloCustomCallInstruction* instr) {
338   // Don't run this function concurrently on the same GPU.
339   //
340   // This is a bit of a hack and doesn't protect us against arbitrary concurrent
341   // use of a GPU, but it's sufficient to let us compile two HLO modules
342   // concurrently and then run them sequentially.
343   //
344   // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
345   // avoid ever doing duplicate work.  If we have a cache miss, only one thread
346   // will run PickBestAlgorithmImpl for a particular device.
347   absl::MutexLock lock(&GetGpuMutex(stream_exec_));
348 
349   // We cache the autotuning results to avoid doing the duplicate work,
350   // which can greatly improve both stability (deterministic numeric results
351   // within a process for a given input) and performance (2x speedup on some
352   // models).
353   ConvCacheKey key = AutotuneCacheKeyfromInstruction(instr, stream_exec_);
354   {
355     absl::MutexLock lock(&autotune_cache_lock);
356     auto it = autotune_cache.find(key);
357     if (it != autotune_cache.end()) {
358       autotune_cache_stats.cache_hits++;
359       return it->second;
360     }
361     autotune_cache_stats.cache_misses++;
362   }
363 
364   // Make sure any previous activity on this executor is done. We don't want
365   // other work still running on the GPU to interfere with autotuning.
366   if (!stream_exec_->SynchronizeAllActivity()) {
367     return InternalError(
368         "Failed to synchronize GPU for autotuning conv instruction: %s",
369         std::get<1>(key) /* instr */);
370   }
371 
372   // allocator either points to this->allocator_ or, if that's null, to a
373   // se::StreamExecutorMemoryAllocator for stream_exec_.
374   se::DeviceMemoryAllocator* allocator;
375   optional<se::StreamExecutorMemoryAllocator> se_allocator;
376   if (allocator_ != nullptr) {
377     allocator = allocator_;
378   } else {
379     se_allocator.emplace(stream_exec_);
380     allocator = &*se_allocator;
381   }
382 
383   TF_ASSIGN_OR_RETURN(se::Stream* const stream,
384                       allocator->GetStream(stream_exec_->device_ordinal()));
385   StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
386   // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
387   // have diverged. Specifically, we need to make sure redzone allocator related
388   // utilities are not used in ROCm routine
389   if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
390     result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream);
391   } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
392 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
393     result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream);
394 #endif
395   }
396 
397   if (result_or.ok()) {
398     absl::MutexLock lock(&autotune_cache_lock);
399     CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
400   }
401   return result_or;
402 }
403 
404 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
405 
406 namespace {
ShouldInitConvData(const HloCustomCallInstruction * instr)407 bool ShouldInitConvData(const HloCustomCallInstruction* instr) {
408   const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
409   const int32_t conv_autotune_level =
410       hlo_module_config.debug_options().xla_gpu_autotune_level();
411   return conv_autotune_level >= 2;
412 }
413 
ShouldCheckConv(const HloCustomCallInstruction * instr)414 bool ShouldCheckConv(const HloCustomCallInstruction* instr) {
415   const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
416   const int32_t conv_autotune_level =
417       hlo_module_config.debug_options().xla_gpu_autotune_level();
418   return conv_autotune_level >= 4;
419 }
420 }  // namespace
421 
422 // There are three tiers of errors possible here: returning a failed StatusOr
423 // means autotuning fails immediately; returning an AutotuneResult with a
424 // failure code other than DISQUALIFIED means autotuning fails if
425 // crash_on_checking_failure is set; and returning a DISQUALIFIED AutotuneResult
426 // simply skips the engine/algorithm while recording a reason for skipping it.
427 StatusOr<tensorflow::AutotuneResult>
AutotuneOneConvRunner(const GpuConvConfig & config,const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::RedzoneAllocator * input_output_allocator,se::Stream * stream,MaybeFusedConvRunner * const runner,absl::Span<const DeviceMemoryBase> operand_buffers,DeviceMemoryBase result_buffer,std::optional<ReferenceResult> * reference_result,absl::Span<const AlgorithmDesc> disabled_algos)428 GpuConvAlgorithmPicker::AutotuneOneConvRunner(
429     const GpuConvConfig& config, const HloCustomCallInstruction* instr,
430     se::DeviceMemoryAllocator* allocator,
431     se::RedzoneAllocator* input_output_allocator, se::Stream* stream,
432     MaybeFusedConvRunner* const runner,
433     absl::Span<const DeviceMemoryBase> operand_buffers,
434     DeviceMemoryBase result_buffer,
435     std::optional<ReferenceResult>* reference_result,
436     absl::Span<const AlgorithmDesc> disabled_algos) {
437   auto alg = runner->ToAlgorithmDesc();
438 
439   XLA_SCOPED_LOGGING_TIMER_LEVEL(
440       absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
441                    alg.ToString()),
442       2);
443 
444   const auto& hlo_module_config = instr->GetModule()->config();
445   const Shape& result_shape = instr->shape().tuple_shapes(0);
446 
447   auto make_failure = [&alg](AutotuneResult::FailureKind kind,
448                              absl::string_view msg) {
449     tensorflow::AutotuneResult result;
450     *result.mutable_algorithm() = alg.ToProto();
451     result.mutable_failure()->set_kind(kind);
452     result.mutable_failure()->set_msg(/* *sigh* */ msg.data(), msg.size());
453     return result;
454   };
455 
456   AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);
457 
458   if (absl::c_linear_search(disabled_algos, alg_key)) {
459     LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
460               << " for conv " << instr->ToString();
461     return make_failure(AutotuneResult::DISQUALIFIED,
462                         "Disqualified for being known-buggy.");
463   }
464 
465   auto activation_mode =
466       config.fusion ? config.fusion->mode : se::dnn::ActivationMode::kNone;
467 
468   // For fused convolutions with the identity function as the activation, only
469   // ALGO_IMPLICIT_PRECOMP_GEMM does the right thing. Other algorithms
470   // silently do Relu. See
471   // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
472   //
473   // For cuDNN Frontend, there is no way to check whether we're using a broken
474   // algorithm, so on versions where some algorithms are broken, we don't use
475   // the cuDNN Frontend for these convs at all.  As such, if we get a
476   // frontend-based runner, we can be sure it's not one of the broken
477   // algorithms we're checking for.
478   if (!alg.is_cudnn_frontend() &&
479       config.kind == CudnnConvKind::kForwardActivation &&
480       activation_mode == se::dnn::ActivationMode::kNone &&
481       alg.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
482     return make_failure(AutotuneResult::DISQUALIFIED,
483                         "Disqualified for implicit RELU.");
484   }
485 
486   const int64_t rz_space_limit = hlo_module_config.debug_options()
487                                      .xla_gpu_redzone_scratch_max_megabytes() *
488                                  (1LL << 20);
489   se::RedzoneAllocator scratch_allocator(
490       stream, allocator,
491       PtxOptsFromDebugOptions(hlo_module_config.debug_options()),
492       /*memory_limit=*/rz_space_limit);
493   se::dnn::ProfileResult profile_result;
494   VLOG(3) << "Trying algorithm " << alg.ToString() << " for "
495           << instr->ToString();
496 
497   std::optional<size_t> workspace_size =
498       runner->ToAlgorithmDesc().workspace_size();
499   if (!workspace_size) {
500     return make_failure(AutotuneResult::UNKNOWN,
501                         "Internal error: missing workspace size from "
502                         "OpRunner::ToAlgorithmDesc()");
503   }
504 
505   auto scratch_or = scratch_allocator.AllocateBytes(*workspace_size);
506   if (!scratch_or.ok()) {
507     return make_failure(AutotuneResult::DISQUALIFIED,
508                         absl::StrCat("Scratch allocation failed: ",
509                                      scratch_or.status().ToString()));
510   }
511   se::DeviceMemoryBase scratch_memory = scratch_or.ValueOrDie();
512 
513   // Use assignment instead of brace-list to make GCC 4.9 happy.
514   RunConvOptions options;
515   options.profile_result = &profile_result;
516   options.runner_cache = runner;
517   Status launch_status = RunGpuConv(config, operand_buffers, result_buffer,
518                                     scratch_memory, stream, options);
519 
520   if (!launch_status.ok()) {
521     VLOG(4) << "Launch failed: " << launch_status;
522     return make_failure(
523         AutotuneResult::DISQUALIFIED,
524         absl::StrCat("Profiling failure on cuDNN engine ", alg.ToString(), ": ",
525                      launch_status.ToString()));
526   }
527 
528   if (!profile_result.is_valid()) {
529     VLOG(4) << "Launch succeeded but profile result is invalid.";
530     // Not DISQUALIFIED: this means something went wrong internally.
531     return make_failure(
532         AutotuneResult::UNKNOWN,
533         absl::StrCat("Launch succeeded but profile result is invalid, "
534                      "with cuDNN engine ",
535                      alg.ToString(), ": ", launch_status.ToString()));
536   }
537 
538   int64_t scratch_bytes_used =
539       scratch_allocator.TotalAllocatedBytesExcludingRedzones();
540 
541   tensorflow::AutotuneResult result;
542   *result.mutable_algorithm() = alg.ToProto();
543   result.set_scratch_bytes(scratch_bytes_used);
544   *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
545       absl::Milliseconds(profile_result.elapsed_time_in_ms()));
546 
547   if (!ShouldCheckConv(instr)) {
548     if (!reference_result->has_value()) {
549       (*reference_result) = {alg, DeviceMemoryBase()};
550     }
551     return result;
552   }
553 
554   // Check for writes to redzones.
555   TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
556                       CheckRedzones(*input_output_allocator, stream,
557                                     "input/output", instr, &result));
558 
559   TF_ASSIGN_OR_RETURN(
560       bool scratch_allocator_redzone_clear,
561       CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
562 
563   if (!input_output_allocator_redzone_clear ||
564       !scratch_allocator_redzone_clear) {
565     std::string canonical_hlo =
566         std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
567 
568     std::string blas_version;
569     if (auto* blas = stream_exec_->AsBlas()) {
570       (void)blas->GetVersion(&blas_version);
571     }
572 
573     AlgorithmDenylist proto;
574     auto entry = proto.add_entries();
575     entry->set_hlo(canonical_hlo);
576     *entry->mutable_cc() = GetComputeCapability(stream_exec_);
577     *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
578     entry->set_blas_version(blas_version);
579     auto algo = entry->add_algos();
580     algo->set_id(alg.algo_id());
581     algo->set_tensor_ops(alg.tensor_ops_enabled());
582 
583     LOG(ERROR) << "To denylist this algorithm for this convolution, "
584                   "copy-paste the following "
585                   "proto to the denylist file pointed by XLA_FLAGS "
586                   "--xla_gpu_algorithm_denylist_path="
587                << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
588                << " : " << proto.ShortDebugString();
589 
590     // CheckRedzones has modified the result in-place to include a failure.
591     return result;
592   }
593 
594   if (reference_result->has_value()) {
595     XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
596     BufferComparator comparator(result_shape, hlo_module_config);
597     StatusOr<bool> compare_result = comparator.CompareEqual(
598         stream, (*reference_result)->buffer, result_buffer);
599     if (!compare_result.ok()) {
600       LOG(ERROR) << "Unable to compare "
601                  << (*reference_result)->algorithm.ToString() << " against "
602                  << alg.ToString() << " for " << instr->ToString() << ": "
603                  << compare_result.status();
604       if (compare_result.status().code() ==
605           tensorflow::error::RESOURCE_EXHAUSTED) {
606         // Possibly OOM. Propagate the error.
607         return compare_result.status();
608       }
609       const DebugOptions& debug_options =
610           instr->GetModule()->config().debug_options();
611       CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
612     } else if (!compare_result.ValueOrDie()) {
613       LOG(ERROR)
614           << "Results mismatch between different convolution algorithms. "
615              "This is likely a bug/unexpected loss of precision in cudnn.\n"
616           << instr->ToString() << " for "
617           << (*reference_result)->algorithm.ToString() << " vs "
618           << alg.ToString();
619       PrintPlatformInfo(stream);
620       VLOG(1) << "Full module on failure: \n" << instr->GetModule()->ToString();
621       auto* fail = result.mutable_failure();
622       fail->set_kind(AutotuneResult::WRONG_RESULT);
623       fail->set_buffer_address(
624           reinterpret_cast<uint64_t>(result_buffer.opaque()));
625       *fail->mutable_reference_algorithm() =
626           (*reference_result)->algorithm.ToProto();
627     }
628   } else {
629     XLA_SCOPED_LOGGING_TIMER_LEVEL("Memcpy Reference Result", 2);
630     TF_ASSIGN_OR_RETURN(
631         auto reference_result_buffer,
632         input_output_allocator->AllocateBytes(result_buffer.size()));
633     stream->ThenMemcpy(&reference_result_buffer, result_buffer,
634                        result_buffer.size());
635     (*reference_result) = {alg, reference_result_buffer};
636   }
637 
638   return result;
639 }
640 
641 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheCuda(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)642 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
643     const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
644     se::Stream* stream) {
645   // Right now Redzone allocator is available in Cuda target only
646   XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
647       "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
648 
649   const Shape& result_shape = instr->shape().tuple_shapes(0);
650   int64_t rng_state = 0;
651 
652   const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
653   const bool init_conv_data = ShouldInitConvData(instr);
654   const auto initialize_buffer = [init_conv_data, &stream, &rng_state](
655                                      DeviceMemoryBase buffer,
656                                      const Shape& buffer_shape) {
657     if (init_conv_data) {
658       InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
659     }
660   };
661 
662   // Allocate space for the input, filter, and output of the convolution.
663   const int64_t redzone_size =
664       ShouldCheckConv(instr) ? se::RedzoneAllocator::kDefaultRedzoneSize : 0;
665   se::RedzoneAllocator input_output_allocator(
666       stream, allocator,
667       PtxOptsFromDebugOptions(hlo_module_config.debug_options()),
668       /*memory_limit=*/std::numeric_limits<int64_t>::max(),
669       /*redzone_size=*/redzone_size);
670   std::vector<se::DeviceMemoryBase> operand_buffers;
671   for (const auto* operand : instr->operands()) {
672     TF_ASSIGN_OR_RETURN(auto buffer,
673                         input_output_allocator.AllocateBytes(
674                             ShapeUtil::ByteSizeOf(operand->shape())));
675     initialize_buffer(buffer, operand->shape());
676     operand_buffers.push_back(buffer);
677   }
678   TF_ASSIGN_OR_RETURN(auto result_buffer,
679                       input_output_allocator.AllocateBytes(
680                           ShapeUtil::ByteSizeOf(result_shape)));
681   initialize_buffer(result_buffer, result_shape);
682 
683   const DebugOptions& debug_options =
684       instr->GetModule()->config().debug_options();
685 
686   const bool crash_on_checking_failure =
687       debug_options.xla_gpu_crash_on_verification_failures();
688 
689   std::string canonical_hlo =
690       std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
691 
692   std::string blas_version;
693   if (auto* blas = stream_exec_->AsBlas()) {
694     (void)blas->GetVersion(&blas_version);
695   }
696 
697   absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
698       GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
699       blas_version, canonical_hlo);
700 
701   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
702 
703   const bool cudnn_frontend_enabled =
704       debug_options.xla_gpu_enable_cudnn_frontend();
705 
706   // Use the first algorithm that's supported as reference. There isn't a
707   // particular reason to use it, as any algorithm suffices. It doesn't make
708   // this algorithm considered correct, though.
709   std::optional<ReferenceResult> reference_result;
710 
711   TF_ASSIGN_OR_RETURN(std::vector<MaybeFusedConvRunner> runners,
712                       GetAlgorithms(config, stream, cudnn_frontend_enabled,
713                                     /* use_fallback = */ false));
714 
715   std::vector<AutotuneResult> profile_results;
716   for (auto& runner_cache : runners) {
717     TF_ASSIGN_OR_RETURN(
718         auto result, AutotuneOneConvRunner(
719                          config, instr, allocator, &input_output_allocator,
720                          stream, &runner_cache, operand_buffers, result_buffer,
721                          &reference_result, disabled_algos));
722     profile_results.emplace_back(std::move(result));
723   }
724 
725   // If any algorithm has worked, we'll skip the fallback algorithms, since
726   // they include some very slow algorithms.
727   if (!reference_result) {
728     LOG(WARNING) << "None of the algorithms provided by cuDNN heuristics "
729                     "worked; trying fallback algorithms.  Conv: "
730                  << canonical_hlo;
731 
732     TF_ASSIGN_OR_RETURN(std::vector<MaybeFusedConvRunner> fallback_runners,
733                         GetAlgorithms(config, stream, cudnn_frontend_enabled,
734                                       /* use_fallback = */ true));
735 
736     for (auto& runner_cache : fallback_runners) {
737       TF_ASSIGN_OR_RETURN(
738           auto result, AutotuneOneConvRunner(
739                            config, instr, allocator, &input_output_allocator,
740                            stream, &runner_cache, operand_buffers,
741                            result_buffer, &reference_result, disabled_algos));
742       profile_results.emplace_back(std::move(result));
743     }
744   }
745 
746   // Log the autotuning result.
747   {
748     tensorflow::AutotuningLog log;
749     {
750       ConvInstructionLog instr_log;
751       *instr_log.mutable_instruction() = instr->ToProto();
752       for (int i = 0; i < instr->operand_count(); i++) {
753         *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
754         instr_log.add_operand_addresses(
755             reinterpret_cast<uint64_t>(operand_buffers[i].opaque()));
756       }
757       instr_log.set_result_address(
758           reinterpret_cast<uint64_t>(result_buffer.opaque()));
759       log.mutable_instr()->PackFrom(instr_log);
760     }
761     for (const auto& profile : profile_results) {
762       *log.add_results() = profile;
763     }
764     *log.mutable_compute_capability() = GetComputeCapability(stream_exec_);
765     *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
766     log.set_device_pci_bus_id(
767         stream_exec_->GetDeviceDescription().pci_bus_id());
768     log.set_blas_version(blas_version);
769     VLOG(1) << "Autotuning result: " << log.ShortDebugString();
770     // If we crash on checking failure, we are in a testing/benchmark mode, thus
771     // omitting logging through the logger.
772     if (!crash_on_checking_failure) {
773       tensorflow::Logger::GetSingleton()->LogProto(log);
774     } else {
775       // Crash on miscompares and redzone violations if desired.
776       for (const auto& profile : profile_results) {
777         if (profile.has_failure() &&
778             profile.failure().kind() != AutotuneResult::DISQUALIFIED) {
779           LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n"
780                      << log.DebugString();
781         }
782       }
783     }
784   }
785 
786   TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
787                       PickBestResult(profile_results, *instr));
788   return selected_algorithm;
789 }
790 #endif
791 
792 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheRocm(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)793 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
794     const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
795     se::Stream* stream) {
796   XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
797       "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
798 
799   const auto device_ordinal = stream_exec_->device_ordinal();
800   std::vector<se::DeviceMemoryBase> operand_buffers;
801 
802   ScratchAllocator input_output_allocator(device_ordinal, allocator);
803   const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
804     // Although we don't have evidence this matters, zero out the buffers
805     // before autotuning.  It's conceivable that using uninitialized memory as
806     // the inputs might affect performance if e.g. the inputs contain
807     // denormals, and this is easy enough.
808     stream->ThenMemZero(&buffer, buffer.size());
809   };
810 
811   // Allocate space for the input, filter, and output of the convolution.  We
812   // use a ScratchAllocator for this instead of calling allocator_ directly so
813   // that our allocations don't leak.
814   for (const auto* operand : instr->operands()) {
815     TF_ASSIGN_OR_RETURN(auto buffer,
816                         input_output_allocator.AllocateBytes(
817                             ShapeUtil::ByteSizeOf(operand->shape())));
818     initialize_buffer(buffer);
819     operand_buffers.push_back(buffer);
820   }
821 
822   TF_ASSIGN_OR_RETURN(
823       auto result_buffer,
824       input_output_allocator.AllocateBytes(
825           ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
826   initialize_buffer(result_buffer);
827 
828   ScratchAllocator scratch_allocator(device_ordinal, allocator);
829 
830   TF_ASSIGN_OR_RETURN(
831       std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
832       GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer,
833                           stream_exec_, &scratch_allocator, stream));
834 
835   std::vector<AutotuneResult> profile_results;
836 
837   if (runners.size() == 1) {
838     TF_ASSIGN_OR_RETURN(auto alg, runners[0]->ToAlgorithmDesc());
839     auto algorithm_proto = alg.ToProto();
840     profile_results.emplace_back();
841     auto& result = profile_results.back();
842     *result.mutable_algorithm() = algorithm_proto;
843 
844     result.set_scratch_bytes(runners[0]->GetWorkspaceSize());
845 
846     // TODO(awpr): if the profile result time for a singleton algorithm is
847     // needed, plumb it via OpRunner; we'll need to do this to let TF ops avoid
848     // re-profiling ROCm algorithms anyway.
849     *result.mutable_run_time() =
850         tensorflow::proto_utils::ToDurationProto(absl::Milliseconds(-1));
851   } else {
852     TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
853     for (auto& runner : runners) {
854       TF_ASSIGN_OR_RETURN(auto alg, runner->ToAlgorithmDesc());
855       XLA_SCOPED_LOGGING_TIMER_LEVEL(
856           absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
857                        alg.ToString()),
858           2);
859 
860       se::dnn::ProfileResult profile_result;
861       VLOG(3) << "Trying algorithm " << alg.ToString() << " for "
862               << instr->ToString();
863 
864       TF_ASSIGN_OR_RETURN(
865           DeviceMemoryBase scratch_memory,
866           scratch_allocator.AllocateBytes(runner->GetWorkspaceSize()));
867 
868       TF_ASSIGN_OR_RETURN(auto lazy_runner,
869                           se::dnn::LazyOpRunner<se::dnn::ConvOp>::FromOpRunner(
870                               std::move(runner)));
871 
872       MaybeFusedConvRunner runner_cache(std::move(lazy_runner));
873 
874       // Use assignment instead of brace-list to make GCC 4.9 happy.
875       RunConvOptions options;
876       options.profile_result = &profile_result;
877       options.runner_cache = &runner_cache;
878       Status launch_status =
879           RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
880                      scratch_memory, stream, options);
881 
882       if (!launch_status.ok()) {
883         continue;
884       }
885 
886       if (!profile_result.is_valid()) {
887         continue;
888       }
889 
890       profile_results.emplace_back();
891       AutotuneResult& result = profile_results.back();
892       *result.mutable_algorithm() = alg.ToProto();
893 
894       int64_t scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
895       result.set_scratch_bytes(scratch_bytes_used);
896       *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
897           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
898     }
899   }
900 
901   TF_ASSIGN_OR_RETURN(AutotuneResult selected_algorithm,
902                       PickBestResult(profile_results, *instr));
903   return selected_algorithm;
904 }
905 
RunOnInstruction(HloInstruction * instr)906 StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
907   CHECK(IsCustomCallToDnnConvolution(*instr));
908 
909   const bool strict = instr->parent()
910                           ->parent()
911                           ->config()
912                           .debug_options()
913                           .xla_gpu_strict_conv_algorithm_picker();
914 
915   StatusOr<AutotuneResult> best_algo_or =
916       PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
917   if (!best_algo_or.ok()) {
918     auto msg = absl::StrFormat(
919         "Failed to determine best cudnn convolution algorithm for:\n%s\n\n"
920         "Original error: %s",
921         instr->ToString(), best_algo_or.status().ToString());
922 
923     if (strict) {
924       return Unknown(
925           "%s\n\nTo ignore this failure and try to use a fallback algorithm "
926           "(which may have suboptimal performance), use "
927           "XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please "
928           "also file a bug for the root cause of failing autotuning.",
929           msg);
930     }
931     LOG(WARNING)
932         << msg << "\n\nAs a result, convolution performance may be suboptimal.";
933     return false;
934   }
935 
936   auto best_algo = std::move(best_algo_or).ValueOrDie();
937   VLOG(2) << "Setting cudnn conv to use algorithm "
938           << best_algo.conv().algorithm() << " and "
939           << NumBytesToString(best_algo.scratch_bytes())
940           << " of scratch memory: " << instr->ToString()
941           << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
942 
943   // Replace instr with a new CustomCall which has the correct algorithm, and
944   // whose output shape has the appropriate amount of scratch memory.
945   HloComputation* computation = instr->parent();
946   Shape new_call_shape = ShapeUtil::MakeTupleShape(
947       {instr->shape().tuple_shapes(0),
948        ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})});
949 
950   TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
951                       instr->backend_config<CudnnConvBackendConfig>());
952   *backend_config.mutable_algorithm() = best_algo.algorithm();
953   backend_config.mutable_algorithm()->mutable_workspace_size()->set_value(
954       best_algo.scratch_bytes());
955 
956   HloInstruction* new_call = computation->AddInstruction(
957       instr->CloneWithNewOperands(new_call_shape, instr->operands()));
958 
959   // Preserve the name of the old instruction.  This is safe because we're going
960   // to remove the old one anyway, and it makes it easier to trace how our conv
961   // is transformed through all our passes.
962   new_call->SetAndSanitizeName(instr->name());
963 
964   VLOG(2) << "Replacing convolution " << instr->ToString() << " with "
965           << new_call->ToString();
966 
967   TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
968 
969   // Repackage new_call so it has the same shape as the original call, namely
970   // (conv_result, u8[0]).
971   HloInstruction* new_tuple =
972       computation->AddInstruction(HloInstruction::CreateTuple(
973           {computation->AddInstruction(HloInstruction::CreateGetTupleElement(
974                new_call_shape.tuple_shapes(0), new_call, 0)),
975            computation->AddInstruction(HloInstruction::CreateConstant(
976                LiteralUtil::CreateR1<uint8_t>({})))}));
977 
978   TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
979   return true;
980 }
981 
RunOnComputation(HloComputation * computation)982 StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
983     HloComputation* computation) {
984   std::vector<HloInstruction*> convs;
985   for (auto* instr : computation->instructions()) {
986     if (IsCustomCallToDnnConvolution(*instr)) {
987       convs.push_back(instr);
988     }
989   }
990 
991   bool changed = false;
992   for (auto* instr : convs) {
993     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
994     changed |= result;
995   }
996   return changed;
997 }
998 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)999 StatusOr<bool> GpuConvAlgorithmPicker::Run(
1000     HloModule* module,
1001     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1002   XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
1003 
1004   if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
1005     VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
1006                "returning early.";
1007     return false;
1008   }
1009 
1010   bool changed = false;
1011   for (HloComputation* computation :
1012        module->MakeNonfusionComputations(execution_threads)) {
1013     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
1014     changed |= result;
1015   }
1016 
1017   {
1018     absl::MutexLock lock(&autotune_cache_lock);
1019     autotune_cache_stats.LogStats();
1020   }
1021 
1022   return changed;
1023 }
1024 
1025 }  // namespace gpu
1026 }  // namespace xla
1027