1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24
25 #include "tensorflow/compiler/xla/util.h"
26
27 #define EIGEN_USE_THREADS
28
29 #include "absl/base/thread_annotations.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/synchronization/mutex.h"
33 #include "absl/types/span.h"
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #include "tensorflow/compiler/xla/client/executable_build_options.h"
36 #include "tensorflow/compiler/xla/client/xla_computation.h"
37 #include "tensorflow/compiler/xla/layout.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
40 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
41 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
42 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
43 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h"
44 #include "tensorflow/compiler/xla/pjrt/utils.h"
45 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
46 #include "tensorflow/compiler/xla/primitive_util.h"
47 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
48 #include "tensorflow/compiler/xla/service/computation_placer.h"
49 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_xfeed.h"
51 #include "tensorflow/compiler/xla/service/dump.h"
52 #include "tensorflow/compiler/xla/service/executable.h"
53 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
54 #include "tensorflow/compiler/xla/shape.h"
55 #include "tensorflow/compiler/xla/statusor.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/denormal.h"
58 #include "tensorflow/core/platform/setround.h"
59 #include "tensorflow/core/profiler/lib/connected_traceme.h"
60 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
61 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
62 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
63 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
64 #include "tfrt/host_context/host_context.h" // from @tf_runtime
65 #include "tfrt/support/forward_decls.h" // from @tf_runtime
66
67 namespace xla {
68 namespace {
69
70 // A RAII helper class used to set an AsyncValueRef<CpuEvent> to a ready state
71 // upon destruction. In many cases in PjRt implementation, there will be
72 // multiple return statements in the function, all of which require setting some
73 // AsyncValueRef<CpuEvent> to be ready. This class could make such code more
74 // robust by using setting the AsyncValue in the destructor.
75 class MarkEventReadyOnExit {
76 public:
MarkEventReadyOnExit(tfrt::AsyncValueRef<CpuEvent> event)77 explicit MarkEventReadyOnExit(tfrt::AsyncValueRef<CpuEvent> event)
78 : event_(std::move(event)) {}
79
80 MarkEventReadyOnExit(const MarkEventReadyOnExit&) = delete;
81 MarkEventReadyOnExit& operator=(const MarkEventReadyOnExit&) = delete;
82 MarkEventReadyOnExit(MarkEventReadyOnExit&&) = default;
83 MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) = default;
84
~MarkEventReadyOnExit()85 ~MarkEventReadyOnExit() {
86 if (event_) event_.SetStateConcrete();
87 }
88
Release()89 tfrt::AsyncValueRef<CpuEvent> Release() && { return std::move(event_); }
90
91 private:
92 tfrt::AsyncValueRef<CpuEvent> event_;
93 };
94
95 } // namespace
96
97 static const char kCpuPlatformName[] = "cpu";
98 static constexpr size_t kSmallDataTransferByteSize = 102400; // 100 KiB
99
GetOrCreateReadyEvent(tfrt::HostContext * host_context)100 static tfrt::AsyncValueRef<CpuEvent> GetOrCreateReadyEvent(
101 tfrt::HostContext* host_context) {
102 static const auto* ready_event = new tfrt::AsyncValueRef<CpuEvent>(
103 tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_context));
104 return ready_event->CopyRef();
105 }
106
TfrtCpuDevice(int id,bool asynchronous)107 TfrtCpuDevice::TfrtCpuDevice(int id, bool asynchronous)
108 : id_(id),
109 max_inflight_computations_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
110 debug_string_ = absl::StrCat("TFRT_CPU_", id);
111 to_string_ = absl::StrCat("CpuDevice(id=", id, ")");
112 }
113
device_kind() const114 absl::string_view TfrtCpuDevice::device_kind() const {
115 return kCpuPlatformName;
116 }
117
DebugString() const118 absl::string_view TfrtCpuDevice::DebugString() const { return debug_string_; }
119
ToString() const120 absl::string_view TfrtCpuDevice::ToString() const { return to_string_; }
121
TransferToInfeed(const LiteralSlice & literal)122 Status TfrtCpuDevice::TransferToInfeed(const LiteralSlice& literal) {
123 return TransferLiteralToInfeedOnCpu(local_hardware_id(), literal);
124 }
125
TransferFromOutfeed(MutableBorrowingLiteral literal)126 Status TfrtCpuDevice::TransferFromOutfeed(MutableBorrowingLiteral literal) {
127 return TransferLiteralFromOutfeedOnCpu(local_hardware_id(), literal);
128 }
129
CpuDeviceCount()130 static int CpuDeviceCount() {
131 // By default we fix the number of devices to one. However we do let the user
132 // override this behavior to help run tests on the host that run models in
133 // parallel across multiple devices, e.g. pmap.
134 return GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
135 }
136
GetTfrtCpuDevices(bool asynchronous,int cpu_device_count)137 static StatusOr<std::vector<std::unique_ptr<TfrtCpuDevice>>> GetTfrtCpuDevices(
138 bool asynchronous, int cpu_device_count) {
139 std::vector<std::unique_ptr<TfrtCpuDevice>> devices;
140 for (int i = 0; i < cpu_device_count; ++i) {
141 auto device = std::make_unique<TfrtCpuDevice>(
142 /*id=*/i, asynchronous);
143 devices.push_back(std::move(device));
144 }
145 return std::move(devices);
146 }
147
GetTfrtCpuClient(bool asynchronous,int cpu_device_count)148 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous,
149 int cpu_device_count) {
150 // TODO(zhangqiaorjc): Allow users set the number of threads.
151 // `num_blocking_threads=16` is picked arbitrarily for now.
152 // Need at least CpuDeviceCount threads to launch one collective.
153 int num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count);
154 auto host_context = std::make_unique<tfrt::HostContext>(
155 [](const tfrt::DecodedDiagnostic& diag) {
156 LOG(ERROR) << "Encountered runtime error: " << diag.message << "\n";
157 },
158 tfrt::CreateMallocAllocator(),
159 tfrt::CreateMultiThreadedWorkQueue(
160 /*num_threads=*/num_threads,
161 /*num_blocking_threads=*/16));
162
163 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
164 GetTfrtCpuDevices(asynchronous, cpu_device_count));
165
166 return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>(
167 /*process_index=*/0, std::move(devices), std::move(host_context)));
168 }
169
GetTfrtCpuClient(bool asynchronous)170 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous) {
171 return GetTfrtCpuClient(asynchronous, CpuDeviceCount());
172 }
173
TfrtCpuClient(int process_index,std::vector<std::unique_ptr<TfrtCpuDevice>> devices,std::unique_ptr<tfrt::HostContext> host_ctx)174 TfrtCpuClient::TfrtCpuClient(
175 int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
176 std::unique_ptr<tfrt::HostContext> host_ctx)
177 : process_index_(process_index),
178 owned_devices_(std::move(devices)),
179 host_ctx_(std::move(host_ctx)),
180 computation_placer_(std::make_unique<ComputationPlacer>()),
181 eigen_intraop_pool_(new tensorflow::thread::ThreadPool(
182 tensorflow::Env::Default(), "XLAEigen", DefaultThreadPoolSize())),
183 eigen_intraop_device_(
184 new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(),
185 eigen_intraop_pool_->NumThreads())),
186 last_collective_launch_event_(
187 tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_ctx_.get())),
188 transpose_cache_(1024) {
189 for (const std::unique_ptr<TfrtCpuDevice>& device : owned_devices_) {
190 devices_.push_back(device.get());
191 CHECK(id_to_device_.insert({device->id(), device.get()}).second)
192 << "Duplicate device id: " << device->id();
193
194 device->SetClient(this);
195 if (device->IsAddressable()) {
196 int idx = device->local_hardware_id();
197 if (idx >= addressable_devices_.size()) {
198 addressable_devices_.resize(idx + 1);
199 }
200 CHECK(addressable_devices_[idx] == nullptr) << idx;
201 addressable_devices_[idx] = device.get();
202 }
203 }
204 for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
205 CHECK(addressable_devices_[idx] != nullptr) << idx;
206 }
207 LOG(INFO) << "TfrtCpuClient created.";
208 }
209
~TfrtCpuClient()210 TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; }
211
LookupDevice(int device_id) const212 StatusOr<PjRtDevice*> TfrtCpuClient::LookupDevice(int device_id) const {
213 auto it = id_to_device_.find(device_id);
214 if (it != id_to_device_.end()) {
215 return it->second;
216 }
217 return InvalidArgument("No matching device found for device_id %d",
218 device_id);
219 }
220
LookupAddressableDevice(int local_hardware_id) const221 StatusOr<PjRtDevice*> TfrtCpuClient::LookupAddressableDevice(
222 int local_hardware_id) const {
223 for (auto* device : addressable_devices_) {
224 if (local_hardware_id == device->local_hardware_id()) {
225 return device;
226 }
227 }
228 return InvalidArgument("No matching device found for local_hardware_id %d",
229 local_hardware_id);
230 }
231
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const232 StatusOr<DeviceAssignment> TfrtCpuClient::GetDefaultDeviceAssignment(
233 int num_replicas, int num_partitions) const {
234 return computation_placer_->AssignDevices(num_replicas, num_partitions);
235 }
236
GetHloCostAnalysis()237 StatusOr<std::unique_ptr<HloCostAnalysis>> TfrtCpuClient::GetHloCostAnalysis() {
238 return std::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes);
239 }
240
ExecutableFingerprint(const PjRtLoadedExecutable & executable) const241 StatusOr<std::optional<std::string>> TfrtCpuClient::ExecutableFingerprint(
242 const PjRtLoadedExecutable& executable) const {
243 return std::optional<std::string>();
244 }
245
JitCompile(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options,const ExecutionOptions & execution_options)246 static StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
247 const XlaComputation& computation,
248 const absl::Span<const Shape* const> argument_layouts,
249 const ExecutableBuildOptions& build_options,
250 const ExecutionOptions& execution_options) {
251 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
252 computation.GetProgramShape());
253 // Unoptimized HloModuleConfig.
254 TF_ASSIGN_OR_RETURN(
255 std::unique_ptr<HloModuleConfig> hlo_module_config,
256 CreateModuleConfig(program_shape, argument_layouts, &execution_options,
257 execution_options.num_replicas(),
258 /*num_threads=*/std::nullopt,
259 /*aot_options=*/nullptr));
260
261 // Unoptimized HloModule.
262 const xla::HloModuleProto& hlo_module_proto = computation.proto();
263 TF_ASSIGN_OR_RETURN(
264 std::unique_ptr<HloModule> hlo_module,
265 xla::HloModule::CreateFromProto(hlo_module_proto, *hlo_module_config));
266 VLOG(3) << "Unoptimized HLO module: " << hlo_module->ToString();
267 static constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
268 DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName);
269
270 // Run Hlo Passes
271 cpu::CpuCompiler compiler;
272 xla::Compiler::CompileOptions dummy;
273 TF_ASSIGN_OR_RETURN(hlo_module,
274 compiler.RunHloPasses(std::move(hlo_module),
275 /*stream_exec=*/nullptr, dummy));
276
277 // Run backend.
278 return compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
279 dummy);
280 }
281
282 // Find the root instruction of the entry computation.
GetRootValueSet(const BufferAssignment & assignment,const HloModule & module)283 static const InstructionValueSet& GetRootValueSet(
284 const BufferAssignment& assignment, const HloModule& module) {
285 return assignment.dataflow_analysis().GetInstructionValueSet(
286 module.entry_computation()->root_instruction());
287 }
288
289 // Buffer table is indexed by buffer allocation indices. The output buffer is
290 // made up of a subset of those buffer allocations (for tuple, it includes tuple
291 // index table). This helper finds the buffer allocation indices in buffer
292 // assignment that make up for the output buffer. It is used by
293 // CreateResultShapedBuffer to reconstruct the output buffer from the buffer
294 // table allocated by MemoryForAllocation.
295 static StatusOr<absl::InlinedVector<BufferAllocation::Index, 4>>
FindResultBufferAllocationIndex(const BufferAssignment & assignment,const HloModule & module)296 FindResultBufferAllocationIndex(const BufferAssignment& assignment,
297 const HloModule& module) {
298 absl::InlinedVector<BufferAllocation::Index, 4> buffer_indices;
299 const InstructionValueSet& root_value_set =
300 GetRootValueSet(assignment, module);
301 const Shape& result_shape = module.result_shape();
302 if (!result_shape.IsTuple()) {
303 // Find the buffer allocation that corresponds to the output buffer.
304 const HloValueSet& sources = root_value_set.element({});
305 // The points to set is unambiguous so the set should be a singleton.
306 CHECK_EQ(1, sources.values().size());
307 const HloValue* value_source = sources.values()[0];
308 HloInstruction* src = value_source->instruction();
309 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
310 assignment.GetUniqueSlice(src, value_source->index()));
311 const BufferAllocation::Index buffer_index = slice.index();
312 buffer_indices.push_back(buffer_index);
313 return {std::move(buffer_indices)};
314 }
315 buffer_indices.reserve(result_shape.tuple_shapes_size());
316 for (int i = 0; i < result_shape.tuple_shapes_size(); ++i) {
317 // Find the buffer allocations that corresponds to the output tuple,
318 // including the tuple index table.
319 const HloValueSet& sources = root_value_set.element({i});
320 // The points to set is unambiguous so the set should be a singleton.
321 CHECK_EQ(1, sources.values().size());
322 const HloValue* value_source = sources.values()[0];
323 HloInstruction* src = value_source->instruction();
324 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
325 assignment.GetUniqueSlice(src, value_source->index()));
326 const BufferAllocation::Index buffer_index = slice.index();
327 buffer_indices.push_back(buffer_index);
328 }
329 return {std::move(buffer_indices)};
330 }
331
Compile(const XlaComputation & computation,CompileOptions options)332 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
333 const XlaComputation& computation, CompileOptions options) {
334 tensorflow::profiler::TraceMe traceme("TfrtCpuClient::Compile");
335 ExecutableBuildOptions& build_options = options.executable_build_options;
336
337 int num_replicas;
338 int num_partitions;
339 std::shared_ptr<DeviceAssignment> device_assignment;
340 TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
341 options.compile_portable_executable, &options.executable_build_options,
342 [this](int num_replicas, int num_partitions) {
343 return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
344 },
345 &num_replicas, &num_partitions, &device_assignment));
346
347 std::vector<const Shape*> argument_layout_pointers;
348 TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
349 computation, &LayoutUtil::GetWithDefaultLayout, options.argument_layouts,
350 &options.executable_build_options, &argument_layout_pointers));
351
352 std::vector<PjRtLoadedExecutable::LogicalDeviceIds>
353 addressable_device_logical_ids;
354 std::vector<PjRtDevice*> addressable_devices;
355 if (device_assignment != nullptr) {
356 addressable_device_logical_ids.reserve(num_replicas * num_partitions);
357 addressable_devices.reserve(num_replicas * num_partitions);
358 for (int replica = 0; replica < num_replicas; ++replica) {
359 for (int partition = 0; partition < num_partitions; ++partition) {
360 int device_id = (*device_assignment)(replica, partition);
361 TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
362 if (device->process_index() != process_index()) {
363 VLOG(3) << "Non-local device: " << device_id;
364 continue;
365 }
366 PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
367 logica_device_ids.replica = replica;
368 logica_device_ids.partition = partition;
369 addressable_device_logical_ids.push_back(std::move(logica_device_ids));
370 addressable_devices.push_back(device);
371 }
372 }
373 if (addressable_devices.empty()) {
374 return InvalidArgument(
375 "Device assignment (%s) does not have any local devices.",
376 device_assignment->ToString());
377 }
378
379 if (build_options.device_ordinal() < 0) {
380 build_options.set_device_ordinal(
381 addressable_devices.front()->local_hardware_id());
382 }
383 }
384
385 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
386 computation.GetProgramShape());
387 ExecutionOptions execution_options =
388 CreateExecutionOptions(build_options, &program_shape);
389 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> cpu_executable,
390 JitCompile(computation, argument_layout_pointers,
391 build_options, execution_options));
392 auto cpu_executable_ptr =
393 tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable.get());
394
395 // `buffer_table[result_slice.index()]` points to result buffer:
396 // If output is a tuple, it points to the buffer index table.
397 // If output is a non-tuple, it points to the buffer itself.
398 TF_ASSIGN_OR_RETURN(
399 const BufferAllocation::Slice result_slice,
400 cpu_executable_ptr->buffer_assignment().GetUniqueTopLevelOutputSlice());
401
402 // `result_buffer_indices` has the buffer allocation indices that make up the
403 // output buffer (could be tuple).
404 TF_ASSIGN_OR_RETURN(
405 auto result_buffer_indices,
406 FindResultBufferAllocationIndex(cpu_executable_ptr->buffer_assignment(),
407 cpu_executable->module()));
408
409 auto executable = std::make_unique<TfrtCpuExecutable>(
410 num_replicas, num_partitions, std::move(device_assignment),
411 options.parameter_is_tupled_arguments, std::move(cpu_executable),
412 result_slice.index(), std::move(result_buffer_indices),
413 std::move(addressable_device_logical_ids), std::move(addressable_devices),
414 this);
415 TF_RETURN_IF_ERROR(
416 executable->SetUpDonation(options.parameter_is_tupled_arguments));
417
418 return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
419 }
420
Compile(mlir::ModuleOp module,CompileOptions options)421 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
422 mlir::ModuleOp module, CompileOptions options) {
423 XlaComputation xla_computation;
424 TF_RETURN_IF_ERROR(MlirToXlaComputation(
425 module, xla_computation,
426 /*use_tuple_args=*/options.parameter_is_tupled_arguments,
427 /*return_tuple=*/false));
428 return Compile(xla_computation, options);
429 }
430
AllocateDestinationBuffer(const Shape & on_device_shape,absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>,4> definition_events,TfrtCpuDevice * device,TfrtCpuClient * client)431 StatusOr<std::unique_ptr<TfrtCpuBuffer>> AllocateDestinationBuffer(
432 const Shape& on_device_shape,
433 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
434 TfrtCpuDevice* device, TfrtCpuClient* client) {
435 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
436 if (!on_device_shape.IsTuple()) {
437 size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape);
438 TF_ASSIGN_OR_RETURN(auto device_buffer,
439 MaybeOwningCpuMemory::AllocateShared(byte_size));
440 buffers.push_back(std::move(device_buffer));
441 return std::make_unique<TfrtCpuBuffer>(
442 on_device_shape,
443 std::make_unique<TrackedTfrtCpuDeviceBuffer>(
444 /*is_tuple=*/false, std::move(buffers),
445 std::move(definition_events)),
446 client, device);
447 }
448 // Tuple case.
449 buffers.reserve(on_device_shape.tuple_shapes().size());
450 for (const auto& leaf_shape : on_device_shape.tuple_shapes()) {
451 size_t byte_size = ShapeUtil::ByteSizeOf(leaf_shape);
452 TF_ASSIGN_OR_RETURN(auto device_buffer,
453 MaybeOwningCpuMemory::AllocateShared(byte_size));
454 buffers.push_back(std::move(device_buffer));
455 }
456 return std::make_unique<TfrtCpuBuffer>(
457 on_device_shape,
458 std::make_unique<TrackedTfrtCpuDeviceBuffer>(
459 /*is_tuple=*/true, std::move(buffers), std::move(definition_events)),
460 client, device);
461 }
462
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)463 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateViewOfDeviceBuffer(
464 void* device_ptr, const Shape& shape, PjRtDevice* device,
465 std::function<void()> on_delete_callback) {
466 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
467 size_t byte_size = ShapeUtil::ByteSizeOf(shape);
468 auto non_owning_buffer =
469 std::make_shared<MaybeOwningCpuMemory>(device_ptr, byte_size);
470 buffers.push_back(std::move(non_owning_buffer));
471 auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
472 /*is_tuple=*/false, std::move(buffers),
473 /*definition_event=*/tfrt::MakeAvailableAsyncValueRef<CpuEvent>(),
474 std::move(on_delete_callback));
475 return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
476 shape, std::move(tracked_device_buffer), this,
477 tensorflow::down_cast<TfrtCpuDevice*>(device)));
478 }
479
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)480 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateUninitializedBuffer(
481 const Shape& shape, PjRtDevice* device) {
482 tensorflow::profiler::TraceMe traceme(
483 "TfrtCpuClient::CreateUninitializedBuffer");
484 VLOG(1) << "TfrtCpuClient::CreateUninitializedBuffer: shape: "
485 << shape.DebugString() << " device: " << device->DebugString();
486 return AllocateDestinationBuffer(
487 shape, /*definition_events=*/{},
488 tensorflow::down_cast<TfrtCpuDevice*>(device), this);
489 }
490
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)491 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostBuffer(
492 const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
493 std::optional<absl::Span<int64_t const>> byte_strides,
494 HostBufferSemantics host_buffer_semantics,
495 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
496 tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostBuffer");
497 Shape shape = ShapeUtil::MakeShape(type, dims);
498 VLOG(2) << "TfrtCpuClient::BufferFromHostBuffer: shape: " << shape.ToString()
499 << " device: " << device->DebugString();
500 bool has_default_layout =
501 !byte_strides || HasMajorToMinorLayout(type, dims, *byte_strides);
502 // If the input buffer has a default layout and is sufficiently aligned, we
503 // can simply point to the input array's data without any further copies. At
504 // the time of writing we require a 16-byte alignment because XLA may generate
505 // code which requires it.
506 bool can_use_zero_copy =
507 has_default_layout &&
508 host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
509 ((absl::bit_cast<std::uintptr_t>(data) &
510 (cpu_function_runtime::MinAlign() - 1)) == 0);
511 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers;
512 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
513 std::function<void()> on_delete_callback;
514 size_t byte_size = ShapeUtil::ByteSizeOf(shape);
515 if (can_use_zero_copy) {
516 auto device_buffer = std::make_shared<MaybeOwningCpuMemory>(
517 const_cast<void*>(data), byte_size);
518 buffers.push_back(std::move(device_buffer));
519 on_delete_callback = std::move(on_done_with_host_buffer);
520 } else {
521 TF_ASSIGN_OR_RETURN(auto device_buffer,
522 MaybeOwningCpuMemory::AllocateShared(byte_size));
523 auto dst_data_ptr = device_buffer->data();
524 buffers.push_back(device_buffer);
525 if (!has_default_layout) {
526 // If the input array does not have a major-to-minor layout, transpose it
527 // into major-to-minor layout. Currently we choose to always do this
528 // synchronously.
529 // TODO(phawkins): consider performing the transpose asynchronously.
530 // TODO(phawkins): parallelize the transpose.
531 std::shared_ptr<TransposePlan> transpose;
532 {
533 absl::InlinedVector<int64_t, 4> permutation(dims.size());
534 absl::c_iota(permutation, 0);
535 absl::MutexLock lock(&transpose_mu_);
536 TF_ASSIGN_OR_RETURN(
537 transpose, transpose_cache_.GetOrCreate(
538 primitive_util::ByteWidth(type), dims, permutation,
539 TransposePlan::Striding{*byte_strides}));
540 }
541 transpose->Execute(data, dst_data_ptr);
542 if (on_done_with_host_buffer) {
543 on_done_with_host_buffer();
544 on_done_with_host_buffer = nullptr;
545 }
546 } else {
547 bool should_sync_copy =
548 host_buffer_semantics ==
549 HostBufferSemantics::kImmutableOnlyDuringCall ||
550 (byte_size < kSmallDataTransferByteSize);
551 if (should_sync_copy) {
552 std::memcpy(dst_data_ptr, data, byte_size);
553 if (on_done_with_host_buffer) {
554 on_done_with_host_buffer();
555 on_done_with_host_buffer = nullptr;
556 }
557 } else {
558 tfrt::AsyncValueRef<CpuEvent> copy_event =
559 tfrt::MakeConstructedAsyncValueRef<CpuEvent>(host_ctx_.get());
560 definition_events.push_back(copy_event.CopyRef());
561 tfrt::EnqueueWork(
562 host_ctx_.get(),
563 [device_buffer = std::move(device_buffer), dst_data_ptr, data,
564 byte_size, copy_event = std::move(copy_event),
565 on_done_with_host_buffer =
566 std::move(on_done_with_host_buffer)]() mutable {
567 tensorflow::profiler::TraceMe traceme("H2D Dispatch");
568 std::memcpy(dst_data_ptr, data, byte_size);
569 if (on_done_with_host_buffer) {
570 on_done_with_host_buffer();
571 on_done_with_host_buffer = nullptr;
572 }
573 // Signal copy is complete.
574 copy_event.SetStateConcrete();
575 });
576 }
577 }
578 }
579 auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
580 /*is_tuple=*/false, std::move(buffers), std::move(definition_events),
581 std::move(on_delete_callback));
582 return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
583 shape, std::move(tracked_device_buffer), this,
584 tensorflow::down_cast<TfrtCpuDevice*>(device)));
585 }
586
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)587 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostLiteral(
588 const LiteralSlice& literal, PjRtDevice* device) {
589 tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostLiteral");
590 VLOG(1) << "TfrtCpuClient::BufferFromHostLiteral: shape: "
591 << literal.shape().DebugString()
592 << " device: " << device->DebugString();
593 const Shape& shape = literal.shape();
594
595 // Add a placeholder definition event for each leaf buffer when creating the
596 // buffer. They are set only after h2d dispatch.
597 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
598 absl::InlinedVector<tfrt::RCReference<tfrt::AsyncValue>, 4> avs;
599 int num_leaf_buffers = shape.IsTuple() ? shape.tuple_shapes_size() : 1;
600 for (int i = 0; i < num_leaf_buffers; ++i) {
601 tfrt::AsyncValueRef<CpuEvent> definition_event =
602 tfrt::MakeConstructedAsyncValueRef<CpuEvent>(GetHostContext());
603 definition_events.push_back(definition_event.CopyRef());
604 avs.push_back(std::move(definition_event));
605 }
606 TF_ASSIGN_OR_RETURN(std::unique_ptr<TfrtCpuBuffer> output_buffer,
607 AllocateDestinationBuffer(
608 shape, std::move(definition_events),
609 tensorflow::down_cast<TfrtCpuDevice*>(device), this));
610
611 auto usage_event = tfrt::MakeAvailableAsyncValueRef<CpuEvent>();
612 auto* device_buffer = output_buffer->AcquireUsage(std::move(usage_event));
613 CHECK(device_buffer);
614 if (!shape.IsTuple()) {
615 // It is OK to capture `buffer` pointer because the `output_buffer` can't be
616 // deleted until all the usage holds have gone away.
617 tfrt::EnqueueWork(GetHostContext(), [literal, av = avs[0].CopyRef(),
618 device_buffer, shape]() mutable {
619 tensorflow::profiler::TraceMe traceme("H2D Dispatch");
620 const std::shared_ptr<MaybeOwningCpuMemory>& b =
621 device_buffer->Buffers()[0];
622 CHECK_EQ(literal.size_bytes(), b->size());
623 std::memcpy(b->data(), literal.untyped_data(), b->size());
624 // Signal copy is complete.
625 av->SetStateConcrete();
626 });
627 } else {
628 // For tuple, transfer leaf literal individually in parallel.
629 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
630 // It is OK to capture `buffer` pointer because the `output_buffer` can't
631 // be deleted until all the usage holds have gone away.
632 tfrt::EnqueueWork(GetHostContext(), [i, literal, av = avs[i].CopyRef(),
633 shape, device_buffer]() mutable {
634 tensorflow::profiler::TraceMe traceme("H2D Dispatch");
635 auto slice = LiteralSlice(literal, {i});
636 const std::shared_ptr<MaybeOwningCpuMemory>& b =
637 device_buffer->Buffers()[i];
638 CHECK_EQ(slice.size_bytes(), b->size());
639 std::memcpy(b->data(), slice.untyped_data(), slice.size_bytes());
640 // Signal copy is complete.
641 av->SetStateConcrete();
642 });
643 }
644 }
645 return std::unique_ptr<PjRtBuffer>(std::move(output_buffer));
646 }
647
TfrtCpuBuffer(Shape on_device_shape,std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,TfrtCpuClient * client,TfrtCpuDevice * device)648 TfrtCpuBuffer::TfrtCpuBuffer(
649 Shape on_device_shape,
650 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,
651 TfrtCpuClient* client, TfrtCpuDevice* device)
652 : client_(client),
653 on_device_shape_(std::move(on_device_shape)),
654 device_(device),
655 tracked_device_buffer_(std::move(tracked_device_buffer)) {}
656
~TfrtCpuBuffer()657 TfrtCpuBuffer::~TfrtCpuBuffer() {
658 Delete();
659 CHECK_EQ(external_reference_counter_, 0);
660 }
661
GetOnDeviceSizeInBytes() const662 StatusOr<size_t> TfrtCpuBuffer::GetOnDeviceSizeInBytes() const {
663 return ShapeUtil::ByteSizeOf(on_device_shape_);
664 }
665
666 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()667 TfrtCpuBuffer::AcquireExternalReference() {
668 class ScopedExternalReference : public PjRtBuffer::ExternalReference {
669 public:
670 explicit ScopedExternalReference(TfrtCpuBuffer* buffer,
671 std::shared_ptr<MaybeOwningCpuMemory> data)
672 : buffer_(buffer), data_(std::move(data)) {
673 DCHECK(data_);
674 data_ptr_ = data_->data();
675 }
676
677 ~ScopedExternalReference() override { buffer_->DropExternalReference(); }
678
679 private:
680 TfrtCpuBuffer* buffer_ = nullptr;
681 // Keep a reference to the underlying data used. Note that it is still
682 // users' responsibility to synchronize reads and writes to the data.
683 std::shared_ptr<MaybeOwningCpuMemory> data_;
684 };
685
686 absl::MutexLock lock(&mu_);
687 if (tracked_device_buffer_ == nullptr) {
688 return InvalidArgument("Buffer has been deleted or donated.");
689 }
690
691 ++external_reference_counter_;
692
693 return {std::make_unique<ScopedExternalReference>(
694 this, tracked_device_buffer_->Buffers()[0])};
695 }
696
697 class TrackedCpuDeviceBufferExternalReference
698 : public PjRtBuffer::ExternalReference {
699 public:
TrackedCpuDeviceBufferExternalReference(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer)700 explicit TrackedCpuDeviceBufferExternalReference(
701 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer)
702 : tracked_device_buffer_(std::move(tracked_device_buffer)) {
703 data_ptr_ = tracked_device_buffer_->Buffers()[0]->data();
704 }
705
706 ~TrackedCpuDeviceBufferExternalReference() override = default;
707
708 private:
709 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer_;
710 };
711
712 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)713 TfrtCpuBuffer::ReleaseDeviceMemoryOwnership(
714 bool wait_for_operations_to_complete) {
715 if (on_device_shape_.IsTuple()) {
716 return InvalidArgument(
717 "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
718 }
719 TF_ASSIGN_OR_RETURN(
720 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,
721 Release(wait_for_operations_to_complete));
722
723 std::unique_ptr<PjRtBuffer::ExternalReference> ref;
724 if (tracked_device_buffer) {
725 ref = std::make_unique<TrackedCpuDeviceBufferExternalReference>(
726 std::move(tracked_device_buffer));
727 }
728 return ref;
729 }
730
CommitDonation()731 void TfrtCpuBuffer::CommitDonation() {
732 absl::MutexLock lock(&mu_);
733 CHECK(pending_donation_);
734 CHECK(!tracked_device_buffer_);
735 pending_donation_ = false;
736 }
737
AbortDonation(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer)738 void TfrtCpuBuffer::AbortDonation(
739 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer) {
740 absl::MutexLock lock(&mu_);
741 CHECK(pending_donation_);
742 CHECK(!tracked_device_buffer_);
743 pending_donation_ = false;
744 tracked_device_buffer_ = std::move(device_buffer);
745 }
746
Delete()747 void TfrtCpuBuffer::Delete() {
748 auto device_buffer = ReleaseBufferLocked();
749 if (device_buffer == nullptr) return;
750
751 // Now that all holds have completed and no more can be added, we can get
752 // the final set of usage events.
753 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events =
754 device_buffer->LockUseAndTransferUsageEvents();
755
756 std::vector<tfrt::AsyncValue*> event_avs;
757 event_avs.reserve(usage_events.size() + 1);
758 for (auto& event : usage_events) {
759 event_avs.push_back(event.GetAsyncValue());
760 }
761
762 // We should also wait for the definition event.
763 event_avs.push_back(device_buffer->definition_event().GetAsyncValue());
764
765 tfrt::RunWhenReady(event_avs,
766 [device_buffer = std::move(device_buffer)]() mutable {
767 device_buffer.reset();
768 });
769 }
770
IsDeleted()771 bool TfrtCpuBuffer::IsDeleted() {
772 absl::MutexLock lock(&mu_);
773 return tracked_device_buffer_ == nullptr;
774 }
775
776 std::unique_ptr<TrackedTfrtCpuDeviceBuffer>
ReleaseBufferLocked()777 TfrtCpuBuffer::ReleaseBufferLocked() {
778 absl::MutexLock lock(&mu_);
779 auto condition = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) {
780 return !pending_donation_;
781 };
782 mu_.Await(absl::Condition(&condition));
783 return std::move(tracked_device_buffer_);
784 }
785
Release(bool wait_for_operations_to_complete)786 StatusOr<std::unique_ptr<TrackedTfrtCpuDeviceBuffer>> TfrtCpuBuffer::Release(
787 bool wait_for_operations_to_complete) {
788 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer =
789 ReleaseBufferLocked();
790 if (device_buffer == nullptr) return {nullptr};
791
792 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> events;
793 // Now that all holds have completed and no more can be added, we can get
794 // the final set of usage events.
795 events = device_buffer->LockUseAndTransferUsageEvents();
796
797 if (wait_for_operations_to_complete) {
798 // Block the host until all usage events have completed. Usage events
799 // dominate definition events, so this also waits for the buffer to be
800 // defined. Return the first error encountered.
801 Status first_error;
802 for (const auto& av : events) {
803 client_->GetHostContext()->Await(av.CopyRCRef());
804 if (auto* error = av.GetErrorIfPresent()) {
805 first_error.Update(InternalError("Error Execute: %s", error->message));
806 }
807 }
808 if (!first_error.ok()) return std::move(first_error);
809 }
810
811 return device_buffer;
812 }
813
AcquireUsage(tfrt::AsyncValueRef<CpuEvent> usage_event)814 TrackedTfrtCpuDeviceBuffer* TfrtCpuBuffer::AcquireUsage(
815 tfrt::AsyncValueRef<CpuEvent> usage_event) {
816 absl::MutexLock lock(&mu_);
817 if (!tracked_device_buffer_) {
818 return nullptr;
819 }
820
821 tracked_device_buffer_->AddUsageEvents(absl::MakeSpan(&usage_event, 1));
822 return tracked_device_buffer_.get();
823 }
824
AcquireDonation()825 StatusOr<TfrtCpuBuffer::DonationTransaction> TfrtCpuBuffer::AcquireDonation() {
826 absl::MutexLock lock(&mu_);
827
828 if (tracked_device_buffer_ == nullptr) {
829 return InvalidArgument("Donation requested for invalid buffer");
830 }
831
832 if (external_reference_counter_ > 0) {
833 return InvalidArgument(
834 "Donation requested for buffer with external reference");
835 }
836
837 CHECK(!pending_donation_);
838 pending_donation_ = true;
839
840 // Swap out `tracked_device_buffer_` so that no one can acquire a usage event
841 // after this point.
842 return DonationTransaction(this, std::move(tracked_device_buffer_));
843 }
844
AsShapedBuffer(int device_ordinal,const Shape & on_device_shape,absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffers)845 static ShapedBuffer AsShapedBuffer(
846 int device_ordinal, const Shape& on_device_shape,
847 absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffers) {
848 ShapedBuffer shaped_buffer(on_device_shape, device_ordinal);
849 ShapeTree<se::DeviceMemoryBase>::iterator iterator =
850 shaped_buffer.buffers().begin();
851 for (const auto& buf : buffers) {
852 CHECK(iterator != shaped_buffer.buffers().end());
853 iterator->second = se::DeviceMemoryBase(buf->data(), buf->size());
854 ++iterator;
855 }
856 CHECK(iterator == shaped_buffer.buffers().end());
857 return shaped_buffer;
858 }
859
logical_on_device_shape()860 StatusOr<Shape> TfrtCpuBuffer::logical_on_device_shape() {
861 if (on_device_shape_.is_static()) {
862 return on_device_shape_;
863 }
864
865 auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
866 auto* device_buffer = AcquireUsage(usage_event);
867 if (device_buffer == nullptr) {
868 return InvalidArgument(
869 "logical_on_device_shape() called on deleted or donated buffer");
870 }
871 MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
872
873 // Wait for the definition event.
874 const auto& av = device_buffer->definition_event();
875 client_->GetHostContext()->Await(av.CopyRCRef());
876 if (auto* error = av.GetErrorIfPresent()) {
877 return InternalError("Error Execute: %s", error->message);
878 }
879
880 ShapedBuffer shaped_buffer = AsShapedBuffer(
881 device_->local_hardware_id(), on_device_shape_, device_buffer->Buffers());
882 Shape ret_shape = on_device_shape_;
883 TF_RETURN_IF_ERROR(ReadDynamicShapesOnCpu(
884 &shaped_buffer, &ret_shape, cpu::CpuExecutable::ShapeSizeBytes));
885 return ret_shape;
886 }
887
GetAsyncValues(absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events)888 static std::vector<tfrt::RCReference<tfrt::AsyncValue>> GetAsyncValues(
889 absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events) {
890 std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs;
891 avs.reserve(events.size());
892 for (const auto& ev : events) {
893 avs.push_back(ev.CopyRCRef());
894 }
895 return avs;
896 }
897
CopyAsyncValues(absl::Span<const tfrt::RCReference<tfrt::AsyncValue>> events)898 static std::vector<tfrt::RCReference<tfrt::AsyncValue>> CopyAsyncValues(
899 absl::Span<const tfrt::RCReference<tfrt::AsyncValue>> events) {
900 std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs;
901 avs.reserve(events.size());
902 for (const auto& ev : events) {
903 avs.push_back(ev.CopyRef());
904 }
905 return avs;
906 }
907
908 // Enqueue to TFRT non-blocking work queue when all `values` are ready.
EnqueueWorkWhenReady(tfrt::HostContext * host_ctx,tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values,llvm::unique_function<void ()> callee)909 static void EnqueueWorkWhenReady(
910 tfrt::HostContext* host_ctx,
911 tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values,
912 llvm::unique_function<void()> callee) {
913 tfrt::RunWhenReady(values, [host_ctx, callee = std::move(callee)]() mutable {
914 tfrt::EnqueueWork(host_ctx, std::move(callee));
915 });
916 }
917
ToLiteral(MutableLiteralBase * literal)918 PjRtFuture<Status> TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal) {
919 tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::ToLiteral");
920 if (IsEmptyTuple()) {
921 return PjRtFuture<Status>(
922 InvalidArgument("ToLiteral called on empty tuple"));
923 }
924 auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
925 auto* device_buffer = AcquireUsage(usage_event);
926 if (device_buffer == nullptr) {
927 return PjRtFuture<Status>(InvalidArgument(
928 "CopyToHostAsync() called on deleted or donated buffer"));
929 }
930 MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
931
932 auto host_ctx = client_->GetHostContext();
933
934 std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs = {
935 device_buffer->definition_event().CopyRCRef()};
936 std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs_copy =
937 CopyAsyncValues(device_buffer_wait_avs);
938
939 bool should_sync_copy = device_buffer_wait_avs.empty() &&
940 literal->size_bytes() < kSmallDataTransferByteSize;
941 if (should_sync_copy) {
942 if (!on_device_shape().IsTuple()) {
943 const std::shared_ptr<MaybeOwningCpuMemory>& b =
944 device_buffer->Buffers()[0];
945 std::memcpy(literal->untyped_data(), b->data(), b->size());
946 } else {
947 // Tuple case.
948 int num_leaves = literal->shape().tuple_shapes().size();
949 for (int i = 0; i < num_leaves; ++i) {
950 const std::shared_ptr<MaybeOwningCpuMemory>& b =
951 device_buffer->Buffers()[i];
952 std::memcpy(literal->untyped_data({i}), b->data(), b->size());
953 }
954 }
955 // Unblock ToLiteral caller.
956 return PjRtFuture<Status>(OkStatus());
957 } else {
958 auto ready_event = tfrt::MakeUnconstructedAsyncValueRef<Status>();
959 // Wait for buffer definition events to finish before d2h dispatch. D2H
960 // dispatch should be in parallel, e.g. one Execute event finish may trigger
961 // multiple outputs' D2H, they should happen in different threads in
962 // parallel.
963 EnqueueWorkWhenReady(
964 host_ctx, device_buffer_wait_avs,
965 [this, device_buffer_wait_avs = std::move(device_buffer_wait_avs_copy),
966 literal, ready_event = ready_event.CopyRef(), device_buffer,
967 ready_on_exit = std::move(ready_on_exit)]() mutable {
968 tensorflow::profiler::TraceMe traceme("D2H Dispatch");
969 // Errors in src buffer are surfaced to user.
970 for (const auto& av : device_buffer_wait_avs) {
971 if (auto* error = av->GetErrorIfPresent()) {
972 ready_event.emplace(
973 Internal("Error converting to literal: %s", error->message));
974 return;
975 }
976 }
977
978 if (!on_device_shape().IsTuple()) {
979 const std::shared_ptr<MaybeOwningCpuMemory>& b =
980 device_buffer->Buffers()[0];
981 std::memcpy(literal->untyped_data(), b->data(), b->size());
982 } else {
983 // Tuple case.
984 int num_leaves = literal->shape().tuple_shapes().size();
985 for (int i = 0; i < num_leaves; ++i) {
986 const std::shared_ptr<MaybeOwningCpuMemory>& b =
987 device_buffer->Buffers()[i];
988 std::memcpy(literal->untyped_data({i}), b->data(), b->size());
989 }
990 }
991
992 // Unblock ToLiteral event.
993 ready_event.emplace(OkStatus());
994 });
995 return PjRtFuture<Status>(
996 std::move(ready_event),
997 /*on_block_start=*/
998 []() {
999 tensorflow::profiler::TraceMeProducer traceme(
1000 "TfrtCpuBuffer::ToLiteral");
1001 VLOG(1) << "TfrtCpuBuffer::ToLiteral";
1002 return PjRtFutureHelpers::ProfilingKeys(
1003 {/*traceme_context_id =*/traceme.GetContextId()});
1004 },
1005 /*on_block_end=*/
1006 [](PjRtFutureHelpers::ProfilingKeys keys) {
1007 tensorflow::profiler::TraceMeConsumer traceme(
1008 "TfrtCpuBuffer::ToLiteral", keys.traceme_context_id);
1009 });
1010 }
1011 }
1012
1013 // TODO(zhangqiaorjc): Consider disallowing multiple CPU devices and assign
1014 // multiple pmap replicas to the same CPU device for multi-CPU pmap testing.
CopyToDevice(PjRtDevice * dst_device)1015 StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuBuffer::CopyToDevice(
1016 PjRtDevice* dst_device) {
1017 tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::CopyToDevice");
1018 // TODO(zhangqiaorjc): Remove this restriction after removing the test that
1019 // explicitly asserts this.
1020 if (dst_device == device_) {
1021 return InvalidArgument(
1022 "CopyToDevice cannot accept the same source and destination devices");
1023 }
1024
1025 // Copying across PjRtClients involves a copy through the host.
1026 if (dst_device->client() != client_) {
1027 TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteralSync());
1028 // Avoid use-after-free on `literal` due to unsequenced move and use.
1029 Literal* literal_pointer = literal.get();
1030 absl::InlinedVector<int64_t, 4> byte_strides(
1031 literal->shape().dimensions_size());
1032 TF_RETURN_IF_ERROR(
1033 ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1034 return dst_device->client()->BufferFromHostBuffer(
1035 literal_pointer->untyped_data(),
1036 literal_pointer->shape().element_type(),
1037 literal_pointer->shape().dimensions(), byte_strides,
1038 TfrtCpuClient::HostBufferSemantics::kZeroCopy,
1039 [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1040 }
1041
1042 // Copy each leaf buffer to a destination buffer.
1043 auto usage_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
1044 auto* src_device_buffer = AcquireUsage(usage_event);
1045 if (src_device_buffer == nullptr) {
1046 return InvalidArgument("CopyToDevice called on deleted or donated buffer");
1047 }
1048 MarkEventReadyOnExit ready_on_exit(std::move(usage_event));
1049
1050 int num_leaf_buffers = src_device_buffer->Buffers().size();
1051 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> src_buffers;
1052 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> dst_buffers;
1053 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> dst_definition_events;
1054 src_buffers.reserve(num_leaf_buffers);
1055 dst_buffers.reserve(num_leaf_buffers);
1056 dst_definition_events.reserve(num_leaf_buffers);
1057
1058 for (int i = 0; i < num_leaf_buffers; ++i) {
1059 auto src_buffer = src_device_buffer->Buffers()[i];
1060 TF_ASSIGN_OR_RETURN(auto dst_buffer, MaybeOwningCpuMemory::AllocateShared(
1061 src_buffer->size()));
1062 src_buffers.push_back(std::move(src_buffer));
1063 dst_buffers.push_back(std::move(dst_buffer));
1064 dst_definition_events.push_back(
1065 tfrt::MakeConstructedAsyncValueRef<CpuEvent>());
1066 }
1067
1068 // Wait for src buffer definition events to finish before d2d dispatch.
1069 // Errors are propagated asynchronously in dst buffer's definition events.
1070 const auto& src_definition_event = src_device_buffer->definition_event();
1071
1072 auto copy_task = [num_leaf_buffers, src_buffers = std::move(src_buffers),
1073 dst_buffers_copies = dst_buffers, dst_definition_events,
1074 src_definition_event,
1075 ready_on_exit = std::move(ready_on_exit)]() mutable {
1076 tensorflow::profiler::TraceMe traceme("D2D Dispatch");
1077 if (auto* error = src_definition_event.GetErrorIfPresent()) {
1078 for (int i = 0; i < num_leaf_buffers; ++i) {
1079 // Any error discovered in src buffer are propagated to dst buffer
1080 // definition events, which will surface to users in
1081 // dst_buffer->ToLiteral().
1082 dst_definition_events[i].SetError(*error);
1083 }
1084 return;
1085 }
1086
1087 for (int i = 0; i < num_leaf_buffers; ++i) {
1088 std::memcpy(dst_buffers_copies[i]->data(), src_buffers[i]->data(),
1089 src_buffers[i]->size());
1090 dst_definition_events[i].SetStateConcrete();
1091 }
1092 };
1093
1094 src_definition_event.AndThen([host_ctx = client()->GetHostContext(),
1095 copy_task = std::move(copy_task)]() mutable {
1096 tfrt::EnqueueWork(host_ctx, std::move(copy_task));
1097 });
1098
1099 return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>(
1100 on_device_shape_,
1101 std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1102 on_device_shape_.IsTuple(), std::move(dst_buffers),
1103 std::move(dst_definition_events)),
1104 client(), tensorflow::down_cast<TfrtCpuDevice*>(dst_device)));
1105 }
1106
GetReadyFuture()1107 PjRtFuture<Status> TfrtCpuBuffer::GetReadyFuture() {
1108 tfrt::AsyncValueRef<CpuEvent> definition_event;
1109 {
1110 absl::MutexLock lock(&mu_);
1111 if (!tracked_device_buffer_) {
1112 return PjRtFuture<Status>(InvalidArgument(
1113 "GetReadyFuture() called on deleted or donated buffer"));
1114 }
1115 definition_event = tracked_device_buffer_->definition_event();
1116 }
1117 DCHECK(definition_event);
1118
1119 if (definition_event.IsAvailable()) {
1120 if (definition_event.IsError()) {
1121 return PjRtFuture<Status>(FailedPrecondition(
1122 "Buffer Definition Event: %s", definition_event.GetError().message));
1123 }
1124 return PjRtFuture<Status>(OkStatus());
1125 } else {
1126 tfrt::AsyncValueRef<Status> status_event =
1127 tfrt::MakeUnconstructedAsyncValueRef<Status>();
1128
1129 definition_event.AndThen(
1130 [definition_event = definition_event.AsPtr(), status_event]() {
1131 if (definition_event.IsError()) {
1132 status_event.emplace(
1133 FailedPrecondition("Buffer Definition Event: %s",
1134 definition_event.GetError().message));
1135 } else {
1136 status_event.emplace(OkStatus());
1137 }
1138 });
1139
1140 return PjRtFuture<Status>(
1141 std::move(status_event),
1142 /*on_block_start=*/
1143 []() {
1144 tensorflow::profiler::TraceMeProducer traceme("TfrtCpuBuffer::Await");
1145 VLOG(1) << "TfrtCpuBuffer::Await";
1146 return PjRtFutureHelpers::ProfilingKeys(
1147 {/*traceme_context_id=*/traceme.GetContextId()});
1148 },
1149 /*on_block_end=*/
1150 [](PjRtFutureHelpers::ProfilingKeys keys) {
1151 tensorflow::profiler::TraceMeConsumer traceme(
1152 "TfrtCpuBuffer::Await", keys.traceme_context_id);
1153 });
1154 }
1155 }
1156
TfrtCpuExecutable(int num_replicas,int num_partitions,std::shared_ptr<DeviceAssignment> device_assignment,bool parameter_is_tupled_arguments,std::unique_ptr<Executable> cpu_executable,BufferAllocation::Index result_buffer_index,absl::InlinedVector<BufferAllocation::Index,4> result_buffer_indices,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,TfrtCpuClient * client)1157 TfrtCpuExecutable::TfrtCpuExecutable(
1158 int num_replicas, int num_partitions,
1159 std::shared_ptr<DeviceAssignment> device_assignment,
1160 bool parameter_is_tupled_arguments,
1161 std::unique_ptr<Executable> cpu_executable,
1162 BufferAllocation::Index result_buffer_index,
1163 absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices,
1164 std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1165 std::vector<PjRtDevice*> addressable_devices, TfrtCpuClient* client)
1166 : client_(client),
1167 num_replicas_(num_replicas),
1168 num_partitions_(num_partitions),
1169 device_assignment_(std::move(device_assignment)),
1170 parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1171 cpu_executable_(std::move(cpu_executable)),
1172 result_buffer_index_(result_buffer_index),
1173 result_buffer_indices_(std::move(result_buffer_indices)),
1174 addressable_device_logical_ids_(
1175 std::move(addressable_device_logical_ids)),
1176 addressable_devices_(std::move(addressable_devices)) {
1177 auto hlo_cost_analysis =
1178 std::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes);
1179 // Cache to avoid std::map lookup in flop_count() on critical path.
1180 // The magic constant 1000 is determined by correlating computation with flop
1181 // estimate. It is a crude heuristic to find computation less than the thread
1182 // context switch time (~5us).
1183 cheap_computation_ = hlo_cost_analysis->flop_count() < 1000;
1184
1185 const auto& computation_layout =
1186 cpu_executable_->module().entry_computation_layout();
1187 if (computation_layout.parameter_count() == 0) {
1188 return;
1189 }
1190 // Assume compiled program expects either many non-tupled arguments or a
1191 // singled tupled argument. Nested tuple is not yet supported.
1192 if (computation_layout.parameter_count() > 1 ||
1193 !computation_layout.parameter_shape(0).IsTuple()) {
1194 input_buffer_sizes_in_bytes_.reserve(computation_layout.parameter_count());
1195 for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1196 input_buffer_sizes_in_bytes_.push_back(
1197 ShapeUtil::ByteSizeOf(computation_layout.parameter_shape(i)));
1198 }
1199 } else {
1200 input_buffer_sizes_in_bytes_.reserve(
1201 computation_layout.parameter_shape(0).tuple_shapes_size());
1202 for (int i = 0;
1203 i < computation_layout.parameter_shape(0).tuple_shapes_size(); ++i) {
1204 input_buffer_sizes_in_bytes_.push_back(ShapeUtil::ByteSizeOf(
1205 computation_layout.parameter_shape(0).tuple_shapes(i)));
1206 }
1207 }
1208 }
1209
Delete()1210 void TfrtCpuExecutable::Delete() {}
1211
IsDeleted()1212 bool TfrtCpuExecutable::IsDeleted() { return false; }
1213
Fingerprint() const1214 StatusOr<std::optional<std::string>> TfrtCpuExecutable::Fingerprint() const {
1215 return std::optional<std::string>();
1216 }
1217
SetUpDonation(bool tuple_inputs)1218 Status TfrtCpuExecutable::SetUpDonation(bool tuple_inputs) {
1219 TF_ASSIGN_OR_RETURN(parameters_that_must_be_donated_,
1220 ComputeParametersThatMustBeDonated(
1221 *cpu_executable_->shared_module(), tuple_inputs));
1222 return OkStatus();
1223 }
1224
1225 // The following few helpers are adapted from XLA:CPU to create a buffer table
1226 // and assemble the buffer pointers in order to call into CpuExecutable.
MemoryForAllocation(const BufferAllocation & allocation,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1227 static StatusOr<std::shared_ptr<MaybeOwningCpuMemory>> MemoryForAllocation(
1228 const BufferAllocation& allocation,
1229 absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1230 if (allocation.is_entry_computation_parameter()) {
1231 TrackedTfrtCpuDeviceBuffer* arg = arguments[allocation.parameter_number()];
1232 std::shared_ptr<MaybeOwningCpuMemory> out =
1233 arg->Buffer(allocation.param_shape_index());
1234 CHECK_EQ(allocation.size(), out->size())
1235 << "Size mismatch on param " << allocation.parameter_number()
1236 << " at shape index " << allocation.param_shape_index().ToString();
1237 return out;
1238 } else if (allocation.is_constant()) {
1239 return std::make_shared<MaybeOwningCpuMemory>();
1240 } else if (allocation.is_thread_local()) {
1241 return std::make_shared<MaybeOwningCpuMemory>();
1242 }
1243
1244 // Output and temporary buffer.
1245 int64_t buffer_size = allocation.size();
1246 TF_ASSIGN_OR_RETURN(auto out,
1247 MaybeOwningCpuMemory::AllocateShared(buffer_size));
1248
1249 // Since the output buffer and all the temporary buffers were written into
1250 // by the JITed code, msan has no way of knowing their memory was
1251 // initialized. Mark them initialized so that msan doesn't flag loads from
1252 // these buffers.
1253 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(out->data(), buffer_size);
1254 return out;
1255 }
1256
1257 static StatusOr<std::vector<std::shared_ptr<MaybeOwningCpuMemory>>>
CreateBufferTable(const BufferAssignment & assignment,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1258 CreateBufferTable(const BufferAssignment& assignment,
1259 absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1260 std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffers(
1261 assignment.Allocations().size());
1262 for (BufferAllocation::Index i = 0; i < assignment.Allocations().size();
1263 ++i) {
1264 const BufferAllocation& allocation = assignment.GetAllocation(i);
1265 TF_ASSIGN_OR_RETURN(buffers[i], MemoryForAllocation(allocation, arguments));
1266 }
1267 return std::move(buffers);
1268 }
1269
1270 static absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4>
CreateResultShapedBuffer(absl::Span<const BufferAllocation::Index> buffer_indices,absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,absl::Span<TrackedTfrtCpuDeviceBuffer * const> arguments)1271 CreateResultShapedBuffer(
1272 absl::Span<const BufferAllocation::Index> buffer_indices,
1273 absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,
1274 absl::Span<TrackedTfrtCpuDeviceBuffer* const> arguments) {
1275 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> output_buffers;
1276 output_buffers.reserve(buffer_indices.size());
1277 for (int i = 0; i < buffer_indices.size(); ++i) {
1278 output_buffers.push_back(buffer_table[buffer_indices[i]]);
1279 }
1280 return output_buffers;
1281 }
1282
CheckBufferCompatibilities(absl::Span<TrackedTfrtCpuDeviceBuffer * const> input_buffers) const1283 Status TfrtCpuExecutable::CheckBufferCompatibilities(
1284 absl::Span<TrackedTfrtCpuDeviceBuffer* const> input_buffers) const {
1285 if (input_buffers.size() != input_buffer_sizes_in_bytes_.size()) {
1286 return InvalidArgument(
1287 "Execution supplied %lld buffers but compiled program expected %lld "
1288 "buffers",
1289 input_buffers.size(), input_buffer_sizes_in_bytes_.size());
1290 }
1291 for (int i = 0; i < input_buffers.size(); ++i) {
1292 const auto& buffer = input_buffers[i];
1293 if (input_buffer_sizes_in_bytes_[i] != buffer->Buffers()[0]->size()) {
1294 return InvalidArgument(
1295 "Executable expected parameter %d of size %lld but got buffer with "
1296 "incompatible size %lld",
1297 i, input_buffer_sizes_in_bytes_[i], buffer->Buffers()[0]->size());
1298 }
1299 }
1300 return OkStatus();
1301 }
1302
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,bool fill_future,TfrtCpuDevice * device)1303 StatusOr<PjRtLoadedExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
1304 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1305 const RunId& run_id, const ExecuteOptions& options,
1306 tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,
1307 bool fill_future, TfrtCpuDevice* device) {
1308 tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteHelper");
1309 auto* host_context = client_->GetHostContext();
1310
1311 std::shared_ptr<DeviceAssignment> device_assignment;
1312 if (device == nullptr) {
1313 CHECK(device_assignment_ != nullptr);
1314 const int device_id = (*device_assignment_)(replica, partition);
1315 TF_ASSIGN_OR_RETURN(PjRtDevice * pjrt_device,
1316 client_->LookupDevice(device_id));
1317 device = tensorflow::down_cast<TfrtCpuDevice*>(pjrt_device);
1318 device_assignment = device_assignment_;
1319 } else {
1320 CHECK(device_assignment_ == nullptr);
1321 CHECK_EQ(replica, 0);
1322 CHECK_EQ(partition, 0);
1323 CHECK(addressable_devices_.empty());
1324 device_assignment = std::make_shared<DeviceAssignment>(1, 1);
1325 (*device_assignment)(0, 0) = device->id();
1326 }
1327 CHECK_EQ(device->process_index(), client_->process_index());
1328
1329 // Handle inputs.
1330 if (options.arguments_are_tupled) {
1331 if (!parameter_is_tupled_arguments_) {
1332 return InvalidArgument(
1333 "Arguments may only be supplied as a tuple when the executable was "
1334 "compiled with a single tupled parameter");
1335 }
1336 if (argument_handles.size() != 1) {
1337 return InvalidArgument(
1338 "Option arguments_are_tupled was true but %d buffers were passed to "
1339 "execution",
1340 argument_handles.size());
1341 }
1342 }
1343
1344 // `execute_event` indicates whether cpu computation is complete and whether
1345 // there was an error.
1346 auto execute_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
1347 MarkEventReadyOnExit ready_on_exit(execute_event);
1348
1349 absl::InlinedVector<TfrtCpuBuffer::DonationTransaction, 4>
1350 donation_transactions;
1351 absl::InlinedVector<TrackedTfrtCpuDeviceBuffer*, 4> tracked_buffers;
1352 tracked_buffers.reserve(argument_handles.size());
1353 // To avoid clobbering inputs, we must ensure that
1354 // `extra_deps` = inputs' definition events + donated inputs' usage events.
1355 // This also ensures that the returned `execute_event` dominates all inputs'
1356 // events, and thus output buffer only need to contain `execute_event` as the
1357 // single definition event.
1358 std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps;
1359 input_deps.reserve(argument_handles.size());
1360
1361 auto donate_it = parameters_that_must_be_donated_.begin();
1362
1363 for (int i = 0; i < argument_handles.size(); ++i) {
1364 PjRtBuffer* handle = argument_handles[i];
1365 auto* tfrt_buffer = tensorflow::down_cast<TfrtCpuBuffer*>(handle);
1366 if (tfrt_buffer->device() != device) {
1367 return InvalidArgument(
1368 "Buffer passed to Execute() as argument %d to replica %d is on "
1369 "device %s, but replica is assigned to device %s.",
1370 i, replica, tfrt_buffer->device()->DebugString(),
1371 device->DebugString());
1372 }
1373
1374 bool must_donate =
1375 donate_it != parameters_that_must_be_donated_.end() && *donate_it == i;
1376 TrackedTfrtCpuDeviceBuffer* tracked_buffer = nullptr;
1377 if (must_donate) {
1378 ++donate_it;
1379 TF_ASSIGN_OR_RETURN(auto donation_transaction,
1380 tfrt_buffer->AcquireDonation());
1381
1382 // After acquiring the buffer for donation, we retrieve the dependent
1383 // usage events. Note that we don't need any locking here as
1384 // AcquireDonation() is supposed to synchronize with other usages.
1385 for (const auto& ev :
1386 donation_transaction.device_buffer()->UsageEvents()) {
1387 if (!ev.IsAvailable()) {
1388 input_deps.push_back(ev.CopyRCRef());
1389 }
1390 }
1391 tracked_buffer = donation_transaction.device_buffer();
1392 tracked_buffers.push_back(tracked_buffer);
1393 donation_transactions.push_back(std::move(donation_transaction));
1394
1395 } else {
1396 tracked_buffer = tfrt_buffer->AcquireUsage(execute_event);
1397 if (!tracked_buffer)
1398 return InvalidArgument(
1399 "Invalid buffer passed: buffer has been deleted or donated.");
1400 tracked_buffers.push_back(tracked_buffer);
1401 }
1402
1403 // Definition events are never modified after buffer construction.
1404 const auto& definition_event = tracked_buffer->definition_event();
1405 if (!definition_event.IsAvailable()) {
1406 input_deps.push_back(definition_event.CopyRCRef());
1407 }
1408 }
1409
1410 TF_RETURN_IF_ERROR(CheckBufferCompatibilities(tracked_buffers));
1411
1412 // Tuplize the inputs if compiler expects a single tuple argument but runtime
1413 // gets many inputs that are not yet tupled.
1414 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tuplized_arg;
1415 if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1416 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> leaf_buffers;
1417 leaf_buffers.reserve(tracked_buffers.size());
1418 for (const auto& tracked_buffer : tracked_buffers) {
1419 auto span = tracked_buffer->Buffers();
1420 leaf_buffers.insert(leaf_buffers.end(), span.begin(), span.end());
1421 }
1422
1423 // Tuplize into a single input.
1424 tracked_buffers.clear();
1425 tuplized_arg = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1426 /*is_tuple=*/true, std::move(leaf_buffers),
1427 /*definition_event=*/tfrt::MakeAvailableAsyncValueRef<CpuEvent>());
1428 tracked_buffers.push_back(tuplized_arg.get());
1429 }
1430
1431 auto* cpu_executable =
1432 tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable_.get());
1433 TF_ASSIGN_OR_RETURN(
1434 std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffer_table,
1435 CreateBufferTable(cpu_executable->buffer_assignment(), tracked_buffers));
1436 auto result_buffers = CreateResultShapedBuffer(result_buffer_indices_,
1437 buffer_table, tracked_buffers);
1438
1439 // The choice of where we wait is arbitrary; the reason for the wait is
1440 // pacing to avoid problems such as memory fragmentation and running ahead
1441 // too far, not for correctness. Placing it before the executable launch
1442 // allows the inputs for the next executable to be fetched even if the
1443 // launch is delayed.
1444 auto compute_reservation = std::make_unique<Semaphore::ScopedReservation>(
1445 device->max_inflight_computations_semaphore().ScopedAcquire(1));
1446
1447 // Call the computation function following the calling convention.
1448 std::vector<void*> buffer_pointers;
1449 buffer_pointers.reserve(buffer_table.size());
1450 for (const auto& buffer : buffer_table) {
1451 buffer_pointers.push_back(buffer->data());
1452 }
1453 void* result_buffer = buffer_pointers[result_buffer_index_];
1454
1455 ExecutableRunOptions run_options;
1456 run_options.set_run_id(run_id);
1457 run_options.set_device_ordinal(device->local_hardware_id());
1458 // Need to keep device_assignment alive until execution completes.
1459 run_options.set_device_assignment(device_assignment.get());
1460 run_options.set_intra_op_thread_pool(client_->eigen_intraop_device());
1461
1462 // Schedule only one collective at a time.
1463 bool is_a_collective_launch = !!last_collective_launch_event;
1464 if (is_a_collective_launch) {
1465 input_deps.push_back(std::move(last_collective_launch_event));
1466 }
1467
1468 bool execute_inline = cheap_computation_;
1469
1470 // Overwrite `execute_inline` if it is specified in the ExecuteOptions.
1471 if (options.execution_mode == ExecuteOptions::ExecutionMode::kAsynchronous) {
1472 execute_inline = false;
1473 } else if (options.execution_mode ==
1474 ExecuteOptions::ExecutionMode::kSynchronous) {
1475 execute_inline = true;
1476 }
1477
1478 if (input_deps.empty() && execute_inline) {
1479 // Synchronously call generated function.
1480
1481 // Set denormal and rounding behavior to match the default TF
1482 // ThreadPool behavior.
1483 tensorflow::port::ScopedFlushDenormal flush;
1484 tensorflow::port::ScopedSetRound round(FE_TONEAREST);
1485
1486 XlaCustomCallStatus status;
1487
1488 // Call generated function.
1489 cpu_executable->compute_function()(result_buffer, &run_options, nullptr,
1490 buffer_pointers.data(), &status,
1491 nullptr);
1492
1493 for (auto& donation_transaction : donation_transactions) {
1494 std::move(donation_transaction).Commit();
1495 }
1496
1497 std::optional<absl::string_view> error_message =
1498 xla::CustomCallStatusGetMessage(&status);
1499 if (error_message) {
1500 return InternalError("Generated function failed: %s", *error_message);
1501 }
1502
1503 } else {
1504 // TODO(zhangqiaorjc): Only async launch expensive computations. Need
1505 // heuristics to decide what computation is expensive.
1506 // Asynchronously call generated function.
1507
1508 // We only created enough threads for one collective to complete.
1509 // The next collective launch will not be scheduled onto threadpool until
1510 // this one completes.
1511 if (is_a_collective_launch) {
1512 client_->SetLastCollectiveLaunchEvent(execute_event.CopyRef());
1513 }
1514 std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps_avs_copy =
1515 CopyAsyncValues(input_deps);
1516 EnqueueWorkWhenReady(
1517 host_context, input_deps,
1518 [cpu_executable, result_buffer,
1519 buffer_pointers = std::move(buffer_pointers),
1520 buffer_table = std::move(buffer_table),
1521 run_options = std::move(run_options),
1522 cpu_executable_copy = cpu_executable_,
1523 device_assignment = std::move(device_assignment),
1524 compute_reservation = std::move(compute_reservation),
1525 tuplized_arg = std::move(tuplized_arg),
1526 donation_transactions = std::move(donation_transactions),
1527 execute_event = std::move(ready_on_exit).Release(),
1528 input_deps_avs = std::move(input_deps_avs_copy)]() mutable {
1529 for (const auto& av : input_deps_avs) {
1530 if (auto* error = av->GetErrorIfPresent()) {
1531 execute_event.SetError(absl::StrCat(
1532 "Error dispatching computation: %s", error->message));
1533 return;
1534 }
1535 }
1536
1537 // Set denormal and rounding behavior to match the default TF
1538 // ThreadPool behavior.
1539 tensorflow::port::ScopedFlushDenormal flush;
1540 tensorflow::port::ScopedSetRound round(FE_TONEAREST);
1541
1542 XlaCustomCallStatus status;
1543
1544 // Call generated function.
1545 cpu_executable->compute_function()(result_buffer, &run_options,
1546 nullptr, buffer_pointers.data(),
1547 &status, nullptr);
1548
1549 std::optional<absl::string_view> error_message =
1550 xla::CustomCallStatusGetMessage(&status);
1551
1552 for (auto& donation_transaction : donation_transactions) {
1553 std::move(donation_transaction).Commit();
1554 }
1555
1556 if (error_message) {
1557 // CPU computation fails with an error.
1558 execute_event.SetError(absl::StrFormat(
1559 "Generated function failed: %s", *error_message));
1560 return;
1561 }
1562
1563 // CPU computation completes.
1564 execute_event.SetStateConcrete();
1565 });
1566 }
1567
1568 // Create output TFRT buffers.
1569 const Shape& result_shape = cpu_executable_->result_shape();
1570 std::vector<std::unique_ptr<PjRtBuffer>> res;
1571 if (options.untuple_result && result_shape.IsTuple()) {
1572 res.reserve(result_buffers.size());
1573 for (int i = 0; i < result_buffers.size(); ++i) {
1574 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> sub_buffer;
1575 sub_buffer.push_back(std::move(result_buffers[i]));
1576 // Program execution writes to output buffers so it's a definition event.
1577 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events;
1578 definition_events.push_back(execute_event.CopyRef());
1579 auto leaf_tracked_device_buffer =
1580 std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1581 /*is_tuple=*/false, std::move(sub_buffer),
1582 std::move(definition_events));
1583 auto leaf_buffer = std::make_unique<TfrtCpuBuffer>(
1584 result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer),
1585 client_, device);
1586 res.push_back(std::move(leaf_buffer));
1587 }
1588 } else {
1589 // Program execution writes to output buffers so it's a definition event.
1590 auto tracked_device_buffer = std::make_unique<TrackedTfrtCpuDeviceBuffer>(
1591 /*is_tuple=*/result_shape.IsTuple(), std::move(result_buffers),
1592 /*definition_event=*/execute_event);
1593 auto tfrt_output_buffer = std::make_unique<TfrtCpuBuffer>(
1594 result_shape, std::move(tracked_device_buffer), client_, device);
1595 res.push_back(std::move(tfrt_output_buffer));
1596 }
1597 std::optional<PjRtFuture<Status>> future;
1598 if (fill_future) {
1599 auto done_event = tfrt::MakeUnconstructedAsyncValueRef<Status>();
1600 execute_event.AndThen(
1601 [done_event = done_event.CopyRef(), event = execute_event.CopyRef()]() {
1602 Status s;
1603 if (auto* error = event.GetErrorIfPresent()) {
1604 s = InternalError("Compute error: %s", error->message);
1605 }
1606 done_event.emplace(std::move(s));
1607 });
1608 future = PjRtFuture<Status>(std::move(done_event));
1609 }
1610 return Result({/*future=*/std::move(future), /*buffers=*/std::move(res)});
1611 }
1612
1613 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)1614 TfrtCpuExecutable::Execute(
1615 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1616 const ExecuteOptions& options,
1617 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
1618 tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::Execute");
1619 if (device_assignment_ == nullptr) {
1620 return InvalidArgument("Execute expects a non-null device_assignment");
1621 }
1622
1623 RunId run_id;
1624 tensorflow::profiler::TraceMeProducer activity(
1625 "TfrtCpuExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
1626 run_id.ToInt());
1627
1628 const int num_addressable_devices = addressable_devices_.size();
1629
1630 if (argument_handles.size() != num_addressable_devices) {
1631 return InvalidArgument(
1632 "Attempted to execute with %d argument lists when local device "
1633 "count is %d (total replica count: %d, partition count: %d)",
1634 argument_handles.size(), num_addressable_devices, num_replicas(),
1635 num_partitions());
1636 }
1637
1638 VLOG(1) << "Executing computation " << name()
1639 << "; num_replicas=" << num_replicas()
1640 << " num_partitions=" << num_partitions()
1641 << " num_addressable_devices=" << num_addressable_devices;
1642
1643 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
1644 num_addressable_devices);
1645 if (returned_futures.has_value()) {
1646 returned_futures->resize(num_addressable_devices);
1647 }
1648 if (num_addressable_devices == 1) {
1649 // Fast-path if there is only one device — run the computation on the
1650 // current thread.
1651 const int replica = addressable_device_logical_ids_[0].replica;
1652 const int partition = addressable_device_logical_ids_[0].partition;
1653
1654 auto statusor = ExecuteHelper(
1655 argument_handles[0], replica, partition, run_id, options,
1656 /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>(),
1657 returned_futures.has_value());
1658
1659 if (!statusor.ok()) {
1660 return std::move(statusor).status();
1661 }
1662
1663 wrapped_results[0] = std::move(statusor->buffers);
1664 if (returned_futures.has_value()) {
1665 (*returned_futures)[0] = std::move(*statusor->future);
1666 }
1667
1668 } else {
1669 // Gang schedule collectives to ensure that collectives with the same RunId
1670 // are run at the same time. We conservatively run only one collective at a
1671 // time, because we may not have enough threads to run arbitrary number of
1672 // collectives concurrently.
1673 tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event =
1674 client_->GetLastCollectiveLaunchEvent();
1675
1676 absl::Mutex mu;
1677 int running = num_addressable_devices;
1678 int failed = 0;
1679 Status first_failure_status;
1680
1681 for (int i = 0; i < num_addressable_devices; ++i) {
1682 const int replica = addressable_device_logical_ids_[i].replica;
1683 const int partition = addressable_device_logical_ids_[i].partition;
1684 tfrt::EnqueueWork(client_->GetHostContext(), [&, replica, partition, i] {
1685 auto statusor =
1686 ExecuteHelper(argument_handles[i], replica, partition, run_id,
1687 options, last_collective_launch_event.CopyRef(),
1688 returned_futures.has_value());
1689 if (statusor.ok()) {
1690 wrapped_results[i] = std::move(statusor->buffers);
1691 if (returned_futures.has_value()) {
1692 (*returned_futures)[i] = std::move(*statusor->future);
1693 }
1694 }
1695
1696 absl::MutexLock lock(&mu);
1697 --running;
1698 if (!statusor.ok()) {
1699 if (failed == 0) {
1700 first_failure_status = AppendStatus(
1701 std::move(statusor).status(),
1702 absl::StrFormat(
1703 "while running replica %d and partition %d of a "
1704 "replicated computation (other "
1705 "replicas may have failed as well).",
1706 replica, partition));
1707 }
1708 ++failed;
1709 }
1710 });
1711 }
1712
1713 {
1714 auto done_running = [&]() {
1715 mu.AssertHeld();
1716 return running == 0;
1717 };
1718 absl::MutexLock lock(&mu);
1719 mu.Await(absl::Condition(&done_running));
1720 }
1721
1722 if (!first_failure_status.ok()) return first_failure_status;
1723 }
1724 VLOG(1) << "Replicated execution complete.";
1725
1726 return wrapped_results;
1727 }
1728
1729 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)1730 TfrtCpuExecutable::ExecuteSharded(
1731 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1732 const ExecuteOptions& options,
1733 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
1734 tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteSharded");
1735 if (device_assignment_ == nullptr) {
1736 return InvalidArgument("ExecuteShard expects a non-null device_assignment");
1737 }
1738 for (int i = 0; i < addressable_devices_.size(); ++i) {
1739 if (addressable_devices_[i] == device) {
1740 VLOG(1) << "ExecuteShard executes computation " << name()
1741 << " on assigned replica/partition on device "
1742 << device->DebugString();
1743 TF_ASSIGN_OR_RETURN(
1744 auto result,
1745 ExecuteHelper(
1746 argument_handles, addressable_device_logical_ids_[i].replica,
1747 addressable_device_logical_ids_[i].partition, RunId(), options,
1748 /*last_collective_launch_event=*/
1749 tfrt::AsyncValueRef<CpuEvent>(), fill_future));
1750 returned_future = std::move(result.future);
1751 return std::move(result.buffers);
1752 }
1753 }
1754 return InvalidArgument(
1755 "ExecuteShard attempted to execute on device id %d which is not "
1756 "addressable by this client",
1757 device->id());
1758 }
1759
1760 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)1761 TfrtCpuExecutable::ExecutePortable(
1762 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1763 const ExecuteOptions& options,
1764 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
1765 tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecutePortable");
1766 if (device_assignment_ != nullptr) {
1767 return InvalidArgument("ExecutePortable gets a non-portable executable");
1768 }
1769 if (num_replicas() != 1 || num_partitions() != 1) {
1770 return InvalidArgument(
1771 "ExecutePortable expects a single-core executable but gets "
1772 "one with %d replica %d partition",
1773 num_replicas(), num_partitions());
1774 }
1775 if (device == nullptr) {
1776 return InvalidArgument("ExecutePortable expects a device to be specified");
1777 }
1778 VLOG(1) << "ExecutePortable executes single-core portable executable "
1779 << name();
1780 TF_ASSIGN_OR_RETURN(
1781 auto result,
1782 ExecuteHelper(
1783 argument_handles,
1784 /*replica=*/0,
1785 /*partition=*/0, RunId(), options,
1786 /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>(),
1787 fill_future, tensorflow::down_cast<TfrtCpuDevice*>(device)));
1788 returned_future = std::move(result.future);
1789 return std::move(result.buffers);
1790 }
1791 } // namespace xla
1792