1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
17
18 #include <stdint.h>
19
20 #include <algorithm>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/cleanup/cleanup.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
31 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
32 #include "tensorflow/compiler/xla/service/computation_layout.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/logical_buffer.h"
36 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
37 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
38 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
39 #include "tensorflow/compiler/xla/shape_tree.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/env.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/stream_executor/device_memory_allocator.h"
48 #include "tensorflow/stream_executor/host/host_stream.h"
49
50 namespace xla {
51 namespace cpu {
52
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,std::unique_ptr<const BufferAssignment> assignment,std::unique_ptr<HloModule> hlo_module,const std::string & entry_function_name,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)53 CpuExecutable::CpuExecutable(
54 std::unique_ptr<SimpleOrcJIT> jit,
55 std::unique_ptr<const BufferAssignment> assignment,
56 std::unique_ptr<HloModule> hlo_module,
57 const std::string& entry_function_name,
58 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
59 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
60 : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
61 std::move(hlo_profile_index_map)),
62 jit_(std::move(jit)),
63 assignment_(std::move(assignment)),
64 module_name_(entry_function_name) {
65 if (assignment_) {
66 buffer_assignment_.reset(new BufferAssignmentProto(assignment_->ToProto()));
67 }
68 if (has_module()) {
69 XlaDebugInfoManager::Get()->RegisterModule(
70 module().unique_id(), shared_module(), buffer_assignment_);
71 }
72
73 // Resolve symbols in the constructor rather than at execution time to avoid
74 // races because FindSymbol is not thread safe.
75 llvm::Expected<llvm::JITEvaluatedSymbol> sym =
76 jit_->FindCompiledSymbol(entry_function_name);
77 // We expect to find the symbol provided with entry_function_name; otherwise
78 // this is an internal error.
79 CHECK(*sym) << "Symbol " << entry_function_name << " not found.";
80 // getAddress can do work under the hood in the jit, so it needs to be
81 // guarded by the mutex.
82 compute_function_ = reinterpret_cast<ComputeFunctionType>(sym->getAddress());
83 VLOG(1) << "compute_function_ at address "
84 << reinterpret_cast<void*>(compute_function_);
85 jit_->DoneCompiling();
86 }
87
~CpuExecutable()88 CpuExecutable::~CpuExecutable() {
89 if (has_module()) {
90 XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
91 }
92 }
93
MemoryForAllocation(const BufferAllocation & allocation,absl::Span<ExecutionInput const> arguments,se::DeviceMemoryAllocator * memory_allocator,int device_ordinal)94 static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
95 const BufferAllocation& allocation,
96 absl::Span<ExecutionInput const> arguments,
97 se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) {
98 VLOG(3) << allocation.ToString();
99 if (allocation.is_entry_computation_parameter()) {
100 se::DeviceMemoryBase out = arguments[allocation.parameter_number()]
101 .Buffer(allocation.param_shape_index())
102 .AsDeviceMemoryBase();
103 CHECK_LE(allocation.size(), out.size())
104 << "Size mismatch on param " << allocation.parameter_number()
105 << " at shape index " << allocation.param_shape_index().ToString();
106 VLOG(3) << "allocation is a parameter";
107 return MaybeOwningDeviceMemory{out};
108 } else if (allocation.is_constant()) {
109 VLOG(3) << "allocation is a constant";
110 return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
111 } else if (allocation.is_thread_local()) {
112 VLOG(3) << "buffer is thread-local";
113 return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
114 }
115
116 int64_t buffer_size = allocation.size();
117 TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory out,
118 memory_allocator->Allocate(device_ordinal, buffer_size));
119 VLOG(3) << "buffer allocated " << buffer_size << " bytes [" << out->opaque()
120 << "]";
121
122 // Since the output buffer and all the temporary buffers were written into
123 // by the JITed code, msan has no way of knowing their memory was
124 // initialized. Mark them initialized so that msan doesn't flag loads from
125 // these buffers.
126 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(out->opaque(), buffer_size);
127 return MaybeOwningDeviceMemory{std::move(out)};
128 }
129
CreateBufferTable(se::DeviceMemoryAllocator * memory_allocator,int device_ordinal,absl::Span<ExecutionInput const> arguments)130 StatusOr<std::vector<MaybeOwningDeviceMemory>> CpuExecutable::CreateBufferTable(
131 se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
132 absl::Span<ExecutionInput const> arguments) {
133 std::vector<MaybeOwningDeviceMemory> buffers(
134 assignment_->Allocations().size());
135 VLOG(3) << "Allocating " << assignment_->Allocations().size()
136 << " allocations for module " << module().name();
137 for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
138 ++i) {
139 const BufferAllocation& allocation = assignment_->GetAllocation(i);
140 TF_ASSIGN_OR_RETURN(
141 buffers[i], MemoryForAllocation(allocation, arguments, memory_allocator,
142 device_ordinal));
143 }
144
145 if (VLOG_IS_ON(3)) {
146 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
147 assignment_->GetUniqueTopLevelOutputSlice());
148 VLOG(3) << "result index: " << result_slice.index();
149 }
150 return std::move(buffers);
151 }
152
ExecuteComputeFunction(const ExecutableRunOptions * run_options,absl::Span<MaybeOwningDeviceMemory const> buffers,HloExecutionProfile * hlo_execution_profile)153 Status CpuExecutable::ExecuteComputeFunction(
154 const ExecutableRunOptions* run_options,
155 absl::Span<MaybeOwningDeviceMemory const> buffers,
156 HloExecutionProfile* hlo_execution_profile) {
157 uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
158
159 size_t profile_counters_size =
160 hlo_execution_profile ? hlo_execution_profile->profile_counters().size()
161 : 0;
162 int64_t* profile_counters =
163 hlo_execution_profile
164 ? hlo_execution_profile->mutable_profile_counters()->data()
165 : nullptr;
166
167 // Call the computation function following the calling convention. See the
168 // definition of 'ComputeFunctionType' for the details of the calling
169 // convention of JITed functions.
170 std::vector<void*> buffer_pointers;
171 for (auto& buffer : buffers) {
172 buffer_pointers.push_back(
173 const_cast<void*>(buffer.AsDeviceMemoryBase().opaque()));
174 }
175
176 VLOG(3) << "Executing compute function:";
177 VLOG(3) << absl::StrFormat(" Number of buffer table entries: %u",
178 buffer_pointers.size());
179 auto ptr_printer = [](std::string* out, const void* p) {
180 absl::StrAppend(out, absl::StrFormat("%p", p));
181 };
182 VLOG(3) << absl::StrFormat(" Buffer table: [%s]",
183 absl::StrJoin(buffer_pointers, ", ", ptr_printer));
184 VLOG(3) << absl::StrFormat(" Number of profile counters: %u",
185 profile_counters_size);
186 VLOG(3) << absl::StrFormat(" Profile counters: %p", profile_counters);
187
188 XlaCustomCallStatus status;
189 // For the entry computation (like all global computations), all inputs and
190 // outputs are in the buffer table, and both the result pointer and args array
191 // pointers are unused (so we set them to 'nullptr').
192 compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
193 &status, profile_counters);
194
195 uint64_t end_micros = tensorflow::Env::Default()->NowMicros();
196
197 if (run_options->execution_profile()) {
198 const double nanoseconds = (end_micros - start_micros) * 1000.0;
199 run_options->execution_profile()->set_compute_time_ns(
200 std::max(nanoseconds, 1.0));
201 // If hlo profiling was disabled then the cycle count is left empty.
202 if (hlo_execution_profile) {
203 run_options->execution_profile()->set_compute_cycle_count(
204 hlo_execution_profile->total_cycles_executed(
205 *module().entry_computation()));
206 }
207 }
208
209 std::optional<absl::string_view> error_message =
210 CustomCallStatusGetMessage(&status);
211 if (error_message) {
212 return InternalError("CustomCall failed: %s", *error_message);
213 }
214
215 return OkStatus();
216 }
217
CreateResultShapedBuffer(const ServiceExecutableRunOptions * run_options,absl::Span<MaybeOwningDeviceMemory> buffers,absl::Span<ExecutionInput> arguments)218 StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
219 const ServiceExecutableRunOptions* run_options,
220 absl::Span<MaybeOwningDeviceMemory> buffers,
221 absl::Span<ExecutionInput> arguments) {
222 se::Stream* stream = run_options->stream();
223 ExecutionOutput result(/*on_device_shape=*/result_shape(),
224 run_options->allocator(),
225 stream->parent()->device_ordinal());
226 const HloInputOutputAliasConfig& input_output_alias =
227 module().input_output_alias_config();
228 HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
229 const Shape& root_shape = root->shape();
230
231 // Move se::OwningDeviceMemory values which contain the array(s) of the result
232 // into the respective location in ScopedShapedBuffer which is returned to the
233 // caller.
234 for (auto& p : result.MutableResult()->buffers()) {
235 const ShapeIndex& index = p.first;
236 se::DeviceMemoryBase& result_buffer = p.second;
237 const HloValueSet& sources = this->GetRootValueSet().element(index);
238 // The points to set is unambiguous so the set should be a
239 // singleton.
240 CHECK_EQ(1, sources.values().size());
241 const HloValue* value_source = sources.values()[0];
242 HloInstruction* src = value_source->instruction();
243
244 // The source for this result buffer can be a nested buffer such as
245 // a tuple element.
246 TF_ASSIGN_OR_RETURN(
247 const BufferAllocation::Slice slice,
248 this->assignment_->GetUniqueSlice(src, value_source->index()));
249 const BufferAllocation::Index buffer_index = slice.index();
250
251 // TODO(cheshire): duplication with other backends.
252 std::optional<HloInputOutputAliasConfig::Alias> alias =
253 input_output_alias.GetAliasedParameter(index);
254 if (alias) {
255 CHECK_LT(alias->parameter_number, arguments.size());
256 ExecutionInput& input = arguments[alias->parameter_number];
257 MaybeOwningDeviceMemory* maybe_owning_memory =
258 input.MutableBuffer(alias->parameter_index);
259 if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
260 return InvalidArgument(
261 "An input was configured to be must-alias at "
262 "compile time but not donated at runtime: %s",
263 alias->ToString());
264 }
265 if (std::optional<se::OwningDeviceMemory> owning =
266 maybe_owning_memory->Release()) {
267 // If the caller passes the ownership of the device memory, reuse it
268 // as the output buffer. It is up to the caller whether or not to
269 // donate a buffer; the aliasing information describes which buffers
270 // may alias, not buffers that must alias.
271 se::DeviceMemoryBase argument_buffer = owning->Release();
272 *maybe_owning_memory = argument_buffer;
273 result_buffer = argument_buffer;
274 // The caller is giving us the
275 // input buffer, but in case of error of the execute call, we should
276 // not be releasing it as it contains valid data (for example, it is a
277 // parameter which the user wants us to alias, in a gradient update
278 // computation). So we store the index into the result in the aliased
279 // vactor, which will be fed to the ExecutionOutput, which will be
280 // using the indices to drop the addresses from its own
281 // ScopedShapedBuffer result, if the ExecutionOutput is not committed.
282 result.AddAliasedIndex(index);
283 } else {
284 VLOG(3) << "Using copy-protection: aliasing is specified, but the "
285 "buffer is not donated; allocating a fresh buffer";
286 int64_t allocation_size =
287 ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(root_shape, index));
288 TF_ASSIGN_OR_RETURN(
289 se::OwningDeviceMemory allocated_buffer,
290 run_options->allocator()->Allocate(
291 stream->parent()->device_ordinal(), allocation_size));
292 result_buffer = allocated_buffer.Release();
293 MaybeOwningDeviceMemory& registered_buffer = buffers[buffer_index];
294 CHECK_EQ(result_buffer.size(),
295 registered_buffer.AsDeviceMemoryBase().size());
296 std::memcpy(/*dest=*/result_buffer.opaque(),
297 /*src=*/registered_buffer.AsDeviceMemoryBase().opaque(),
298 /*n=*/result_buffer.size());
299 registered_buffer = result_buffer;
300 }
301 }
302
303 if (result_buffer.is_null()) {
304 MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
305 if (std::optional<se::OwningDeviceMemory> owned_buffer =
306 buffer.Release()) {
307 result_buffer = owned_buffer->Release();
308 buffer = result_buffer;
309 } else {
310 result_buffer = buffer.AsDeviceMemoryBase();
311 result.AddAliasedIndex(index);
312 }
313 }
314 }
315 return std::move(result);
316 }
317
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)318 StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
319 const ServiceExecutableRunOptions* run_options,
320 std::vector<ExecutionInput> arguments,
321 HloExecutionProfile* hlo_execution_profile) {
322 if (GetRootValueSet().IsAmbiguous()) {
323 return Unimplemented("Points-to set of root instruction is ambiguous");
324 }
325
326 if (hlo_module_) {
327 const HloComputation* entry_comp = hlo_module_->entry_computation();
328 CHECK_EQ(entry_comp->num_parameters(), arguments.size())
329 << "Wrong number of arguments passed when running executable";
330 for (int64_t i = 0; i < entry_comp->num_parameters(); ++i) {
331 const Shape& expected_shape =
332 entry_comp->parameter_instruction(i)->shape();
333 const Shape& actual_shape = arguments[i].Buffers().shape();
334 TF_RET_CHECK(
335 ShapeUtil::DynamicShapeIsCompatible(actual_shape, expected_shape))
336 << "Shape mismatch on argument " << i << ", "
337 << expected_shape.ToString(/*print_layout=*/true) << " vs. "
338 << actual_shape.ToString(/*print_layout=*/true);
339 }
340 }
341
342 auto* host_stream = dynamic_cast<se::host::HostStream*>(
343 run_options->stream()->implementation());
344 se::Stream* stream = run_options->stream();
345 se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
346 TF_ASSIGN_OR_RETURN(
347 std::vector<MaybeOwningDeviceMemory> buffers,
348 CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
349 arguments));
350
351 TF_ASSIGN_OR_RETURN(
352 ExecutionOutput result,
353 CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
354 absl::MakeSpan(arguments)));
355
356 // Logically we want this lambda to capture `buffers` by move, ultimately our
357 // functor needs to be wrapped in an std::function, and that requires its
358 // functor to be copyable. Thus we perpetrate the hack of capturing buffers
359 // "by shared pointer".
360 //
361 // We also need to change the types of some of the variables we capture:
362 // run_options needs to change from a pointer to a value type, and arguments
363 // needs to change from a Span into a vector. We use a struct instead
364 // of a lambda to make this explicit.
365 struct AsyncRunTask {
366 CpuExecutable* executable;
367 ServiceExecutableRunOptions run_options;
368 std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
369 HloExecutionProfile* hlo_execution_profile;
370
371 Status operator()() {
372 return executable->ExecuteComputeFunction(
373 &run_options.run_options(), *task_buffers, hlo_execution_profile);
374 }
375 };
376 host_stream->EnqueueTaskWithStatus(
377 AsyncRunTask{this, *run_options,
378 std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
379 std::move(buffers)),
380 hlo_execution_profile});
381
382 MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
383 return std::move(result);
384 }
385
ShapeSizeBytes(const Shape & shape)386 /*static*/ int64_t CpuExecutable::ShapeSizeBytes(const Shape& shape) {
387 // On the cpu, opaques are pointers.
388 if (shape.IsOpaque()) {
389 return sizeof(void*);
390 }
391 if (shape.is_static() || shape.IsTuple()) {
392 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
393 }
394 // Each dynamic dimension size is represented as a S32.
395 int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
396 return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size;
397 }
398
GetRootValueSet() const399 const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
400 return assignment_->dataflow_analysis().GetInstructionValueSet(
401 module().entry_computation()->root_instruction());
402 }
403
SizeOfGeneratedCodeInBytes() const404 int64_t CpuExecutable::SizeOfGeneratedCodeInBytes() const {
405 return jit_->SizeOfGeneratedCodeInBytes();
406 }
407
408 } // namespace cpu
409 } // namespace xla
410