xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 
25 #include "tensorflow/compiler/xla/util.h"
26 
27 #define EIGEN_USE_THREADS
28 
29 #include "absl/base/thread_annotations.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/synchronization/mutex.h"
33 #include "absl/types/span.h"
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #include "tensorflow/compiler/xla/client/executable_build_options.h"
36 #include "tensorflow/compiler/xla/client/xla_computation.h"
37 #include "tensorflow/compiler/xla/layout.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
40 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
41 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
42 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
43 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h"
44 #include "tensorflow/compiler/xla/pjrt/utils.h"
45 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
46 #include "tensorflow/compiler/xla/primitive_util.h"
47 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
48 #include "tensorflow/compiler/xla/service/computation_placer.h"
49 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_xfeed.h"
51 #include "tensorflow/compiler/xla/service/dump.h"
52 #include "tensorflow/compiler/xla/service/executable.h"
53 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
54 #include "tensorflow/compiler/xla/shape.h"
55 #include "tensorflow/compiler/xla/statusor.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/denormal.h"
58 #include "tensorflow/core/platform/setround.h"
59 #include "tensorflow/core/profiler/lib/connected_traceme.h"
60 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
61 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
62 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
63 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
64 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
65 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
66 
67 namespace xla {
68 namespace {
69 
70 // A RAII helper class used to set an AsyncValueRef<CpuEvent> to a ready state
71 // upon destruction. In many cases in PjRt implementation, there will be
72 // multiple return statements in the function, all of which require setting some
73 // AsyncValueRef<CpuEvent> to be ready. This class could make such code more
74 // robust by using setting the AsyncValue in the destructor.
75 class MarkEventReadyOnExit {
76  public:
MarkEventReadyOnExit(tfrt::AsyncValueRef<CpuEvent> event)77   explicit MarkEventReadyOnExit(tfrt::AsyncValueRef<CpuEvent> event)
78       : event_(std::move(event)) {}
79 
80   MarkEventReadyOnExit(const MarkEventReadyOnExit&) = delete;
81   MarkEventReadyOnExit& operator=(const MarkEventReadyOnExit&) = delete;
82   MarkEventReadyOnExit(MarkEventReadyOnExit&&) = default;
83   MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) = default;
84 
~MarkEventReadyOnExit()85   ~MarkEventReadyOnExit() {
86     if (event_) event_.SetStateConcrete();
87   }
88 
Release()89   tfrt::AsyncValueRef<CpuEvent> Release() && { return std::move(event_); }
90 
91  private:
92   tfrt::AsyncValueRef<CpuEvent> event_;
93 };
94 
95 }  // namespace
96 
97 static const char kCpuPlatformName[] = "cpu";
98 static constexpr size_t kSmallDataTransferByteSize = 102400;  // 100 KiB
99 
GetOrCreateReadyEvent(tfrt::HostContext * host_context)100 static tfrt::AsyncValueRef<CpuEvent> GetOrCreateReadyEvent(
101     tfrt::HostContext* host_context) {
102   static const auto* ready_event = new tfrt::AsyncValueRef<CpuEvent>(
103       tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_context));
104   return ready_event->CopyRef();
105 }
106 
TfrtCpuDevice(int id,bool asynchronous)107 TfrtCpuDevice::TfrtCpuDevice(int id, bool asynchronous)
108     : id_(id),
109       max_inflight_computations_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
110   debug_string_ = absl::StrCat("TFRT_CPU_", id);
111   to_string_ = absl::StrCat("CpuDevice(id=", id, ")");
112 }
113 
device_kind() const114 absl::string_view TfrtCpuDevice::device_kind() const {
115   return kCpuPlatformName;
116 }
117 
DebugString() const118 absl::string_view TfrtCpuDevice::DebugString() const { return debug_string_; }
119 
ToString() const120 absl::string_view TfrtCpuDevice::ToString() const { return to_string_; }
121 
TransferToInfeed(const LiteralSlice & literal)122 Status TfrtCpuDevice::TransferToInfeed(const LiteralSlice& literal) {
123   return TransferLiteralToInfeedOnCpu(local_hardware_id(), literal);
124 }
125 
TransferFromOutfeed(MutableBorrowingLiteral literal)126 Status TfrtCpuDevice::TransferFromOutfeed(MutableBorrowingLiteral literal) {
127   return TransferLiteralFromOutfeedOnCpu(local_hardware_id(), literal);
128 }
129 
CpuDeviceCount()130 static int CpuDeviceCount() {
131   // By default we fix the number of devices to one.  However we do let the user
132   // override this behavior to help run tests on the host that run models in
133   // parallel across multiple devices, e.g. pmap.
134   return GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
135 }
136 
GetTfrtCpuDevices(bool asynchronous,int cpu_device_count)137 static StatusOr<std::vector<std::unique_ptr<TfrtCpuDevice>>> GetTfrtCpuDevices(
138     bool asynchronous, int cpu_device_count) {
139   std::vector<std::unique_ptr<TfrtCpuDevice>> devices;
140   for (int i = 0; i < cpu_device_count; ++i) {
141     auto device = std::make_unique<TfrtCpuDevice>(
142         /*id=*/i, asynchronous);
143     devices.push_back(std::move(device));
144   }
145   return std::move(devices);
146 }
147 
GetTfrtCpuClient(bool asynchronous,int cpu_device_count)148 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous,
149                                                        int cpu_device_count) {
150   // TODO(zhangqiaorjc): Allow users set the number of threads.
151   // `num_blocking_threads=16` is picked arbitrarily for now.
152   // Need at least CpuDeviceCount threads to launch one collective.
153   int num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count);
154   auto host_context = std::make_unique<tfrt::HostContext>(
155       [](const tfrt::DecodedDiagnostic& diag) {
156         LOG(ERROR) << "Encountered runtime error: " << diag.message << "\n";
157       },
158       tfrt::CreateMallocAllocator(),
159       tfrt::CreateMultiThreadedWorkQueue(
160           /*num_threads=*/num_threads,
161           /*num_blocking_threads=*/16));
162 
163   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
164                       GetTfrtCpuDevices(asynchronous, cpu_device_count));
165 
166   return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>(
167       /*process_index=*/0, std::move(devices), std::move(host_context)));
168 }
169 
GetTfrtCpuClient(bool asynchronous)170 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous) {
171   return GetTfrtCpuClient(asynchronous, CpuDeviceCount());
172 }
173 
TfrtCpuClient(int process_index,std::vector<std::unique_ptr<TfrtCpuDevice>> devices,std::unique_ptr<tfrt::HostContext> host_ctx)174 TfrtCpuClient::TfrtCpuClient(
175     int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
176     std::unique_ptr<tfrt::HostContext> host_ctx)
177     : process_index_(process_index),
178       owned_devices_(std::move(devices)),
179       host_ctx_(std::move(host_ctx)),
180       computation_placer_(std::make_unique<ComputationPlacer>()),
181       eigen_intraop_pool_(new tensorflow::thread::ThreadPool(
182           tensorflow::Env::Default(), "XLAEigen", DefaultThreadPoolSize())),
183       eigen_intraop_device_(
184           new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
185                                       eigen_intraop_pool_->NumThreads())),
186       last_collective_launch_event_(
187           tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_ctx_.get())),
188       transpose_cache_(1024) {
189   for (const std::unique_ptr<TfrtCpuDevice>& device : owned_devices_) {
190     devices_.push_back(device.get());
191     CHECK(id_to_device_.insert({device->id(), device.get()}).second)
192         << "Duplicate device id: " << device->id();
193 
194     device->SetClient(this);
195     if (device->IsAddressable()) {
196       int idx = device->local_hardware_id();
197       if (idx >= addressable_devices_.size()) {
198         addressable_devices_.resize(idx + 1);
199       }
200       CHECK(addressable_devices_[idx] == nullptr) << idx;
201       addressable_devices_[idx] = device.get();
202     }
203   }
204   for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
205     CHECK(addressable_devices_[idx] != nullptr) << idx;
206   }
207   LOG(INFO) << "TfrtCpuClient created.";
208 }
209 
~TfrtCpuClient()210 TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; }
211 
LookupDevice(int device_id) const212 StatusOr<PjRtDevice*> TfrtCpuClient::LookupDevice(int device_id) const {
213   auto it = id_to_device_.find(device_id);
214   if (it != id_to_device_.end()) {
215     return it->second;
216   }
217   return InvalidArgument("No matching device found for device_id %d",
218                          device_id);
219 }
220 
LookupAddressableDevice(int local_hardware_id) const221 StatusOr<PjRtDevice*> TfrtCpuClient::LookupAddressableDevice(
222     int local_hardware_id) const {
223   for (auto* device : addressable_devices_) {
224     if (local_hardware_id == device->local_hardware_id()) {
225       return device;
226     }
227   }
228   return InvalidArgument("No matching device found for local_hardware_id %d",
229                          local_hardware_id);
230 }
231 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const232 StatusOr<DeviceAssignment> TfrtCpuClient::GetDefaultDeviceAssignment(
233     int num_replicas, int num_partitions) const {
234   return computation_placer_->AssignDevices(num_replicas, num_partitions);
235 }
236 
GetHloCostAnalysis()237 StatusOr<std::unique_ptr<HloCostAnalysis>> TfrtCpuClient::GetHloCostAnalysis() {
238   return std::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes);
239 }
240 
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const241 StatusOr<std::optional<std::string>> TfrtCpuClient::ExecutableFingerprint(
242     const PjRtLoadedExecutable& executable) const {
243   return std::optional<std::string>();
244 }
245 
JitCompile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options,const ExecutionOptions & execution_options)246 static StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
247     const XlaComputation& computation,
248     const absl::Span<const Shape* const> argument_layouts,
249     const ExecutableBuildOptions& build_options,
250     const ExecutionOptions& execution_options) {
251   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
252                       computation.GetProgramShape());
253   // Unoptimized HloModuleConfig.
254   TF_ASSIGN_OR_RETURN(
255       std::unique_ptr<HloModuleConfig> hlo_module_config,
256       CreateModuleConfig(program_shape, argument_layouts, &execution_options,
257                          execution_options.num_replicas(),
258                          /*num_threads=*/std::nullopt,
259                          /*aot_options=*/nullptr));
260 
261   // Unoptimized HloModule.
262   const xla::HloModuleProto& hlo_module_proto = computation.proto();
263   TF_ASSIGN_OR_RETURN(
264       std::unique_ptr<HloModule> hlo_module,
265       xla::HloModule::CreateFromProto(hlo_module_proto, *hlo_module_config));
266   VLOG(3) << "Unoptimized HLO module: " << hlo_module->ToString();
267   static constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
268   DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName);
269 
270   // Run Hlo Passes
271   cpu::CpuCompiler compiler;
272   xla::Compiler::CompileOptions dummy;
273   TF_ASSIGN_OR_RETURN(hlo_module,
274                       compiler.RunHloPasses(std::move(hlo_module),
275                                             /*stream_exec=*/nullptr, dummy));
276 
277   // Run backend.
278   return compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
279                              dummy);
280 }
281 
282 // Find the root instruction of the entry computation.
GetRootValueSet(const BufferAssignment & assignment,const HloModule & module)283 static const InstructionValueSet& GetRootValueSet(
284     const BufferAssignment& assignment, const HloModule& module) {
285   return assignment.dataflow_analysis().GetInstructionValueSet(
286       module.entry_computation()->root_instruction());
287 }
288 
289 // Buffer table is indexed by buffer allocation indices. The output buffer is
290 // made up of a subset of those buffer allocations (for tuple, it includes tuple
291 // index table). This helper finds the buffer allocation indices in buffer
292 // assignment that make up for the output buffer. It is used by
293 // CreateResultShapedBuffer to reconstruct the output buffer from the buffer
294 // table allocated by MemoryForAllocation.
295 static StatusOr<absl::InlinedVector<BufferAllocation::Index, 4>>
FindResultBufferAllocationIndex(const BufferAssignment & assignment,const HloModule & module)296 FindResultBufferAllocationIndex(const BufferAssignment& assignment,
297                                 const HloModule& module) {
298   absl::InlinedVector<BufferAllocation::Index, 4> buffer_indices;
299   const InstructionValueSet& root_value_set =
300       GetRootValueSet(assignment, module);
301   const Shape& result_shape = module.result_shape();
302   if (!result_shape.IsTuple()) {
303     // Find the buffer allocation that corresponds to the output buffer.
304     const HloValueSet& sources = root_value_set.element({});
305     // The points to set is unambiguous so the set should be a singleton.
306     CHECK_EQ(1, sources.values().size());
307     const HloValue* value_source = sources.values()[0];
308     HloInstruction* src = value_source->instruction();
309     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
310                         assignment.GetUniqueSlice(src, value_source->index()));
311     const BufferAllocation::Index buffer_index = slice.index();
312     buffer_indices.push_back(buffer_index);
313     return {std::move(buffer_indices)};
314   }
315   buffer_indices.reserve(result_shape.tuple_shapes_size());
316   for (int i = 0; i < result_shape.tuple_shapes_size(); ++i) {
317     // Find the buffer allocations that corresponds to the output tuple,
318     // including the tuple index table.
319     const HloValueSet& sources = root_value_set.element({i});
320     // The points to set is unambiguous so the set should be a singleton.
321     CHECK_EQ(1, sources.values().size());
322     const HloValue* value_source = sources.values()[0];
323     HloInstruction* src = value_source->instruction();
324     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
325                         assignment.GetUniqueSlice(src, value_source->index()));
326     const BufferAllocation::Index buffer_index = slice.index();
327     buffer_indices.push_back(buffer_index);
328   }
329   return {std::move(buffer_indices)};
330 }
331 
Compile(const XlaComputation & computation,CompileOptions options)332 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
333     const XlaComputation& computation, CompileOptions options) {
334   tensorflow::profiler::TraceMe traceme("TfrtCpuClient::Compile");
335   ExecutableBuildOptions& build_options = options.executable_build_options;
336 
337   int num_replicas;
338   int num_partitions;
339   std::shared_ptr<DeviceAssignment> device_assignment;
340   TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
341       options.compile_portable_executable, &options.executable_build_options,
342       [this](int num_replicas, int num_partitions) {
343         return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
344       },
345       &num_replicas, &num_partitions, &device_assignment));
346 
347   std::vector<const Shape*> argument_layout_pointers;
348   TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
349       computation, &LayoutUtil::GetWithDefaultLayout, options.argument_layouts,
350       &options.executable_build_options, &argument_layout_pointers));
351 
352   std::vector<PjRtLoadedExecutable::LogicalDeviceIds>
353       addressable_device_logical_ids;
354   std::vector<PjRtDevice*> addressable_devices;
355   if (device_assignment != nullptr) {
356     addressable_device_logical_ids.reserve(num_replicas * num_partitions);
357     addressable_devices.reserve(num_replicas * num_partitions);
358     for (int replica = 0; replica < num_replicas; ++replica) {
359       for (int partition = 0; partition < num_partitions; ++partition) {
360         int device_id = (*device_assignment)(replica, partition);
361         TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
362         if (device->process_index() != process_index()) {
363           VLOG(3) << "Non-local device: " << device_id;
364           continue;
365         }
366         PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
367         logica_device_ids.replica = replica;
368         logica_device_ids.partition = partition;
369         addressable_device_logical_ids.push_back(std::move(logica_device_ids));
370         addressable_devices.push_back(device);
371       }
372     }
373     if (addressable_devices.empty()) {
374       return InvalidArgument(
375           "Device assignment (%s) does not have any local devices.",
376           device_assignment->ToString());
377     }
378 
379     if (build_options.device_ordinal() < 0) {
380       build_options.set_device_ordinal(
381           addressable_devices.front()->local_hardware_id());
382     }
383   }
384 
385   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
386                       computation.GetProgramShape());
387   ExecutionOptions execution_options =
388       CreateExecutionOptions(build_options, &program_shape);
389   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> cpu_executable,
390                       JitCompile(computation, argument_layout_pointers,
391                                  build_options, execution_options));
392   auto cpu_executable_ptr =
393       tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable.get());
394 
395   // `buffer_table[result_slice.index()]` points to result buffer:
396   // If output is a tuple, it points to the buffer index table.
397   // If output is a non-tuple, it points to the buffer itself.
398   TF_ASSIGN_OR_RETURN(
399       const BufferAllocation::Slice result_slice,
400       cpu_executable_ptr->buffer_assignment().GetUniqueTopLevelOutputSlice());
401 
402   // `result_buffer_indices` has the buffer allocation indices that make up the
403   // output buffer (could be tuple).
404   TF_ASSIGN_OR_RETURN(
405       auto result_buffer_indices,
406       FindResultBufferAllocationIndex(cpu_executable_ptr->buffer_assignment(),
407                                       cpu_executable->module()));
408 
409   auto executable = std::make_unique<TfrtCpuExecutable>(
410       num_replicas, num_partitions, std::move(device_assignment),
411       options.parameter_is_tupled_arguments, std::move(cpu_executable),
412       result_slice.index(), std::move(result_buffer_indices),
413       std::move(addressable_device_logical_ids), std::move(addressable_devices),
414       this);
415   TF_RETURN_IF_ERROR(
416       executable->SetUpDonation(options.parameter_is_tupled_arguments));
417 
418   return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
419 }
420 
Compile(mlir::ModuleOp module,CompileOptions options)421 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
422     mlir::ModuleOp module, CompileOptions options) {
423   XlaComputation xla_computation;
424   TF_RETURN_IF_ERROR(MlirToXlaComputation(
425       module, xla_computation,
426       /*use_tuple_args=*/options.parameter_is_tupled_arguments,
427       /*return_tuple=*/false));
428   return Compile(xla_computation, options);
429 }
430 
AllocateDestinationBuffer(const Shape & on_device_shape,absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>,4> definition_events,TfrtCpuDevice * device,TfrtCpuClient * client)431 StatusOr<std::unique_ptr<TfrtCpuBuffer>> AllocateDestinationBuffer(
432     const Shape& on_device_shape,
433     absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
434     TfrtCpuDevice* device, TfrtCpuClient* client) {
435   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
436   if (!on_device_shape.IsTuple()) {
437     size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape);
438     TF_ASSIGN_OR_RETURN(auto device_buffer,
439                         MaybeOwningCpuMemory::AllocateShared(byte_size));
440     buffers.push_back(std::move(device_buffer));
441     return std::make_unique<TfrtCpuBuffer>(
442         on_device_shape,
443         std::make_unique<TrackedTfrtCpuDeviceBuffer>(
444             /*is_tuple=*/false, std::move(buffers),
445             std::move(definition_events)),
446         client, device);
447   }
448   // Tuple case.
449   buffers.reserve(on_device_shape.tuple_shapes().size());
450   for (const auto& leaf_shape : on_device_shape.tuple_shapes()) {
451     size_t byte_size = ShapeUtil::ByteSizeOf(leaf_shape);
452     TF_ASSIGN_OR_RETURN(auto device_buffer,
453                         MaybeOwningCpuMemory::AllocateShared(byte_size));
454     buffers.push_back(std::move(device_buffer));
455   }
456   return std::make_unique<TfrtCpuBuffer>(
457       on_device_shape,
458       std::make_unique<TrackedTfrtCpuDeviceBuffer>(
459           /*is_tuple=*/true, std::move(buffers), std::move(definition_events)),
460       client, device);
461 }
462 
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)463 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateViewOfDeviceBuffer(
464     void* device_ptr, const Shape& shape, PjRtDevice* device,
465     std::function<void()> on_delete_callback) {
466   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
467   size_t byte_size = ShapeUtil::ByteSizeOf(shape);
468   auto non_owning_buffer =
469       std::make_shared<MaybeOwningCpuMemory>(device_ptr, byte_size);
470   buffers.push_back(std::move(non_owning_buffer));
471   auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
472       /*is_tuple=*/false, std::move(buffers),
473       /*definition_event=*/tfrt::MakeAvailableAsyncValueRef<CpuEvent>(),
474       std::move(on_delete_callback));
475   return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
476       shape, std::move(tracked_device_buffer), this,
477       tensorflow::down_cast<TfrtCpuDevice*>(device)));
478 }
479 
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)480 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateUninitializedBuffer(
481     const Shape& shape, PjRtDevice* device) {
482   tensorflow::profiler::TraceMe traceme(
483       "TfrtCpuClient::CreateUninitializedBuffer");
484   VLOG(1) << "TfrtCpuClient::CreateUninitializedBuffer: shape: "
485           << shape.DebugString() << " device: " << device->DebugString();
486   return AllocateDestinationBuffer(
487       shape, /*definition_events=*/{},
488       tensorflow::down_cast<TfrtCpuDevice*>(device), this);
489 }
490 
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)491 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostBuffer(
492     const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
493     std::optional<absl::Span<int64_t const>> byte_strides,
494     HostBufferSemantics host_buffer_semantics,
495     std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
496   tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostBuffer");
497   Shape shape = ShapeUtil::MakeShape(type, dims);
498   VLOG(2) << "TfrtCpuClient::BufferFromHostBuffer: shape: " << shape.ToString()
499           << " device: " << device->DebugString();
500   bool has_default_layout =
501       !byte_strides || HasMajorToMinorLayout(type, dims, *byte_strides);
502   // If the input buffer has a default layout and is sufficiently aligned, we
503   // can simply point to the input array's data without any further copies. At
504   // the time of writing we require a 16-byte alignment because XLA may generate
505   // code which requires it.
506   bool can_use_zero_copy =
507       has_default_layout &&
508       host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
509       ((absl::bit_cast<std::uintptr_t>(data) &
510         (cpu_function_runtime::MinAlign() - 1)) == 0);
511   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
512   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
513   std::function<void()> on_delete_callback;
514   size_t byte_size = ShapeUtil::ByteSizeOf(shape);
515   if (can_use_zero_copy) {
516     auto device_buffer = std::make_shared<MaybeOwningCpuMemory>(
517         const_cast<void*>(data), byte_size);
518     buffers.push_back(std::move(device_buffer));
519     on_delete_callback = std::move(on_done_with_host_buffer);
520   } else {
521     TF_ASSIGN_OR_RETURN(auto device_buffer,
522                         MaybeOwningCpuMemory::AllocateShared(byte_size));
523     auto dst_data_ptr = device_buffer->data();
524     buffers.push_back(device_buffer);
525     if (!has_default_layout) {
526       // If the input array does not have a major-to-minor layout, transpose it
527       // into major-to-minor layout. Currently we choose to always do this
528       // synchronously.
529       // TODO(phawkins): consider performing the transpose asynchronously.
530       // TODO(phawkins): parallelize the transpose.
531       std::shared_ptr<TransposePlan> transpose;
532       {
533         absl::InlinedVector<int64_t, 4> permutation(dims.size());
534         absl::c_iota(permutation, 0);
535         absl::MutexLock lock(&transpose_mu_);
536         TF_ASSIGN_OR_RETURN(
537             transpose, transpose_cache_.GetOrCreate(
538                            primitive_util::ByteWidth(type), dims, permutation,
539                            TransposePlan::Striding{*byte_strides}));
540       }
541       transpose->Execute(data, dst_data_ptr);
542       if (on_done_with_host_buffer) {
543         on_done_with_host_buffer();
544         on_done_with_host_buffer = nullptr;
545       }
546     } else {
547       bool should_sync_copy =
548           host_buffer_semantics ==
549               HostBufferSemantics::kImmutableOnlyDuringCall ||
550           (byte_size < kSmallDataTransferByteSize);
551       if (should_sync_copy) {
552         std::memcpy(dst_data_ptr, data, byte_size);
553         if (on_done_with_host_buffer) {
554           on_done_with_host_buffer();
555           on_done_with_host_buffer = nullptr;
556         }
557       } else {
558         tfrt::AsyncValueRef<CpuEvent> copy_event =
559             tfrt::MakeConstructedAsyncValueRef<CpuEvent>(host_ctx_.get());
560         definition_events.push_back(copy_event.CopyRef());
561         tfrt::EnqueueWork(
562             host_ctx_.get(),
563             [device_buffer = std::move(device_buffer), dst_data_ptr, data,
564              byte_size, copy_event = std::move(copy_event),
565              on_done_with_host_buffer =
566                  std::move(on_done_with_host_buffer)]() mutable {
567               tensorflow::profiler::TraceMe traceme("H2D Dispatch");
568               std::memcpy(dst_data_ptr, data, byte_size);
569               if (on_done_with_host_buffer) {
570                 on_done_with_host_buffer();
571                 on_done_with_host_buffer = nullptr;
572               }
573               // Signal copy is complete.
574               copy_event.SetStateConcrete();
575             });
576       }
577     }
578   }
579   auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
580       /*is_tuple=*/false, std::move(buffers), std::move(definition_events),
581       std::move(on_delete_callback));
582   return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
583       shape, std::move(tracked_device_buffer), this,
584       tensorflow::down_cast<TfrtCpuDevice*>(device)));
585 }
586 
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)587 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostLiteral(
588     const LiteralSlice& literal, PjRtDevice* device) {
589   tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostLiteral");
590   VLOG(1) << "TfrtCpuClient::BufferFromHostLiteral: shape: "
591           << literal.shape().DebugString()
592           << " device: " << device->DebugString();
593   const Shape& shape = literal.shape();
594 
595   // Add a placeholder definition event for each leaf buffer when creating the
596   // buffer. They are set only after h2d dispatch.
597   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
598   absl::InlinedVector<tfrt::RCReference<tfrt::AsyncValue>, 4> avs;
599   int num_leaf_buffers = shape.IsTuple() ? shape.tuple_shapes_size() : 1;
600   for (int i = 0; i < num_leaf_buffers; ++i) {
601     tfrt::AsyncValueRef<CpuEvent> definition_event =
602         tfrt::MakeConstructedAsyncValueRef<CpuEvent>(GetHostContext());
603     definition_events.push_back(definition_event.CopyRef());
604     avs.push_back(std::move(definition_event));
605   }
606   TF_ASSIGN_OR_RETURN(std::unique_ptr<TfrtCpuBuffer> output_buffer,
607                       AllocateDestinationBuffer(
608                           shape, std::move(definition_events),
609                           tensorflow::down_cast<TfrtCpuDevice*>(device), this));
610 
611   auto usage_event = tfrt::MakeAvailableAsyncValueRef<CpuEvent>();
612   auto* device_buffer = output_buffer->AcquireUsage(std::move(usage_event));
613   CHECK(device_buffer);
614   if (!shape.IsTuple()) {
615     // It is OK to capture `buffer` pointer because the `output_buffer` can't be
616     // deleted until all the usage holds have gone away.
617     tfrt::EnqueueWork(GetHostContext(), [literal, av = avs[0].CopyRef(),
618                                          device_buffer, shape]() mutable {
619       tensorflow::profiler::TraceMe traceme("H2D Dispatch");
620       const std::shared_ptr<MaybeOwningCpuMemory>& b =
621           device_buffer->Buffers()[0];
622       CHECK_EQ(literal.size_bytes(), b->size());
623       std::memcpy(b->data(), literal.untyped_data(), b->size());
624       // Signal copy is complete.
625       av->SetStateConcrete();
626     });
627   } else {
628     // For tuple, transfer leaf literal individually in parallel.
629     for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
630       // It is OK to capture `buffer` pointer because the `output_buffer` can't
631       // be deleted until all the usage holds have gone away.
632       tfrt::EnqueueWork(GetHostContext(), [i, literal, av = avs[i].CopyRef(),
633                                            shape, device_buffer]() mutable {
634         tensorflow::profiler::TraceMe traceme("H2D Dispatch");
635         auto slice = LiteralSlice(literal, {i});
636         const std::shared_ptr<MaybeOwningCpuMemory>& b =
637             device_buffer->Buffers()[i];
638         CHECK_EQ(slice.size_bytes(), b->size());
639         std::memcpy(b->data(), slice.untyped_data(), slice.size_bytes());
640         // Signal copy is complete.
641         av->SetStateConcrete();
642       });
643     }
644   }
645   return std::unique_ptr<PjRtBuffer>(std::move(output_buffer));
646 }
647 
TfrtCpuBuffer(Shape on_device_shape,std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,TfrtCpuClient * client,TfrtCpuDevice * device)648 TfrtCpuBuffer::TfrtCpuBuffer(
649     Shape on_device_shape,
650     std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,
651     TfrtCpuClient* client, TfrtCpuDevice* device)
652     : client_(client),
653       on_device_shape_(std::move(on_device_shape)),
654       device_(device),
655       tracked_device_buffer_(std::move(tracked_device_buffer)) {}
656 
~TfrtCpuBuffer()657 TfrtCpuBuffer::~TfrtCpuBuffer() {
658   Delete();
659   CHECK_EQ(external_reference_counter_, 0);
660 }
661 
GetOnDeviceSizeInBytes() const662 StatusOr<size_t> TfrtCpuBuffer::GetOnDeviceSizeInBytes() const {
663   return ShapeUtil::ByteSizeOf(on_device_shape_);
664 }
665 
666 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()667 TfrtCpuBuffer::AcquireExternalReference() {
668   class ScopedExternalReference : public PjRtBuffer::ExternalReference {
669    public:
670     explicit ScopedExternalReference(TfrtCpuBuffer* buffer,
671                                      std::shared_ptr<MaybeOwningCpuMemory> data)
672         : buffer_(buffer), data_(std::move(data)) {
673       DCHECK(data_);
674       data_ptr_ = data_->data();
675     }
676 
677     ~ScopedExternalReference() override { buffer_->DropExternalReference(); }
678 
679    private:
680     TfrtCpuBuffer* buffer_ = nullptr;
681     // Keep a reference to the underlying data used. Note that it is still
682     // users' responsibility to synchronize reads and writes to the data.
683     std::shared_ptr<MaybeOwningCpuMemory> data_;
684   };
685 
686   absl::MutexLock lock(&mu_);
687   if (tracked_device_buffer_ == nullptr) {
688     return InvalidArgument("Buffer has been deleted or donated.");
689   }
690 
691   ++external_reference_counter_;
692 
693   return {std::make_unique<ScopedExternalReference>(
694       this, tracked_device_buffer_->Buffers()[0])};
695 }
696 
697 class TrackedCpuDeviceBufferExternalReference
698     : public PjRtBuffer::ExternalReference {
699  public:
TrackedCpuDeviceBufferExternalReference(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer)700   explicit TrackedCpuDeviceBufferExternalReference(
701       std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer)
702       : tracked_device_buffer_(std::move(tracked_device_buffer)) {
703     data_ptr_ = tracked_device_buffer_->Buffers()[0]->data();
704   }
705 
706   ~TrackedCpuDeviceBufferExternalReference() override = default;
707 
708  private:
709   std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer_;
710 };
711 
712 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)713 TfrtCpuBuffer::ReleaseDeviceMemoryOwnership(
714     bool wait_for_operations_to_complete) {
715   if (on_device_shape_.IsTuple()) {
716     return InvalidArgument(
717         "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
718   }
719   TF_ASSIGN_OR_RETURN(
720       std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,
721       Release(wait_for_operations_to_complete));
722 
723   std::unique_ptr<PjRtBuffer::ExternalReference> ref;
724   if (tracked_device_buffer) {
725     ref = std::make_unique<TrackedCpuDeviceBufferExternalReference>(
726         std::move(tracked_device_buffer));
727   }
728   return ref;
729 }
730 
CommitDonation()731 void TfrtCpuBuffer::CommitDonation() {
732   absl::MutexLock lock(&mu_);
733   CHECK(pending_donation_);
734   CHECK(!tracked_device_buffer_);
735   pending_donation_ = false;
736 }
737 
AbortDonation(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer)738 void TfrtCpuBuffer::AbortDonation(
739     std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer) {
740   absl::MutexLock lock(&mu_);
741   CHECK(pending_donation_);
742   CHECK(!tracked_device_buffer_);
743   pending_donation_ = false;
744   tracked_device_buffer_ = std::move(device_buffer);
745 }
746 
Delete()747 void TfrtCpuBuffer::Delete() {
748   auto device_buffer = ReleaseBufferLocked();
749   if (device_buffer == nullptr) return;
750 
751   // Now that all holds have completed and no more can be added, we can get
752   // the final set of usage events.
753   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events =
754       device_buffer->LockUseAndTransferUsageEvents();
755 
756   std::vector<tfrt::AsyncValue*> event_avs;
757   event_avs.reserve(usage_events.size() + 1);
758   for (auto& event : usage_events) {
759     event_avs.push_back(event.GetAsyncValue());
760   }
761 
762   // We should also wait for the definition event.
763   event_avs.push_back(device_buffer->definition_event().GetAsyncValue());
764 
765   tfrt::RunWhenReady(event_avs,
766                      [device_buffer = std::move(device_buffer)]() mutable {
767                        device_buffer.reset();
768                      });
769 }
770 
IsDeleted()771 bool TfrtCpuBuffer::IsDeleted() {
772   absl::MutexLock lock(&mu_);
773   return tracked_device_buffer_ == nullptr;
774 }
775 
776 std::unique_ptr<TrackedTfrtCpuDeviceBuffer>
ReleaseBufferLocked()777 TfrtCpuBuffer::ReleaseBufferLocked() {
778   absl::MutexLock lock(&mu_);
779   auto condition = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) {
780     return !pending_donation_;
781   };
782   mu_.Await(absl::Condition(&condition));
783   return std::move(tracked_device_buffer_);
784 }
785 
Release(bool wait_for_operations_to_complete)786 StatusOr<std::unique_ptr<TrackedTfrtCpuDeviceBuffer>> TfrtCpuBuffer::Release(
787     bool wait_for_operations_to_complete) {
788   std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer =
789       ReleaseBufferLocked();
790   if (device_buffer == nullptr) return {nullptr};
791 
792   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> events;
793   // Now that all holds have completed and no more can be added, we can get
794   // the final set of usage events.
795   events = device_buffer->LockUseAndTransferUsageEvents();
796 
797   if (wait_for_operations_to_complete) {
798     // Block the host until all usage events have completed. Usage events
799     // dominate definition events, so this also waits for the buffer to be
800     // defined. Return the first error encountered.
801     Status first_error;
802     for (const auto& av : events) {
803       client_->GetHostContext()->Await(av.CopyRCRef());
804       if (auto* error = av.GetErrorIfPresent()) {
805         first_error.Update(InternalError("Error Execute: %s", error->message));
806       }
807     }
808     if (!first_error.ok()) return std::move(first_error);
809   }
810 
811   return device_buffer;
812 }
813 
AcquireUsage(tfrt::AsyncValueRef<CpuEvent> usage_event)814 TrackedTfrtCpuDeviceBuffer* TfrtCpuBuffer::AcquireUsage(
815     tfrt::AsyncValueRef<CpuEvent> usage_event) {
816   absl::MutexLock lock(&mu_);
817   if (!tracked_device_buffer_) {
818     return nullptr;
819   }
820 
821   tracked_device_buffer_->AddUsageEvents(absl::MakeSpan(&usage_event, 1));
822   return tracked_device_buffer_.get();
823 }
824 
AcquireDonation()825 StatusOr<TfrtCpuBuffer::DonationTransaction> TfrtCpuBuffer::AcquireDonation() {
826   absl::MutexLock lock(&mu_);
827 
828   if (tracked_device_buffer_ == nullptr) {
829     return InvalidArgument("Donation requested for invalid buffer");
830   }
831 
832   if (external_reference_counter_ > 0) {
833     return InvalidArgument(
834         "Donation requested for buffer with external reference");
835   }
836 
837   CHECK(!pending_donation_);
838   pending_donation_ = true;
839 
840   // Swap out `tracked_device_buffer_` so that no one can acquire a usage event
841   // after this point.
842   return DonationTransaction(this, std::move(tracked_device_buffer_));
843 }
844 
AsShapedBuffer(int device_ordinal,const Shape & on_device_shape,absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffers)845 static ShapedBuffer AsShapedBuffer(
846     int device_ordinal, const Shape& on_device_shape,
847     absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffers) {
848   ShapedBuffer shaped_buffer(on_device_shape, device_ordinal);
849   ShapeTree<se::DeviceMemoryBase>::iterator iterator =
850       shaped_buffer.buffers().begin();
851   for (const auto& buf : buffers) {
852     CHECK(iterator != shaped_buffer.buffers().end());
853     iterator->second = se::DeviceMemoryBase(buf->data(), buf->size());
854     ++iterator;
855   }
856   CHECK(iterator == shaped_buffer.buffers().end());
857   return shaped_buffer;
858 }
859 
logical_on_device_shape()860 StatusOr<Shape> TfrtCpuBuffer::logical_on_device_shape() {
861   if (on_device_shape_.is_static()) {
862     return on_device_shape_;
863   }
864 
865   auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
866   auto* device_buffer = AcquireUsage(usage_event);
867   if (device_buffer == nullptr) {
868     return InvalidArgument(
869         "logical_on_device_shape() called on deleted or donated buffer");
870   }
871   MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
872 
873   // Wait for the definition event.
874   const auto& av = device_buffer->definition_event();
875   client_->GetHostContext()->Await(av.CopyRCRef());
876   if (auto* error = av.GetErrorIfPresent()) {
877     return InternalError("Error Execute: %s", error->message);
878   }
879 
880   ShapedBuffer shaped_buffer = AsShapedBuffer(
881       device_->local_hardware_id(), on_device_shape_, device_buffer->Buffers());
882   Shape ret_shape = on_device_shape_;
883   TF_RETURN_IF_ERROR(ReadDynamicShapesOnCpu(
884       &shaped_buffer, &ret_shape, cpu::CpuExecutable::ShapeSizeBytes));
885   return ret_shape;
886 }
887 
GetAsyncValues(absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events)888 static std::vector<tfrt::RCReference<tfrt::AsyncValue>> GetAsyncValues(
889     absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events) {
890   std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs;
891   avs.reserve(events.size());
892   for (const auto& ev : events) {
893     avs.push_back(ev.CopyRCRef());
894   }
895   return avs;
896 }
897 
CopyAsyncValues(absl::Span<const tfrt::RCReference<tfrt::AsyncValue>> events)898 static std::vector<tfrt::RCReference<tfrt::AsyncValue>> CopyAsyncValues(
899     absl::Span<const tfrt::RCReference<tfrt::AsyncValue>> events) {
900   std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs;
901   avs.reserve(events.size());
902   for (const auto& ev : events) {
903     avs.push_back(ev.CopyRef());
904   }
905   return avs;
906 }
907 
908 // Enqueue to TFRT non-blocking work queue when all `values` are ready.
EnqueueWorkWhenReady(tfrt::HostContext * host_ctx,tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values,llvm::unique_function<void ()> callee)909 static void EnqueueWorkWhenReady(
910     tfrt::HostContext* host_ctx,
911     tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values,
912     llvm::unique_function<void()> callee) {
913   tfrt::RunWhenReady(values, [host_ctx, callee = std::move(callee)]() mutable {
914     tfrt::EnqueueWork(host_ctx, std::move(callee));
915   });
916 }
917 
ToLiteral(MutableLiteralBase * literal)918 PjRtFuture<Status> TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal) {
919   tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::ToLiteral");
920   if (IsEmptyTuple()) {
921     return PjRtFuture<Status>(
922         InvalidArgument("ToLiteral called on empty tuple"));
923   }
924   auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
925   auto* device_buffer = AcquireUsage(usage_event);
926   if (device_buffer == nullptr) {
927     return PjRtFuture<Status>(InvalidArgument(
928         "CopyToHostAsync() called on deleted or donated buffer"));
929   }
930   MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
931 
932   auto host_ctx = client_->GetHostContext();
933 
934   std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs = {
935       device_buffer->definition_event().CopyRCRef()};
936   std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs_copy =
937       CopyAsyncValues(device_buffer_wait_avs);
938 
939   bool should_sync_copy = device_buffer_wait_avs.empty() &&
940                           literal->size_bytes() < kSmallDataTransferByteSize;
941   if (should_sync_copy) {
942     if (!on_device_shape().IsTuple()) {
943       const std::shared_ptr<MaybeOwningCpuMemory>& b =
944           device_buffer->Buffers()[0];
945       std::memcpy(literal->untyped_data(), b->data(), b->size());
946     } else {
947       // Tuple case.
948       int num_leaves = literal->shape().tuple_shapes().size();
949       for (int i = 0; i < num_leaves; ++i) {
950         const std::shared_ptr<MaybeOwningCpuMemory>& b =
951             device_buffer->Buffers()[i];
952         std::memcpy(literal->untyped_data({i}), b->data(), b->size());
953       }
954     }
955     // Unblock ToLiteral caller.
956     return PjRtFuture<Status>(OkStatus());
957   } else {
958     auto ready_event = tfrt::MakeUnconstructedAsyncValueRef<Status>();
959     // Wait for buffer definition events to finish before d2h dispatch. D2H
960     // dispatch should be in parallel, e.g. one Execute event finish may trigger
961     // multiple outputs' D2H, they should happen in different threads in
962     // parallel.
963     EnqueueWorkWhenReady(
964         host_ctx, device_buffer_wait_avs,
965         [this, device_buffer_wait_avs = std::move(device_buffer_wait_avs_copy),
966          literal, ready_event = ready_event.CopyRef(), device_buffer,
967          ready_on_exit = std::move(ready_on_exit)]() mutable {
968           tensorflow::profiler::TraceMe traceme("D2H Dispatch");
969           // Errors in src buffer are surfaced to user.
970           for (const auto& av : device_buffer_wait_avs) {
971             if (auto* error = av->GetErrorIfPresent()) {
972               ready_event.emplace(
973                   Internal("Error converting to literal: %s", error->message));
974               return;
975             }
976           }
977 
978           if (!on_device_shape().IsTuple()) {
979             const std::shared_ptr<MaybeOwningCpuMemory>& b =
980                 device_buffer->Buffers()[0];
981             std::memcpy(literal->untyped_data(), b->data(), b->size());
982           } else {
983             // Tuple case.
984             int num_leaves = literal->shape().tuple_shapes().size();
985             for (int i = 0; i < num_leaves; ++i) {
986               const std::shared_ptr<MaybeOwningCpuMemory>& b =
987                   device_buffer->Buffers()[i];
988               std::memcpy(literal->untyped_data({i}), b->data(), b->size());
989             }
990           }
991 
992           // Unblock ToLiteral event.
993           ready_event.emplace(OkStatus());
994         });
995     return PjRtFuture<Status>(
996         std::move(ready_event),
997         /*on_block_start=*/
998         []() {
999           tensorflow::profiler::TraceMeProducer traceme(
1000               "TfrtCpuBuffer::ToLiteral");
1001           VLOG(1) << "TfrtCpuBuffer::ToLiteral";
1002           return PjRtFutureHelpers::ProfilingKeys(
1003               {/*traceme_context_id =*/traceme.GetContextId()});
1004         },
1005         /*on_block_end=*/
1006         [](PjRtFutureHelpers::ProfilingKeys keys) {
1007           tensorflow::profiler::TraceMeConsumer traceme(
1008               "TfrtCpuBuffer::ToLiteral", keys.traceme_context_id);
1009         });
1010   }
1011 }
1012 
1013 // TODO(zhangqiaorjc): Consider disallowing multiple CPU devices and assign
1014 // multiple pmap replicas to the same CPU device for multi-CPU pmap testing.
CopyToDevice(PjRtDevice * dst_device)1015 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuBuffer::CopyToDevice(
1016     PjRtDevice* dst_device) {
1017   tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::CopyToDevice");
1018   // TODO(zhangqiaorjc): Remove this restriction after removing the test that
1019   // explicitly asserts this.
1020   if (dst_device == device_) {
1021     return InvalidArgument(
1022         "CopyToDevice cannot accept the same source and destination devices");
1023   }
1024 
1025   // Copying across PjRtClients involves a copy through the host.
1026   if (dst_device->client() != client_) {
1027     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteralSync());
1028     // Avoid use-after-free on `literal` due to unsequenced move and use.
1029     Literal* literal_pointer = literal.get();
1030     absl::InlinedVector<int64_t, 4> byte_strides(
1031         literal->shape().dimensions_size());
1032     TF_RETURN_IF_ERROR(
1033         ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1034     return dst_device->client()->BufferFromHostBuffer(
1035         literal_pointer->untyped_data(),
1036         literal_pointer->shape().element_type(),
1037         literal_pointer->shape().dimensions(), byte_strides,
1038         TfrtCpuClient::HostBufferSemantics::kZeroCopy,
1039         [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1040   }
1041 
1042   // Copy each leaf buffer to a destination buffer.
1043   auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
1044   auto* src_device_buffer = AcquireUsage(usage_event);
1045   if (src_device_buffer == nullptr) {
1046     return InvalidArgument("CopyToDevice called on deleted or donated buffer");
1047   }
1048   MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
1049 
1050   int num_leaf_buffers = src_device_buffer->Buffers().size();
1051   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> src_buffers;
1052   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> dst_buffers;
1053   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> dst_definition_events;
1054   src_buffers.reserve(num_leaf_buffers);
1055   dst_buffers.reserve(num_leaf_buffers);
1056   dst_definition_events.reserve(num_leaf_buffers);
1057 
1058   for (int i = 0; i < num_leaf_buffers; ++i) {
1059     auto src_buffer = src_device_buffer->Buffers()[i];
1060     TF_ASSIGN_OR_RETURN(auto dst_buffer, MaybeOwningCpuMemory::AllocateShared(
1061                                              src_buffer->size()));
1062     src_buffers.push_back(std::move(src_buffer));
1063     dst_buffers.push_back(std::move(dst_buffer));
1064     dst_definition_events.push_back(
1065         tfrt::MakeConstructedAsyncValueRef<CpuEvent>());
1066   }
1067 
1068   // Wait for src buffer definition events to finish before d2d dispatch.
1069   // Errors are propagated asynchronously in dst buffer's definition events.
1070   const auto& src_definition_event = src_device_buffer->definition_event();
1071 
1072   auto copy_task = [num_leaf_buffers, src_buffers = std::move(src_buffers),
1073                     dst_buffers_copies = dst_buffers, dst_definition_events,
1074                     src_definition_event,
1075                     ready_on_exit = std::move(ready_on_exit)]() mutable {
1076     tensorflow::profiler::TraceMe traceme("D2D Dispatch");
1077     if (auto* error = src_definition_event.GetErrorIfPresent()) {
1078       for (int i = 0; i < num_leaf_buffers; ++i) {
1079         // Any error discovered in src buffer are propagated to dst buffer
1080         // definition events, which will surface to users in
1081         // dst_buffer->ToLiteral().
1082         dst_definition_events[i].SetError(*error);
1083       }
1084       return;
1085     }
1086 
1087     for (int i = 0; i < num_leaf_buffers; ++i) {
1088       std::memcpy(dst_buffers_copies[i]->data(), src_buffers[i]->data(),
1089                   src_buffers[i]->size());
1090       dst_definition_events[i].SetStateConcrete();
1091     }
1092   };
1093 
1094   src_definition_event.AndThen([host_ctx = client()->GetHostContext(),
1095                                 copy_task = std::move(copy_task)]() mutable {
1096     tfrt::EnqueueWork(host_ctx, std::move(copy_task));
1097   });
1098 
1099   return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
1100       on_device_shape_,
1101       std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1102           on_device_shape_.IsTuple(), std::move(dst_buffers),
1103           std::move(dst_definition_events)),
1104       client(), tensorflow::down_cast<TfrtCpuDevice*>(dst_device)));
1105 }
1106 
GetReadyFuture()1107 PjRtFuture<Status> TfrtCpuBuffer::GetReadyFuture() {
1108   tfrt::AsyncValueRef<CpuEvent> definition_event;
1109   {
1110     absl::MutexLock lock(&mu_);
1111     if (!tracked_device_buffer_) {
1112       return PjRtFuture<Status>(InvalidArgument(
1113           "GetReadyFuture() called on deleted or donated buffer"));
1114     }
1115     definition_event = tracked_device_buffer_->definition_event();
1116   }
1117   DCHECK(definition_event);
1118 
1119   if (definition_event.IsAvailable()) {
1120     if (definition_event.IsError()) {
1121       return PjRtFuture<Status>(FailedPrecondition(
1122           "Buffer Definition Event: %s", definition_event.GetError().message));
1123     }
1124     return PjRtFuture<Status>(OkStatus());
1125   } else {
1126     tfrt::AsyncValueRef<Status> status_event =
1127         tfrt::MakeUnconstructedAsyncValueRef<Status>();
1128 
1129     definition_event.AndThen(
1130         [definition_event = definition_event.AsPtr(), status_event]() {
1131           if (definition_event.IsError()) {
1132             status_event.emplace(
1133                 FailedPrecondition("Buffer Definition Event: %s",
1134                                    definition_event.GetError().message));
1135           } else {
1136             status_event.emplace(OkStatus());
1137           }
1138         });
1139 
1140     return PjRtFuture<Status>(
1141         std::move(status_event),
1142         /*on_block_start=*/
1143         []() {
1144           tensorflow::profiler::TraceMeProducer traceme("TfrtCpuBuffer::Await");
1145           VLOG(1) << "TfrtCpuBuffer::Await";
1146           return PjRtFutureHelpers::ProfilingKeys(
1147               {/*traceme_context_id=*/traceme.GetContextId()});
1148         },
1149         /*on_block_end=*/
1150         [](PjRtFutureHelpers::ProfilingKeys keys) {
1151           tensorflow::profiler::TraceMeConsumer traceme(
1152               "TfrtCpuBuffer::Await", keys.traceme_context_id);
1153         });
1154   }
1155 }
1156 
TfrtCpuExecutable(int num_replicas,int num_partitions,std::shared_ptr<DeviceAssignment> device_assignment,bool parameter_is_tupled_arguments,std::unique_ptr<Executable> cpu_executable,BufferAllocation::Index result_buffer_index,absl::InlinedVector<BufferAllocation::Index,4> result_buffer_indices,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,TfrtCpuClient * client)1157 TfrtCpuExecutable::TfrtCpuExecutable(
1158     int num_replicas, int num_partitions,
1159     std::shared_ptr<DeviceAssignment> device_assignment,
1160     bool parameter_is_tupled_arguments,
1161     std::unique_ptr<Executable> cpu_executable,
1162     BufferAllocation::Index result_buffer_index,
1163     absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices,
1164     std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1165     std::vector<PjRtDevice*> addressable_devices, TfrtCpuClient* client)
1166     : client_(client),
1167       num_replicas_(num_replicas),
1168       num_partitions_(num_partitions),
1169       device_assignment_(std::move(device_assignment)),
1170       parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1171       cpu_executable_(std::move(cpu_executable)),
1172       result_buffer_index_(result_buffer_index),
1173       result_buffer_indices_(std::move(result_buffer_indices)),
1174       addressable_device_logical_ids_(
1175           std::move(addressable_device_logical_ids)),
1176       addressable_devices_(std::move(addressable_devices)) {
1177   auto hlo_cost_analysis =
1178       std::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes);
1179   // Cache to avoid std::map lookup in flop_count() on critical path.
1180   // The magic constant 1000 is determined by correlating computation with flop
1181   // estimate. It is a crude heuristic to find computation less than the thread
1182   // context switch time (~5us).
1183   cheap_computation_ = hlo_cost_analysis->flop_count() < 1000;
1184 
1185   const auto& computation_layout =
1186       cpu_executable_->module().entry_computation_layout();
1187   if (computation_layout.parameter_count() == 0) {
1188     return;
1189   }
1190   // Assume compiled program expects either many non-tupled arguments or a
1191   // singled tupled argument. Nested tuple is not yet supported.
1192   if (computation_layout.parameter_count() > 1 ||
1193       !computation_layout.parameter_shape(0).IsTuple()) {
1194     input_buffer_sizes_in_bytes_.reserve(computation_layout.parameter_count());
1195     for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1196       input_buffer_sizes_in_bytes_.push_back(
1197           ShapeUtil::ByteSizeOf(computation_layout.parameter_shape(i)));
1198     }
1199   } else {
1200     input_buffer_sizes_in_bytes_.reserve(
1201         computation_layout.parameter_shape(0).tuple_shapes_size());
1202     for (int i = 0;
1203          i < computation_layout.parameter_shape(0).tuple_shapes_size(); ++i) {
1204       input_buffer_sizes_in_bytes_.push_back(ShapeUtil::ByteSizeOf(
1205           computation_layout.parameter_shape(0).tuple_shapes(i)));
1206     }
1207   }
1208 }
1209 
Delete()1210 void TfrtCpuExecutable::Delete() {}
1211 
IsDeleted()1212 bool TfrtCpuExecutable::IsDeleted() { return false; }
1213 
Fingerprint() const1214 StatusOr<std::optional<std::string>> TfrtCpuExecutable::Fingerprint() const {
1215   return std::optional<std::string>();
1216 }
1217 
SetUpDonation(bool tuple_inputs)1218 Status TfrtCpuExecutable::SetUpDonation(bool tuple_inputs) {
1219   TF_ASSIGN_OR_RETURN(parameters_that_must_be_donated_,
1220                       ComputeParametersThatMustBeDonated(
1221                           *cpu_executable_->shared_module(), tuple_inputs));
1222   return OkStatus();
1223 }
1224 
1225 // The following few helpers are adapted from XLA:CPU to create a buffer table
1226 // and assemble the buffer pointers in order to call into CpuExecutable.
MemoryForAllocation(const BufferAllocation & allocation,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1227 static StatusOr<std::shared_ptr<MaybeOwningCpuMemory>> MemoryForAllocation(
1228     const BufferAllocation& allocation,
1229     absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1230   if (allocation.is_entry_computation_parameter()) {
1231     TrackedTfrtCpuDeviceBuffer* arg = arguments[allocation.parameter_number()];
1232     std::shared_ptr<MaybeOwningCpuMemory> out =
1233         arg->Buffer(allocation.param_shape_index());
1234     CHECK_EQ(allocation.size(), out->size())
1235         << "Size mismatch on param " << allocation.parameter_number()
1236         << " at shape index " << allocation.param_shape_index().ToString();
1237     return out;
1238   } else if (allocation.is_constant()) {
1239     return std::make_shared<MaybeOwningCpuMemory>();
1240   } else if (allocation.is_thread_local()) {
1241     return std::make_shared<MaybeOwningCpuMemory>();
1242   }
1243 
1244   // Output and temporary buffer.
1245   int64_t buffer_size = allocation.size();
1246   TF_ASSIGN_OR_RETURN(auto out,
1247                       MaybeOwningCpuMemory::AllocateShared(buffer_size));
1248 
1249   // Since the output buffer and all the temporary buffers were written into
1250   // by the JITed code, msan has no way of knowing their memory was
1251   // initialized. Mark them initialized so that msan doesn't flag loads from
1252   // these buffers.
1253   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(out->data(), buffer_size);
1254   return out;
1255 }
1256 
1257 static StatusOr<std::vector<std::shared_ptr<MaybeOwningCpuMemory>>>
CreateBufferTable(const BufferAssignment & assignment,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1258 CreateBufferTable(const BufferAssignment& assignment,
1259                   absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1260   std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffers(
1261       assignment.Allocations().size());
1262   for (BufferAllocation::Index i = 0; i < assignment.Allocations().size();
1263        ++i) {
1264     const BufferAllocation& allocation = assignment.GetAllocation(i);
1265     TF_ASSIGN_OR_RETURN(buffers[i], MemoryForAllocation(allocation, arguments));
1266   }
1267   return std::move(buffers);
1268 }
1269 
1270 static absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4>
CreateResultShapedBuffer(absl::Span<const BufferAllocation::Index> buffer_indices,absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1271 CreateResultShapedBuffer(
1272     absl::Span<const BufferAllocation::Index> buffer_indices,
1273     absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,
1274     absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1275   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> output_buffers;
1276   output_buffers.reserve(buffer_indices.size());
1277   for (int i = 0; i < buffer_indices.size(); ++i) {
1278     output_buffers.push_back(buffer_table[buffer_indices[i]]);
1279   }
1280   return output_buffers;
1281 }
1282 
CheckBufferCompatibilities(absl::Span<TrackedTfrtCpuDeviceBuffer * const> input_buffers) const1283 Status TfrtCpuExecutable::CheckBufferCompatibilities(
1284     absl::Span<TrackedTfrtCpuDeviceBuffer* const> input_buffers) const {
1285   if (input_buffers.size() != input_buffer_sizes_in_bytes_.size()) {
1286     return InvalidArgument(
1287         "Execution supplied %lld buffers but compiled program expected %lld "
1288         "buffers",
1289         input_buffers.size(), input_buffer_sizes_in_bytes_.size());
1290   }
1291   for (int i = 0; i < input_buffers.size(); ++i) {
1292     const auto& buffer = input_buffers[i];
1293     if (input_buffer_sizes_in_bytes_[i] != buffer->Buffers()[0]->size()) {
1294       return InvalidArgument(
1295           "Executable expected parameter %d of size %lld but got buffer with "
1296           "incompatible size %lld",
1297           i, input_buffer_sizes_in_bytes_[i], buffer->Buffers()[0]->size());
1298     }
1299   }
1300   return OkStatus();
1301 }
1302 
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,bool fill_future,TfrtCpuDevice * device)1303 StatusOr<PjRtLoadedExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
1304     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1305     const RunId& run_id, const ExecuteOptions& options,
1306     tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,
1307     bool fill_future, TfrtCpuDevice* device) {
1308   tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteHelper");
1309   auto* host_context = client_->GetHostContext();
1310 
1311   std::shared_ptr<DeviceAssignment> device_assignment;
1312   if (device == nullptr) {
1313     CHECK(device_assignment_ != nullptr);
1314     const int device_id = (*device_assignment_)(replica, partition);
1315     TF_ASSIGN_OR_RETURN(PjRtDevice * pjrt_device,
1316                         client_->LookupDevice(device_id));
1317     device = tensorflow::down_cast<TfrtCpuDevice*>(pjrt_device);
1318     device_assignment = device_assignment_;
1319   } else {
1320     CHECK(device_assignment_ == nullptr);
1321     CHECK_EQ(replica, 0);
1322     CHECK_EQ(partition, 0);
1323     CHECK(addressable_devices_.empty());
1324     device_assignment = std::make_shared<DeviceAssignment>(1, 1);
1325     (*device_assignment)(0, 0) = device->id();
1326   }
1327   CHECK_EQ(device->process_index(), client_->process_index());
1328 
1329   // Handle inputs.
1330   if (options.arguments_are_tupled) {
1331     if (!parameter_is_tupled_arguments_) {
1332       return InvalidArgument(
1333           "Arguments may only be supplied as a tuple when the executable was "
1334           "compiled with a single tupled parameter");
1335     }
1336     if (argument_handles.size() != 1) {
1337       return InvalidArgument(
1338           "Option arguments_are_tupled was true but %d buffers were passed to "
1339           "execution",
1340           argument_handles.size());
1341     }
1342   }
1343 
1344   // `execute_event` indicates whether cpu computation is complete and whether
1345   // there was an error.
1346   auto execute_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
1347   MarkEventReadyOnExit ready_on_exit(execute_event);
1348 
1349   absl::InlinedVector<TfrtCpuBuffer::DonationTransaction, 4>
1350       donation_transactions;
1351   absl::InlinedVector<TrackedTfrtCpuDeviceBuffer*, 4> tracked_buffers;
1352   tracked_buffers.reserve(argument_handles.size());
1353   // To avoid clobbering inputs, we must ensure that
1354   //   `extra_deps` = inputs' definition events + donated inputs' usage events.
1355   // This also ensures that the returned `execute_event` dominates all inputs'
1356   // events, and thus output buffer only need to contain `execute_event` as the
1357   // single definition event.
1358   std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps;
1359   input_deps.reserve(argument_handles.size());
1360 
1361   auto donate_it = parameters_that_must_be_donated_.begin();
1362 
1363   for (int i = 0; i < argument_handles.size(); ++i) {
1364     PjRtBuffer* handle = argument_handles[i];
1365     auto* tfrt_buffer = tensorflow::down_cast<TfrtCpuBuffer*>(handle);
1366     if (tfrt_buffer->device() != device) {
1367       return InvalidArgument(
1368           "Buffer passed to Execute() as argument %d to replica %d is on "
1369           "device %s, but replica is assigned to device %s.",
1370           i, replica, tfrt_buffer->device()->DebugString(),
1371           device->DebugString());
1372     }
1373 
1374     bool must_donate =
1375         donate_it != parameters_that_must_be_donated_.end() && *donate_it == i;
1376     TrackedTfrtCpuDeviceBuffer* tracked_buffer = nullptr;
1377     if (must_donate) {
1378       ++donate_it;
1379       TF_ASSIGN_OR_RETURN(auto donation_transaction,
1380                           tfrt_buffer->AcquireDonation());
1381 
1382       // After acquiring the buffer for donation, we retrieve the dependent
1383       // usage events. Note that we don't need any locking here as
1384       // AcquireDonation() is supposed to synchronize with other usages.
1385       for (const auto& ev :
1386            donation_transaction.device_buffer()->UsageEvents()) {
1387         if (!ev.IsAvailable()) {
1388           input_deps.push_back(ev.CopyRCRef());
1389         }
1390       }
1391       tracked_buffer = donation_transaction.device_buffer();
1392       tracked_buffers.push_back(tracked_buffer);
1393       donation_transactions.push_back(std::move(donation_transaction));
1394 
1395     } else {
1396       tracked_buffer = tfrt_buffer->AcquireUsage(execute_event);
1397       if (!tracked_buffer)
1398         return InvalidArgument(
1399             "Invalid buffer passed: buffer has been deleted or donated.");
1400       tracked_buffers.push_back(tracked_buffer);
1401     }
1402 
1403     // Definition events are never modified after buffer construction.
1404     const auto& definition_event = tracked_buffer->definition_event();
1405     if (!definition_event.IsAvailable()) {
1406       input_deps.push_back(definition_event.CopyRCRef());
1407     }
1408   }
1409 
1410   TF_RETURN_IF_ERROR(CheckBufferCompatibilities(tracked_buffers));
1411 
1412   // Tuplize the inputs if compiler expects a single tuple argument but runtime
1413   // gets many inputs that are not yet tupled.
1414   std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tuplized_arg;
1415   if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1416     absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> leaf_buffers;
1417     leaf_buffers.reserve(tracked_buffers.size());
1418     for (const auto& tracked_buffer : tracked_buffers) {
1419       auto span = tracked_buffer->Buffers();
1420       leaf_buffers.insert(leaf_buffers.end(), span.begin(), span.end());
1421     }
1422 
1423     // Tuplize into a single input.
1424     tracked_buffers.clear();
1425     tuplized_arg = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1426         /*is_tuple=*/true, std::move(leaf_buffers),
1427         /*definition_event=*/tfrt::MakeAvailableAsyncValueRef<CpuEvent>());
1428     tracked_buffers.push_back(tuplized_arg.get());
1429   }
1430 
1431   auto* cpu_executable =
1432       tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable_.get());
1433   TF_ASSIGN_OR_RETURN(
1434       std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,
1435       CreateBufferTable(cpu_executable->buffer_assignment(), tracked_buffers));
1436   auto result_buffers = CreateResultShapedBuffer(result_buffer_indices_,
1437                                                  buffer_table, tracked_buffers);
1438 
1439   // The choice of where we wait is arbitrary; the reason for the wait is
1440   // pacing to avoid problems such as memory fragmentation and running ahead
1441   // too far, not for correctness. Placing it before the executable launch
1442   // allows the inputs for the next executable to be fetched even if the
1443   // launch is delayed.
1444   auto compute_reservation = std::make_unique<Semaphore::ScopedReservation>(
1445       device->max_inflight_computations_semaphore().ScopedAcquire(1));
1446 
1447   // Call the computation function following the calling convention.
1448   std::vector<void*> buffer_pointers;
1449   buffer_pointers.reserve(buffer_table.size());
1450   for (const auto& buffer : buffer_table) {
1451     buffer_pointers.push_back(buffer->data());
1452   }
1453   void* result_buffer = buffer_pointers[result_buffer_index_];
1454 
1455   ExecutableRunOptions run_options;
1456   run_options.set_run_id(run_id);
1457   run_options.set_device_ordinal(device->local_hardware_id());
1458   // Need to keep device_assignment alive until execution completes.
1459   run_options.set_device_assignment(device_assignment.get());
1460   run_options.set_intra_op_thread_pool(client_->eigen_intraop_device());
1461 
1462   // Schedule only one collective at a time.
1463   bool is_a_collective_launch = !!last_collective_launch_event;
1464   if (is_a_collective_launch) {
1465     input_deps.push_back(std::move(last_collective_launch_event));
1466   }
1467 
1468   bool execute_inline = cheap_computation_;
1469 
1470   // Overwrite `execute_inline` if it is specified in the ExecuteOptions.
1471   if (options.execution_mode == ExecuteOptions::ExecutionMode::kAsynchronous) {
1472     execute_inline = false;
1473   } else if (options.execution_mode ==
1474              ExecuteOptions::ExecutionMode::kSynchronous) {
1475     execute_inline = true;
1476   }
1477 
1478   if (input_deps.empty() && execute_inline) {
1479     // Synchronously call generated function.
1480 
1481     // Set denormal and rounding behavior to match the default TF
1482     // ThreadPool behavior.
1483     tensorflow::port::ScopedFlushDenormal flush;
1484     tensorflow::port::ScopedSetRound round(FE_TONEAREST);
1485 
1486     XlaCustomCallStatus status;
1487 
1488     // Call generated function.
1489     cpu_executable->compute_function()(result_buffer, &run_options, nullptr,
1490                                        buffer_pointers.data(), &status,
1491                                        nullptr);
1492 
1493     for (auto& donation_transaction : donation_transactions) {
1494       std::move(donation_transaction).Commit();
1495     }
1496 
1497     std::optional<absl::string_view> error_message =
1498         xla::CustomCallStatusGetMessage(&status);
1499     if (error_message) {
1500       return InternalError("Generated function failed: %s", *error_message);
1501     }
1502 
1503   } else {
1504     // TODO(zhangqiaorjc): Only async launch expensive computations. Need
1505     // heuristics to decide what computation is expensive.
1506     // Asynchronously call generated function.
1507 
1508     // We only created enough threads for one collective to complete.
1509     // The next collective launch will not be scheduled onto threadpool until
1510     // this one completes.
1511     if (is_a_collective_launch) {
1512       client_->SetLastCollectiveLaunchEvent(execute_event.CopyRef());
1513     }
1514     std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps_avs_copy =
1515         CopyAsyncValues(input_deps);
1516     EnqueueWorkWhenReady(
1517         host_context, input_deps,
1518         [cpu_executable, result_buffer,
1519          buffer_pointers = std::move(buffer_pointers),
1520          buffer_table = std::move(buffer_table),
1521          run_options = std::move(run_options),
1522          cpu_executable_copy = cpu_executable_,
1523          device_assignment = std::move(device_assignment),
1524          compute_reservation = std::move(compute_reservation),
1525          tuplized_arg = std::move(tuplized_arg),
1526          donation_transactions = std::move(donation_transactions),
1527          execute_event = std::move(ready_on_exit).Release(),
1528          input_deps_avs = std::move(input_deps_avs_copy)]() mutable {
1529           for (const auto& av : input_deps_avs) {
1530             if (auto* error = av->GetErrorIfPresent()) {
1531               execute_event.SetError(absl::StrCat(
1532                   "Error dispatching computation: %s", error->message));
1533               return;
1534             }
1535           }
1536 
1537           // Set denormal and rounding behavior to match the default TF
1538           // ThreadPool behavior.
1539           tensorflow::port::ScopedFlushDenormal flush;
1540           tensorflow::port::ScopedSetRound round(FE_TONEAREST);
1541 
1542           XlaCustomCallStatus status;
1543 
1544           // Call generated function.
1545           cpu_executable->compute_function()(result_buffer, &run_options,
1546                                              nullptr, buffer_pointers.data(),
1547                                              &status, nullptr);
1548 
1549           std::optional<absl::string_view> error_message =
1550               xla::CustomCallStatusGetMessage(&status);
1551 
1552           for (auto& donation_transaction : donation_transactions) {
1553             std::move(donation_transaction).Commit();
1554           }
1555 
1556           if (error_message) {
1557             // CPU computation fails with an error.
1558             execute_event.SetError(absl::StrFormat(
1559                 "Generated function failed: %s", *error_message));
1560             return;
1561           }
1562 
1563           // CPU computation completes.
1564           execute_event.SetStateConcrete();
1565         });
1566   }
1567 
1568   // Create output TFRT buffers.
1569   const Shape& result_shape = cpu_executable_->result_shape();
1570   std::vector<std::unique_ptr<PjRtBuffer>> res;
1571   if (options.untuple_result && result_shape.IsTuple()) {
1572     res.reserve(result_buffers.size());
1573     for (int i = 0; i < result_buffers.size(); ++i) {
1574       absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> sub_buffer;
1575       sub_buffer.push_back(std::move(result_buffers[i]));
1576       // Program execution writes to output buffers so it's a definition event.
1577       absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
1578       definition_events.push_back(execute_event.CopyRef());
1579       auto leaf_tracked_device_buffer =
1580           std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1581               /*is_tuple=*/false, std::move(sub_buffer),
1582               std::move(definition_events));
1583       auto leaf_buffer = std::make_unique<TfrtCpuBuffer>(
1584           result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer),
1585           client_, device);
1586       res.push_back(std::move(leaf_buffer));
1587     }
1588   } else {
1589     // Program execution writes to output buffers so it's a definition event.
1590     auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1591         /*is_tuple=*/result_shape.IsTuple(), std::move(result_buffers),
1592         /*definition_event=*/execute_event);
1593     auto tfrt_output_buffer = std::make_unique<TfrtCpuBuffer>(
1594         result_shape, std::move(tracked_device_buffer), client_, device);
1595     res.push_back(std::move(tfrt_output_buffer));
1596   }
1597   std::optional<PjRtFuture<Status>> future;
1598   if (fill_future) {
1599     auto done_event = tfrt::MakeUnconstructedAsyncValueRef<Status>();
1600     execute_event.AndThen(
1601         [done_event = done_event.CopyRef(), event = execute_event.CopyRef()]() {
1602           Status s;
1603           if (auto* error = event.GetErrorIfPresent()) {
1604             s = InternalError("Compute error: %s", error->message);
1605           }
1606           done_event.emplace(std::move(s));
1607         });
1608     future = PjRtFuture<Status>(std::move(done_event));
1609   }
1610   return Result({/*future=*/std::move(future), /*buffers=*/std::move(res)});
1611 }
1612 
1613 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)1614 TfrtCpuExecutable::Execute(
1615     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1616     const ExecuteOptions& options,
1617     std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
1618   tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::Execute");
1619   if (device_assignment_ == nullptr) {
1620     return InvalidArgument("Execute expects a non-null device_assignment");
1621   }
1622 
1623   RunId run_id;
1624   tensorflow::profiler::TraceMeProducer activity(
1625       "TfrtCpuExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
1626       run_id.ToInt());
1627 
1628   const int num_addressable_devices = addressable_devices_.size();
1629 
1630   if (argument_handles.size() != num_addressable_devices) {
1631     return InvalidArgument(
1632         "Attempted to execute with %d argument lists when local device "
1633         "count is %d (total replica count: %d, partition count: %d)",
1634         argument_handles.size(), num_addressable_devices, num_replicas(),
1635         num_partitions());
1636   }
1637 
1638   VLOG(1) << "Executing computation " << name()
1639           << "; num_replicas=" << num_replicas()
1640           << " num_partitions=" << num_partitions()
1641           << " num_addressable_devices=" << num_addressable_devices;
1642 
1643   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
1644       num_addressable_devices);
1645   if (returned_futures.has_value()) {
1646     returned_futures->resize(num_addressable_devices);
1647   }
1648   if (num_addressable_devices == 1) {
1649     // Fast-path if there is only one device — run the computation on the
1650     // current thread.
1651     const int replica = addressable_device_logical_ids_[0].replica;
1652     const int partition = addressable_device_logical_ids_[0].partition;
1653 
1654     auto statusor = ExecuteHelper(
1655         argument_handles[0], replica, partition, run_id, options,
1656         /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>(),
1657         returned_futures.has_value());
1658 
1659     if (!statusor.ok()) {
1660       return std::move(statusor).status();
1661     }
1662 
1663     wrapped_results[0] = std::move(statusor->buffers);
1664     if (returned_futures.has_value()) {
1665       (*returned_futures)[0] = std::move(*statusor->future);
1666     }
1667 
1668   } else {
1669     // Gang schedule collectives to ensure that collectives with the same RunId
1670     // are run at the same time. We conservatively run only one collective at a
1671     // time, because we may not have enough threads to run arbitrary number of
1672     // collectives concurrently.
1673     tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event =
1674         client_->GetLastCollectiveLaunchEvent();
1675 
1676     absl::Mutex mu;
1677     int running = num_addressable_devices;
1678     int failed = 0;
1679     Status first_failure_status;
1680 
1681     for (int i = 0; i < num_addressable_devices; ++i) {
1682       const int replica = addressable_device_logical_ids_[i].replica;
1683       const int partition = addressable_device_logical_ids_[i].partition;
1684       tfrt::EnqueueWork(client_->GetHostContext(), [&, replica, partition, i] {
1685         auto statusor =
1686             ExecuteHelper(argument_handles[i], replica, partition, run_id,
1687                           options, last_collective_launch_event.CopyRef(),
1688                           returned_futures.has_value());
1689         if (statusor.ok()) {
1690           wrapped_results[i] = std::move(statusor->buffers);
1691           if (returned_futures.has_value()) {
1692             (*returned_futures)[i] = std::move(*statusor->future);
1693           }
1694         }
1695 
1696         absl::MutexLock lock(&mu);
1697         --running;
1698         if (!statusor.ok()) {
1699           if (failed == 0) {
1700             first_failure_status = AppendStatus(
1701                 std::move(statusor).status(),
1702                 absl::StrFormat(
1703                     "while running replica %d and partition %d of a "
1704                     "replicated computation (other "
1705                     "replicas may have failed as well).",
1706                     replica, partition));
1707           }
1708           ++failed;
1709         }
1710       });
1711     }
1712 
1713     {
1714       auto done_running = [&]() {
1715         mu.AssertHeld();
1716         return running == 0;
1717       };
1718       absl::MutexLock lock(&mu);
1719       mu.Await(absl::Condition(&done_running));
1720     }
1721 
1722     if (!first_failure_status.ok()) return first_failure_status;
1723   }
1724   VLOG(1) << "Replicated execution complete.";
1725 
1726   return wrapped_results;
1727 }
1728 
1729 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)1730 TfrtCpuExecutable::ExecuteSharded(
1731     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1732     const ExecuteOptions& options,
1733     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
1734   tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteSharded");
1735   if (device_assignment_ == nullptr) {
1736     return InvalidArgument("ExecuteShard expects a non-null device_assignment");
1737   }
1738   for (int i = 0; i < addressable_devices_.size(); ++i) {
1739     if (addressable_devices_[i] == device) {
1740       VLOG(1) << "ExecuteShard executes computation " << name()
1741               << " on assigned replica/partition on device "
1742               << device->DebugString();
1743       TF_ASSIGN_OR_RETURN(
1744           auto result,
1745           ExecuteHelper(
1746               argument_handles, addressable_device_logical_ids_[i].replica,
1747               addressable_device_logical_ids_[i].partition, RunId(), options,
1748               /*last_collective_launch_event=*/
1749               tfrt::AsyncValueRef<CpuEvent>(), fill_future));
1750       returned_future = std::move(result.future);
1751       return std::move(result.buffers);
1752     }
1753   }
1754   return InvalidArgument(
1755       "ExecuteShard attempted to execute on device id %d which is not "
1756       "addressable by this client",
1757       device->id());
1758 }
1759 
1760 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)1761 TfrtCpuExecutable::ExecutePortable(
1762     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1763     const ExecuteOptions& options,
1764     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
1765   tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecutePortable");
1766   if (device_assignment_ != nullptr) {
1767     return InvalidArgument("ExecutePortable gets a non-portable executable");
1768   }
1769   if (num_replicas() != 1 || num_partitions() != 1) {
1770     return InvalidArgument(
1771         "ExecutePortable expects a single-core executable but gets "
1772         "one with %d replica %d partition",
1773         num_replicas(), num_partitions());
1774   }
1775   if (device == nullptr) {
1776     return InvalidArgument("ExecutePortable expects a device to be specified");
1777   }
1778   VLOG(1) << "ExecutePortable executes single-core portable executable "
1779           << name();
1780   TF_ASSIGN_OR_RETURN(
1781       auto result,
1782       ExecuteHelper(
1783           argument_handles,
1784           /*replica=*/0,
1785           /*partition=*/0, RunId(), options,
1786           /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>(),
1787           fill_future, tensorflow::down_cast<TfrtCpuDevice*>(device)));
1788   returned_future = std::move(result.future);
1789   return std::move(result.buffers);
1790 }
1791 }  // namespace xla
1792