xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
17 
18 #include <memory>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
23 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/kernel.h"
32 
33 namespace xla {
34 namespace gpu {
35 
KernelThunk(ThunkInfo thunk_info,absl::Span<const BufferAllocation * const> args,const std::string & kernel_name,const LaunchDimensions & launch_dimensions)36 KernelThunk::KernelThunk(ThunkInfo thunk_info,
37                          absl::Span<const BufferAllocation* const> args,
38                          const std::string& kernel_name,
39                          const LaunchDimensions& launch_dimensions)
40     : Thunk(Kind::kKernel, thunk_info),
41       args_(args.begin(), args.end()),
42       kernel_name_(kernel_name),
43       launch_dimensions_(launch_dimensions) {}
44 
ToStringExtra(int indent) const45 std::string KernelThunk::ToStringExtra(int indent) const {
46   return absl::StrFormat(", kernel = %s, launch dimensions = %s", kernel_name_,
47                          launch_dimensions_.ToString());
48 }
49 
Initialize(const GpuExecutable & executable,se::StreamExecutor * executor)50 Status KernelThunk::Initialize(const GpuExecutable& executable,
51                                se::StreamExecutor* executor) {
52   absl::MutexLock lock(&mutex_);
53 
54   // Load the kernel into the device if necessary.
55   //
56   // We could alternatively do this within ExecuteOnStream, but doing it here
57   // lets the time spent loading the kernel not count towards our execution
58   // profiles.
59   auto it = kernel_cache_.find(executor);
60   if (kernel_cache_.end() == it) {
61     TF_ASSIGN_OR_RETURN(
62         std::unique_ptr<se::KernelBase> kernel,
63         CreateKernel(kernel_name_, args_.size(), executable.text(),
64                      executable.binary(), executor));
65 
66     kernel_cache_.emplace(executor, std::move(kernel));
67   }
68 
69   return OkStatus();
70 }
71 
PrintBufferContents(se::Stream * stream,absl::Span<const se::DeviceMemoryBase> buffer_args)72 static void PrintBufferContents(
73     se::Stream* stream, absl::Span<const se::DeviceMemoryBase> buffer_args) {
74   int input_idx = 0;
75   for (const se::DeviceMemoryBase& buf : buffer_args) {
76     auto host_buffer = std::make_unique<char[]>(buf.size());
77     CHECK(stream->ThenMemcpy(host_buffer.get(), buf, buf.size()).ok());
78     CHECK(stream->BlockHostUntilDone().ok());
79 
80     std::string buffer_contents;
81     for (int i = 0; i < buf.size(); i++) {
82       absl::StrAppendFormat(&buffer_contents, "%x ",
83                             static_cast<unsigned>(host_buffer[i]));
84     }
85     VLOG(100) << "BUF(" << input_idx++ << ") = " << buffer_contents;
86   }
87 }
88 
ExecuteOnStream(const ExecuteParams & params)89 Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
90   // Load the kernel.
91   se::StreamExecutor* executor = params.stream->parent();
92   LaunchDimensions launch_dimensions;
93   const se::KernelBase* kernel = nullptr;
94 
95   {
96     absl::MutexLock lock(&mutex_);
97     auto it = kernel_cache_.find(executor);
98     CHECK(it != kernel_cache_.end())
99         << "Initialize() not called for StreamExecutor " << executor;
100     launch_dimensions = launch_dimensions_;
101     kernel = it->second.get();
102   }
103 
104   VLOG(3) << "Launching " << kernel->name();
105   absl::InlinedVector<se::DeviceMemoryBase, 4> buffer_args;
106   for (const BufferAllocation* arg : args_) {
107     se::DeviceMemoryBase buf =
108         params.buffer_allocations->GetDeviceAddress(arg->index());
109     VLOG(3) << "  Arg: alloc #" << arg->index() << ": " << buf.opaque() << "  ("
110             << buf.size() << "B)";
111     buffer_args.push_back(buf);
112   }
113 
114   if (VLOG_IS_ON(100)) {
115     PrintBufferContents(params.stream, buffer_args);
116   }
117 
118   return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions,
119                                params.stream);
120 }
121 
122 }  // namespace gpu
123 }  // namespace xla
124