xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
17 
18 #include <functional>
19 #include <limits>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
27 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
28 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logger.h"
34 #include "tensorflow/core/platform/statusor.h"
35 #include "tensorflow/core/protobuf/autotuning.pb.h"
36 #include "tensorflow/core/util/proto/proto_utils.h"
37 #include "tensorflow/stream_executor/blas.h"
38 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
39 #include "tensorflow/stream_executor/device_memory.h"
40 #include "tensorflow/stream_executor/device_memory_allocator.h"
41 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
42 
43 namespace xla {
44 namespace gpu {
45 
46 using tensorflow::AutotuneResult;
47 
48 namespace {
49 
50 struct AutotuneConfig {
should_init_buffersxla::gpu::__anonf4bf08a10111::AutotuneConfig51   bool should_init_buffers() const { return autotune_level >= 2; }
should_reinit_output_bufferxla::gpu::__anonf4bf08a10111::AutotuneConfig52   bool should_reinit_output_buffer() const { return autotune_level >= 3; }
should_check_correctnessxla::gpu::__anonf4bf08a10111::AutotuneConfig53   bool should_check_correctness() const { return autotune_level >= 4; }
54 
55   int32_t autotune_level;
56   bool should_crash_on_check_failure;
57 };
58 
GetConfig(const DebugOptions & debug_options)59 AutotuneConfig GetConfig(const DebugOptions& debug_options) {
60   return {debug_options.xla_gpu_autotune_level(),
61           debug_options.xla_gpu_crash_on_verification_failures()};
62 }
63 
CreateRedzoneAllocator(se::Stream * stream,se::DeviceMemoryAllocator * allocator,const DebugOptions & debug_options,const AutotuneConfig & config)64 se::RedzoneAllocator CreateRedzoneAllocator(
65     se::Stream* stream, se::DeviceMemoryAllocator* allocator,
66     const DebugOptions& debug_options, const AutotuneConfig& config) {
67   int64_t redzone_size = config.should_check_correctness()
68                              ? se::RedzoneAllocator::kDefaultRedzoneSize
69                              : 0;
70 
71   return se::RedzoneAllocator(
72       stream, allocator, PtxOptsFromDebugOptions(debug_options),
73       /*memory_limit=*/std::numeric_limits<int64_t>::max(),
74       /*redzone_size=*/redzone_size);
75 }
76 
CreateBuffer(se::RedzoneAllocator & allocator,const HloInstruction & op,const AutotuneConfig & config,int64_t & rng_state)77 StatusOr<se::DeviceMemoryBase> CreateBuffer(se::RedzoneAllocator& allocator,
78                                             const HloInstruction& op,
79                                             const AutotuneConfig& config,
80                                             int64_t& rng_state) {
81   TF_ASSIGN_OR_RETURN(
82       se::DeviceMemoryBase buffer,
83       allocator.AllocateBytes(ShapeUtil::ByteSizeOf(op.shape())));
84   if (config.should_init_buffers()) {
85     InitializeBuffer(allocator.stream(), op.shape().element_type(), &rng_state,
86                      buffer);
87   }
88   return buffer;
89 }
90 
91 // Returns the index (into `algorithms`) of the fastest algorithm.
92 template <typename AlgoT>
GetBestAlgorithm(se::Stream * stream,se::RedzoneAllocator & allocator,const HloInstruction & gemm,const AutotuneConfig & autotune_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,absl::Span<const AlgoT> algorithms,const std::function<StatusOr<se::blas::ProfileResult> (const AlgoT &)> & run_benchmark)93 StatusOr<std::optional<size_t>> GetBestAlgorithm(
94     se::Stream* stream, se::RedzoneAllocator& allocator,
95     const HloInstruction& gemm, const AutotuneConfig& autotune_config,
96     se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
97     se::DeviceMemoryBase output_buffer, absl::Span<const AlgoT> algorithms,
98     const std::function<StatusOr<se::blas::ProfileResult>(const AlgoT&)>&
99         run_benchmark) {
100   if (!stream->parent()->SynchronizeAllActivity()) {
101     return InternalError("Failed to synchronize GPU for autotuning.");
102   }
103 
104   TF_ASSIGN_OR_RETURN(GemmBackendConfig backend_config,
105                       gemm.backend_config<GemmBackendConfig>());
106 
107   se::DeviceMemoryBase reference_buffer;
108   if (autotune_config.should_check_correctness()) {
109     TF_ASSIGN_OR_RETURN(
110         reference_buffer,
111         allocator.AllocateBytes(ShapeUtil::ByteSizeOf(gemm.shape())));
112   }
113 
114   BufferComparator comparator(gemm.shape(), gemm.GetModule()->config());
115 
116   std::vector<AutotuneResult> results;
117   std::optional<int64_t> reference_algorithm;
118 
119   for (const AlgoT& algorithm : algorithms) {
120     // Make sure the output buffer always has the same value if we use
121     // the bias parameter.
122     if (autotune_config.should_reinit_output_buffer() &&
123         backend_config.beta() != 0) {
124       int64_t rng_state = 0;
125       InitializeBuffer(stream, gemm.shape().element_type(), &rng_state,
126                        output_buffer);
127     }
128 
129     TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result,
130                         run_benchmark(algorithm));
131 
132     results.emplace_back();
133     AutotuneResult& result = results.back();
134     result.mutable_gemm()->set_algorithm(profile_result.algorithm());
135 
136     if (!profile_result.is_valid()) {  // Unsupported algorithm.
137       result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED);
138       continue;
139     }
140 
141     VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took "
142             << profile_result.elapsed_time_in_ms() << "ms";
143 
144     *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
145         absl::Milliseconds(profile_result.elapsed_time_in_ms()));
146 
147     if (!autotune_config.should_check_correctness()) {
148       continue;
149     }
150 
151     TF_ASSIGN_OR_RETURN(
152         se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
153         allocator.CheckRedzones());
154 
155     if (!rz_check_status.ok()) {
156       result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
157       *result.mutable_failure()->mutable_msg() =
158           rz_check_status.RedzoneFailureMsg();
159       LOG(ERROR) << "Detected out-of-bounds write in gemm buffer";
160       CHECK(!autotune_config.should_crash_on_check_failure);
161       continue;
162     }
163 
164     if (!reference_algorithm) {
165       stream->ThenMemcpy(&reference_buffer, output_buffer,
166                          output_buffer.size());
167       reference_algorithm = profile_result.algorithm();
168     } else {
169       // Perform the comparison.
170       TF_ASSIGN_OR_RETURN(
171           bool outputs_match,
172           comparator.CompareEqual(stream, output_buffer, reference_buffer));
173       if (!outputs_match) {
174         LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
175                    << "This is likely a bug/unexpected loss of precision.";
176         CHECK(!autotune_config.should_crash_on_check_failure);
177 
178         result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT);
179         result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
180             *reference_algorithm);
181       }
182     }
183   }
184 
185   if (!autotune_config.should_crash_on_check_failure) {
186     tensorflow::AutotuningLog log;
187     for (const AutotuneResult& result : results) {
188       *log.add_results() = result;
189     }
190     tensorflow::Logger::GetSingleton()->LogProto(log);
191   }
192 
193   StatusOr<AutotuneResult> best = PickBestResult(results, gemm);
194   if (best.ok()) {
195     for (size_t i = 0; i < results.size(); ++i) {
196       if (best->gemm().algorithm() == results[i].gemm().algorithm()) {
197         return {i};
198       }
199     }
200     return InternalError("unknown best algorithm");
201   }
202 
203   LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance "
204                   "might be suboptimal: "
205                << best.status();
206   return {std::nullopt};
207 }
208 
DoGemmAutotune(const HloInstruction * gemm,const GemmBackendConfig & gemm_config,se::DeviceMemoryAllocator * allocator,se::Stream * stream)209 StatusOr<std::optional<se::blas::AlgorithmType>> DoGemmAutotune(
210     const HloInstruction* gemm, const GemmBackendConfig& gemm_config,
211     se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
212   VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
213   const HloInstruction* lhs = gemm->operand(0);
214   const HloInstruction* rhs = gemm->operand(1);
215 
216   TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
217   // Don't run autotuning concurrently on the same GPU.
218   absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent()));
219 
220   auto key = std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
221                              gemm->shape(), gemm_config.SerializeAsString(),
222                              IsCublasLtMatmul(*gemm));
223 
224   static absl::Mutex mutex(absl::kConstInit);
225   static auto& cache ABSL_GUARDED_BY(mutex) =
226       *new absl::flat_hash_map<decltype(key),
227                                std::optional<se::blas::AlgorithmType>>();
228   static int64_t cache_hits ABSL_GUARDED_BY(mutex) = 0;
229   static int64_t cache_misses ABSL_GUARDED_BY(mutex) = 0;
230 
231   absl::MutexLock lock(&mutex);
232   auto it = cache.find(key);
233   int64_t requests = cache_hits + cache_misses;
234   if (requests && requests % 10 == 0) {
235     VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
236             << requests;
237   }
238 
239   if (it != cache.end()) {
240     cache_hits++;
241     VLOG(4) << "Autotuning cache hit, using algorithm: "
242             << (it->second.has_value() ? absl::StrCat(*(it->second))
243                                        : "<generic>");
244     return it->second;
245   }
246   cache_misses++;
247   VLOG(4) << "Autotuning cache miss";
248 
249   const DebugOptions& debug_options =
250       gemm->GetModule()->config().debug_options();
251   AutotuneConfig autotune_config = GetConfig(debug_options);
252 
253   se::RedzoneAllocator buffer_allocator =
254       CreateRedzoneAllocator(stream, allocator, debug_options, autotune_config);
255 
256   int64_t rng_state = 0;
257   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
258                       CreateBuffer(buffer_allocator, *gemm->operand(0),
259                                    autotune_config, rng_state));
260   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
261                       CreateBuffer(buffer_allocator, *gemm->operand(1),
262                                    autotune_config, rng_state));
263   TF_ASSIGN_OR_RETURN(
264       se::DeviceMemoryBase output_buffer,
265       CreateBuffer(buffer_allocator, *gemm, autotune_config, rng_state));
266 
267   std::optional<se::blas::AlgorithmType> best_algorithm;
268   if (IsCublasLtMatmul(*gemm)) {
269     bool has_matrix_bias = config.beta != 0.;
270     bool has_vector_bias = gemm_config.epilogue() == GemmBackendConfig::BIAS;
271 
272     auto epilogue = has_vector_bias ? se::cuda::BlasLt::Epilogue::kBias
273                                     : se::cuda::BlasLt::Epilogue::kDefault;
274 
275     se::DeviceMemoryBase bias_buffer;
276     if (has_vector_bias) {
277       TF_ASSIGN_OR_RETURN(bias_buffer,
278                           CreateBuffer(buffer_allocator,
279                                        *gemm->operand(has_matrix_bias ? 3 : 2),
280                                        autotune_config, rng_state));
281     }
282 
283     TF_ASSIGN_OR_RETURN(auto plan,
284                         cublas_lt::MatmulPlan::From(config, epilogue));
285     TF_ASSIGN_OR_RETURN(
286         std::vector<se::cuda::BlasLt::MatmulAlgorithm> algorithms,
287         plan.GetAlgorithms(stream));
288 
289     TF_ASSIGN_OR_RETURN(
290         std::optional<size_t> best_algorithm_idx,
291         GetBestAlgorithm<se::cuda::BlasLt::MatmulAlgorithm>(
292             stream, buffer_allocator, *gemm, autotune_config, lhs_buffer,
293             rhs_buffer, output_buffer, algorithms,
294             [&](const se::cuda::BlasLt::MatmulAlgorithm& algorithm)
295                 -> StatusOr<se::blas::ProfileResult> {
296               se::OwningScratchAllocator<> scratch_allocator(
297                   stream->parent()->device_ordinal(), allocator);
298               se::blas::ProfileResult profile_result;
299               TF_RETURN_IF_ERROR(plan.ExecuteOnStream(
300                   stream, lhs_buffer, rhs_buffer, output_buffer, output_buffer,
301                   bias_buffer, algorithm, scratch_allocator, &profile_result));
302               return std::move(profile_result);
303             }));
304 
305     TF_RET_CHECK(best_algorithm_idx) << "failed to auto-tune cublas_lt matmul";
306     best_algorithm = *best_algorithm_idx;
307   } else {
308     std::vector<se::blas::AlgorithmType> algorithms;
309     TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms));
310 
311     TF_ASSIGN_OR_RETURN(std::optional<size_t> best_algorithm_idx,
312                         GetBestAlgorithm<se::blas::AlgorithmType>(
313                             stream, buffer_allocator, *gemm, autotune_config,
314                             lhs_buffer, rhs_buffer, output_buffer, algorithms,
315                             [&](const se::blas::AlgorithmType& algorithm)
316                                 -> StatusOr<se::blas::ProfileResult> {
317                               se::blas::ProfileResult profile_result;
318                               // We expect GemmWithAlgorithm to fail sometimes
319                               // -- in fact, it will fail for all algorithms if
320                               // we're targeting < sm_50.  But because we pass a
321                               // non-null ProfileResult, DoGemmWithAlgorithm
322                               // should always return true, and the actual
323                               // success-ness is returned in
324                               // ProfileResult::is_valid.
325                               TF_RETURN_IF_ERROR(RunGemm(
326                                   config, lhs_buffer, rhs_buffer, output_buffer,
327                                   stream, algorithm, &profile_result));
328                               return std::move(profile_result);
329                             }));
330 
331     if (best_algorithm_idx) best_algorithm = algorithms[*best_algorithm_idx];
332   }
333 
334   CHECK(cache.emplace(key, best_algorithm).second);
335   return best_algorithm;
336 }
337 
RunOnInstruction(HloInstruction * instr,se::StreamExecutor * executor,se::DeviceMemoryAllocator * allocator)338 StatusOr<bool> RunOnInstruction(HloInstruction* instr,
339                                 se::StreamExecutor* executor,
340                                 se::DeviceMemoryAllocator* allocator) {
341   if (allocator == nullptr) {
342     allocator = executor->GetAllocator();
343   }
344   TF_ASSIGN_OR_RETURN(se::Stream* const stream,
345                       allocator->GetStream(executor->device_ordinal()));
346 
347   GemmBackendConfig gemm_config =
348       instr->backend_config<GemmBackendConfig>().ValueOrDie();
349 
350   TF_ASSIGN_OR_RETURN(std::optional<se::blas::AlgorithmType> gemm_algorithm,
351                       DoGemmAutotune(instr, gemm_config, allocator, stream));
352 
353   // We update instruction->backend_config(); if no algorithms are supported,
354   // a different API is used, which does not require specifying an algorithm.
355   GemmBackendConfig updated_config = gemm_config;
356   if (gemm_algorithm) {
357     VLOG(4) << "GEMM autotuning picked algorithm " << *gemm_algorithm << " for "
358             << instr->name();
359     updated_config.set_selected_algorithm(*gemm_algorithm);
360   }
361   TF_RETURN_IF_ERROR(instr->set_backend_config(updated_config));
362   return updated_config.SerializeAsString() != gemm_config.SerializeAsString();
363 }
364 
RunOnComputation(HloComputation * computation,se::StreamExecutor * se,se::DeviceMemoryAllocator * allocator)365 StatusOr<bool> RunOnComputation(HloComputation* computation,
366                                 se::StreamExecutor* se,
367                                 se::DeviceMemoryAllocator* allocator) {
368   bool changed = false;
369   for (HloInstruction* instr : computation->instructions()) {
370     if (IsCublasGemm(*instr) || IsCublasLtMatmul(*instr)) {
371       TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, se, allocator));
372       changed |= result;
373     }
374   }
375   return changed;
376 }
377 
378 }  // namespace
379 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)380 StatusOr<bool> GemmAlgorithmPicker::Run(
381     HloModule* module,
382     const absl::flat_hash_set<absl::string_view>& execution_threads) {
383   XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker");
384 
385   if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
386     VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
387     return false;
388   }
389 
390   bool changed = false;
391   for (HloComputation* computation :
392        module->MakeNonfusionComputations(execution_threads)) {
393     TF_ASSIGN_OR_RETURN(
394         bool result, RunOnComputation(computation, stream_exec_, allocator_));
395     changed |= result;
396   }
397   return changed;
398 }
399 
400 }  // namespace gpu
401 }  // namespace xla
402