xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_executable.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/cleanup/cleanup.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
31 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
32 #include "tensorflow/compiler/xla/service/computation_layout.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/logical_buffer.h"
36 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
37 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
38 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
39 #include "tensorflow/compiler/xla/shape_tree.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/env.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/stream_executor/device_memory_allocator.h"
48 #include "tensorflow/stream_executor/host/host_stream.h"
49 
50 namespace xla {
51 namespace cpu {
52 
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,std::unique_ptr<const BufferAssignment> assignment,std::unique_ptr<HloModule> hlo_module,const std::string & entry_function_name,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)53 CpuExecutable::CpuExecutable(
54     std::unique_ptr<SimpleOrcJIT> jit,
55     std::unique_ptr<const BufferAssignment> assignment,
56     std::unique_ptr<HloModule> hlo_module,
57     const std::string& entry_function_name,
58     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
59     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
60     : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
61                  std::move(hlo_profile_index_map)),
62       jit_(std::move(jit)),
63       assignment_(std::move(assignment)),
64       module_name_(entry_function_name) {
65   if (assignment_) {
66     buffer_assignment_.reset(new BufferAssignmentProto(assignment_->ToProto()));
67   }
68   if (has_module()) {
69     XlaDebugInfoManager::Get()->RegisterModule(
70         module().unique_id(), shared_module(), buffer_assignment_);
71   }
72 
73   // Resolve symbols in the constructor rather than at execution time to avoid
74   // races because FindSymbol is not thread safe.
75   llvm::Expected<llvm::JITEvaluatedSymbol> sym =
76       jit_->FindCompiledSymbol(entry_function_name);
77   // We expect to find the symbol provided with entry_function_name; otherwise
78   // this is an internal error.
79   CHECK(*sym) << "Symbol " << entry_function_name << " not found.";
80   // getAddress can do work under the hood in the jit, so it needs to be
81   // guarded by the mutex.
82   compute_function_ = reinterpret_cast<ComputeFunctionType>(sym->getAddress());
83   VLOG(1) << "compute_function_ at address "
84           << reinterpret_cast<void*>(compute_function_);
85   jit_->DoneCompiling();
86 }
87 
~CpuExecutable()88 CpuExecutable::~CpuExecutable() {
89   if (has_module()) {
90     XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id());
91   }
92 }
93 
MemoryForAllocation(const BufferAllocation & allocation,absl::Span<ExecutionInput const> arguments,se::DeviceMemoryAllocator * memory_allocator,int device_ordinal)94 static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
95     const BufferAllocation& allocation,
96     absl::Span<ExecutionInput const> arguments,
97     se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) {
98   VLOG(3) << allocation.ToString();
99   if (allocation.is_entry_computation_parameter()) {
100     se::DeviceMemoryBase out = arguments[allocation.parameter_number()]
101                                    .Buffer(allocation.param_shape_index())
102                                    .AsDeviceMemoryBase();
103     CHECK_LE(allocation.size(), out.size())
104         << "Size mismatch on param " << allocation.parameter_number()
105         << " at shape index " << allocation.param_shape_index().ToString();
106     VLOG(3) << "allocation is a parameter";
107     return MaybeOwningDeviceMemory{out};
108   } else if (allocation.is_constant()) {
109     VLOG(3) << "allocation is a constant";
110     return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
111   } else if (allocation.is_thread_local()) {
112     VLOG(3) << "buffer is thread-local";
113     return MaybeOwningDeviceMemory{se::DeviceMemoryBase{}};
114   }
115 
116   int64_t buffer_size = allocation.size();
117   TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory out,
118                       memory_allocator->Allocate(device_ordinal, buffer_size));
119   VLOG(3) << "buffer allocated " << buffer_size << " bytes [" << out->opaque()
120           << "]";
121 
122   // Since the output buffer and all the temporary buffers were written into
123   // by the JITed code, msan has no way of knowing their memory was
124   // initialized. Mark them initialized so that msan doesn't flag loads from
125   // these buffers.
126   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(out->opaque(), buffer_size);
127   return MaybeOwningDeviceMemory{std::move(out)};
128 }
129 
CreateBufferTable(se::DeviceMemoryAllocator * memory_allocator,int device_ordinal,absl::Span<ExecutionInput const> arguments)130 StatusOr<std::vector<MaybeOwningDeviceMemory>> CpuExecutable::CreateBufferTable(
131     se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
132     absl::Span<ExecutionInput const> arguments) {
133   std::vector<MaybeOwningDeviceMemory> buffers(
134       assignment_->Allocations().size());
135   VLOG(3) << "Allocating " << assignment_->Allocations().size()
136           << " allocations for module " << module().name();
137   for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
138        ++i) {
139     const BufferAllocation& allocation = assignment_->GetAllocation(i);
140     TF_ASSIGN_OR_RETURN(
141         buffers[i], MemoryForAllocation(allocation, arguments, memory_allocator,
142                                         device_ordinal));
143   }
144 
145   if (VLOG_IS_ON(3)) {
146     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
147                         assignment_->GetUniqueTopLevelOutputSlice());
148     VLOG(3) << "result index: " << result_slice.index();
149   }
150   return std::move(buffers);
151 }
152 
ExecuteComputeFunction(const ExecutableRunOptions * run_options,absl::Span<MaybeOwningDeviceMemory const> buffers,HloExecutionProfile * hlo_execution_profile)153 Status CpuExecutable::ExecuteComputeFunction(
154     const ExecutableRunOptions* run_options,
155     absl::Span<MaybeOwningDeviceMemory const> buffers,
156     HloExecutionProfile* hlo_execution_profile) {
157   uint64_t start_micros = tensorflow::Env::Default()->NowMicros();
158 
159   size_t profile_counters_size =
160       hlo_execution_profile ? hlo_execution_profile->profile_counters().size()
161                             : 0;
162   int64_t* profile_counters =
163       hlo_execution_profile
164           ? hlo_execution_profile->mutable_profile_counters()->data()
165           : nullptr;
166 
167   // Call the computation function following the calling convention. See the
168   // definition of 'ComputeFunctionType' for the details of the calling
169   // convention of JITed functions.
170   std::vector<void*> buffer_pointers;
171   for (auto& buffer : buffers) {
172     buffer_pointers.push_back(
173         const_cast<void*>(buffer.AsDeviceMemoryBase().opaque()));
174   }
175 
176   VLOG(3) << "Executing compute function:";
177   VLOG(3) << absl::StrFormat("  Number of buffer table entries: %u",
178                              buffer_pointers.size());
179   auto ptr_printer = [](std::string* out, const void* p) {
180     absl::StrAppend(out, absl::StrFormat("%p", p));
181   };
182   VLOG(3) << absl::StrFormat("  Buffer table: [%s]",
183                              absl::StrJoin(buffer_pointers, ", ", ptr_printer));
184   VLOG(3) << absl::StrFormat("  Number of profile counters: %u",
185                              profile_counters_size);
186   VLOG(3) << absl::StrFormat("  Profile counters: %p", profile_counters);
187 
188   XlaCustomCallStatus status;
189   // For the entry computation (like all global computations), all inputs and
190   // outputs are in the buffer table, and both the result pointer and args array
191   // pointers are unused (so we set them to 'nullptr').
192   compute_function_(nullptr, run_options, nullptr, buffer_pointers.data(),
193                     &status, profile_counters);
194 
195   uint64_t end_micros = tensorflow::Env::Default()->NowMicros();
196 
197   if (run_options->execution_profile()) {
198     const double nanoseconds = (end_micros - start_micros) * 1000.0;
199     run_options->execution_profile()->set_compute_time_ns(
200         std::max(nanoseconds, 1.0));
201     // If hlo profiling was disabled then the cycle count is left empty.
202     if (hlo_execution_profile) {
203       run_options->execution_profile()->set_compute_cycle_count(
204           hlo_execution_profile->total_cycles_executed(
205               *module().entry_computation()));
206     }
207   }
208 
209   std::optional<absl::string_view> error_message =
210       CustomCallStatusGetMessage(&status);
211   if (error_message) {
212     return InternalError("CustomCall failed: %s", *error_message);
213   }
214 
215   return OkStatus();
216 }
217 
CreateResultShapedBuffer(const ServiceExecutableRunOptions * run_options,absl::Span<MaybeOwningDeviceMemory> buffers,absl::Span<ExecutionInput> arguments)218 StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
219     const ServiceExecutableRunOptions* run_options,
220     absl::Span<MaybeOwningDeviceMemory> buffers,
221     absl::Span<ExecutionInput> arguments) {
222   se::Stream* stream = run_options->stream();
223   ExecutionOutput result(/*on_device_shape=*/result_shape(),
224                          run_options->allocator(),
225                          stream->parent()->device_ordinal());
226   const HloInputOutputAliasConfig& input_output_alias =
227       module().input_output_alias_config();
228   HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
229   const Shape& root_shape = root->shape();
230 
231   // Move se::OwningDeviceMemory values which contain the array(s) of the result
232   // into the respective location in ScopedShapedBuffer which is returned to the
233   // caller.
234   for (auto& p : result.MutableResult()->buffers()) {
235     const ShapeIndex& index = p.first;
236     se::DeviceMemoryBase& result_buffer = p.second;
237     const HloValueSet& sources = this->GetRootValueSet().element(index);
238     // The points to set is unambiguous so the set should be a
239     // singleton.
240     CHECK_EQ(1, sources.values().size());
241     const HloValue* value_source = sources.values()[0];
242     HloInstruction* src = value_source->instruction();
243 
244     // The source for this result buffer can be a nested buffer such as
245     // a tuple element.
246     TF_ASSIGN_OR_RETURN(
247         const BufferAllocation::Slice slice,
248         this->assignment_->GetUniqueSlice(src, value_source->index()));
249     const BufferAllocation::Index buffer_index = slice.index();
250 
251     // TODO(cheshire): duplication with other backends.
252     std::optional<HloInputOutputAliasConfig::Alias> alias =
253         input_output_alias.GetAliasedParameter(index);
254     if (alias) {
255       CHECK_LT(alias->parameter_number, arguments.size());
256       ExecutionInput& input = arguments[alias->parameter_number];
257       MaybeOwningDeviceMemory* maybe_owning_memory =
258           input.MutableBuffer(alias->parameter_index);
259       if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
260         return InvalidArgument(
261             "An input was configured to be must-alias at "
262             "compile time but not donated at runtime: %s",
263             alias->ToString());
264       }
265       if (std::optional<se::OwningDeviceMemory> owning =
266               maybe_owning_memory->Release()) {
267         // If the caller passes the ownership of the device memory, reuse it
268         // as the output buffer. It is up to the caller whether or not to
269         // donate a buffer; the aliasing information describes which buffers
270         // may alias, not buffers that must alias.
271         se::DeviceMemoryBase argument_buffer = owning->Release();
272         *maybe_owning_memory = argument_buffer;
273         result_buffer = argument_buffer;
274         // The caller is giving us the
275         // input buffer, but in case of error of the execute call, we should
276         // not be releasing it as it contains valid data (for example, it is a
277         // parameter which the user wants us to alias, in a gradient update
278         // computation). So we store the index into the result in the aliased
279         // vactor, which will be fed to the ExecutionOutput, which will be
280         // using the indices to drop the addresses from its own
281         // ScopedShapedBuffer result, if the ExecutionOutput is not committed.
282         result.AddAliasedIndex(index);
283       } else {
284         VLOG(3) << "Using copy-protection: aliasing is specified, but the "
285                    "buffer is not donated; allocating a fresh buffer";
286         int64_t allocation_size =
287             ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(root_shape, index));
288         TF_ASSIGN_OR_RETURN(
289             se::OwningDeviceMemory allocated_buffer,
290             run_options->allocator()->Allocate(
291                 stream->parent()->device_ordinal(), allocation_size));
292         result_buffer = allocated_buffer.Release();
293         MaybeOwningDeviceMemory& registered_buffer = buffers[buffer_index];
294         CHECK_EQ(result_buffer.size(),
295                  registered_buffer.AsDeviceMemoryBase().size());
296         std::memcpy(/*dest=*/result_buffer.opaque(),
297                     /*src=*/registered_buffer.AsDeviceMemoryBase().opaque(),
298                     /*n=*/result_buffer.size());
299         registered_buffer = result_buffer;
300       }
301     }
302 
303     if (result_buffer.is_null()) {
304       MaybeOwningDeviceMemory& buffer = buffers[buffer_index];
305       if (std::optional<se::OwningDeviceMemory> owned_buffer =
306               buffer.Release()) {
307         result_buffer = owned_buffer->Release();
308         buffer = result_buffer;
309       } else {
310         result_buffer = buffer.AsDeviceMemoryBase();
311         result.AddAliasedIndex(index);
312       }
313     }
314   }
315   return std::move(result);
316 }
317 
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)318 StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
319     const ServiceExecutableRunOptions* run_options,
320     std::vector<ExecutionInput> arguments,
321     HloExecutionProfile* hlo_execution_profile) {
322   if (GetRootValueSet().IsAmbiguous()) {
323     return Unimplemented("Points-to set of root instruction is ambiguous");
324   }
325 
326   if (hlo_module_) {
327     const HloComputation* entry_comp = hlo_module_->entry_computation();
328     CHECK_EQ(entry_comp->num_parameters(), arguments.size())
329         << "Wrong number of arguments passed when running executable";
330     for (int64_t i = 0; i < entry_comp->num_parameters(); ++i) {
331       const Shape& expected_shape =
332           entry_comp->parameter_instruction(i)->shape();
333       const Shape& actual_shape = arguments[i].Buffers().shape();
334       TF_RET_CHECK(
335           ShapeUtil::DynamicShapeIsCompatible(actual_shape, expected_shape))
336           << "Shape mismatch on argument " << i << ", "
337           << expected_shape.ToString(/*print_layout=*/true) << " vs. "
338           << actual_shape.ToString(/*print_layout=*/true);
339     }
340   }
341 
342   auto* host_stream = dynamic_cast<se::host::HostStream*>(
343       run_options->stream()->implementation());
344   se::Stream* stream = run_options->stream();
345   se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
346   TF_ASSIGN_OR_RETURN(
347       std::vector<MaybeOwningDeviceMemory> buffers,
348       CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
349                         arguments));
350 
351   TF_ASSIGN_OR_RETURN(
352       ExecutionOutput result,
353       CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
354                                absl::MakeSpan(arguments)));
355 
356   // Logically we want this lambda to capture `buffers` by move, ultimately our
357   // functor needs to be wrapped in an std::function, and that requires its
358   // functor to be copyable.  Thus we perpetrate the hack of capturing buffers
359   // "by shared pointer".
360   //
361   // We also need to change the types of some of the variables we capture:
362   // run_options needs to change from a pointer to a value type, and arguments
363   // needs to change from a Span into a vector.  We use a struct instead
364   // of a lambda to make this explicit.
365   struct AsyncRunTask {
366     CpuExecutable* executable;
367     ServiceExecutableRunOptions run_options;
368     std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
369     HloExecutionProfile* hlo_execution_profile;
370 
371     Status operator()() {
372       return executable->ExecuteComputeFunction(
373           &run_options.run_options(), *task_buffers, hlo_execution_profile);
374     }
375   };
376   host_stream->EnqueueTaskWithStatus(
377       AsyncRunTask{this, *run_options,
378                    std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
379                        std::move(buffers)),
380                    hlo_execution_profile});
381 
382   MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
383   return std::move(result);
384 }
385 
ShapeSizeBytes(const Shape & shape)386 /*static*/ int64_t CpuExecutable::ShapeSizeBytes(const Shape& shape) {
387   // On the cpu, opaques are pointers.
388   if (shape.IsOpaque()) {
389     return sizeof(void*);
390   }
391   if (shape.is_static() || shape.IsTuple()) {
392     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
393   }
394   // Each dynamic dimension size is represented as a S32.
395   int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size();
396   return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size;
397 }
398 
GetRootValueSet() const399 const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
400   return assignment_->dataflow_analysis().GetInstructionValueSet(
401       module().entry_computation()->root_instruction());
402 }
403 
SizeOfGeneratedCodeInBytes() const404 int64_t CpuExecutable::SizeOfGeneratedCodeInBytes() const {
405   return jit_->SizeOfGeneratedCodeInBytes();
406 }
407 
408 }  // namespace cpu
409 }  // namespace xla
410