1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66
67 #include <algorithm>
68 #include <cstddef>
69 #include <cstdlib>
70 #include <functional>
71 #include <memory>
72 #include <optional>
73 #include <string>
74 #include <utility>
75 #include <vector>
76
77 #include "absl/algorithm/container.h"
78 #include "absl/base/casts.h"
79 #include "absl/container/flat_hash_set.h"
80 #include "absl/container/inlined_vector.h"
81 #include "absl/strings/str_format.h"
82 #include "absl/synchronization/mutex.h"
83 #include "absl/time/time.h"
84 #include "absl/types/span.h"
85 #include "tensorflow/compiler/xla/client/local_client.h"
86 #include "tensorflow/compiler/xla/client/xla_computation.h"
87 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
88 #include "tensorflow/compiler/xla/executable_run_options.h"
89 #include "tensorflow/compiler/xla/layout.h"
90 #include "tensorflow/compiler/xla/literal.h"
91 #include "tensorflow/compiler/xla/literal_util.h"
92 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
93 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
94 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
95 #include "tensorflow/compiler/xla/pjrt/metrics.h"
96 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
97 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
98 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
99 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
100 #include "tensorflow/compiler/xla/pjrt/utils.h"
101 #include "tensorflow/compiler/xla/service/computation_layout.h"
102 #include "tensorflow/compiler/xla/service/executable.h"
103 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
104 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
105 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
106 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
107 #include "tensorflow/compiler/xla/service/transfer_manager.h"
108 #include "tensorflow/compiler/xla/shape.h"
109 #include "tensorflow/compiler/xla/shape_util.h"
110 #include "tensorflow/compiler/xla/util.h"
111 #include "tensorflow/compiler/xla/xla_data.pb.h"
112 #include "tensorflow/core/platform/cpu_info.h"
113 #include "tensorflow/core/platform/env.h"
114 #include "tensorflow/core/platform/errors.h"
115 #include "tensorflow/core/platform/fingerprint.h"
116 #include "tensorflow/core/platform/mem.h"
117 #include "tensorflow/core/platform/status.h"
118 #include "tensorflow/core/platform/statusor.h"
119 #include "tensorflow/core/profiler/lib/connected_traceme.h"
120 #include "tensorflow/core/profiler/lib/traceme.h"
121 #include "tensorflow/core/profiler/lib/traceme_encode.h"
122 #include "tensorflow/stream_executor/device_memory.h"
123 #include "tensorflow/stream_executor/device_memory_allocator.h"
124 #include "tensorflow/stream_executor/event.h"
125 #include "tensorflow/stream_executor/host/host_platform_id.h"
126 #include "tensorflow/stream_executor/lib/statusor.h"
127 #include "tensorflow/stream_executor/stream.h"
128
129 namespace xla {
130
platform_id() const131 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
132 return client_->platform_id();
133 }
platform_name() const134 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
135 return client_->platform_name();
136 }
137
GetLocalDeviceState() const138 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
139 const {
140 if (local_device_state_) {
141 return local_device_state_.get();
142 }
143 return InvalidArgument("Device %s is not a local device.", DebugString());
144 }
145
DebugString() const146 absl::string_view PjRtStreamExecutorDevice::DebugString() const {
147 return debug_string_;
148 }
149
ToString() const150 absl::string_view PjRtStreamExecutorDevice::ToString() const {
151 return to_string_;
152 }
153
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)154 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
155 absl::Span<const std::vector<PjRtDevice*>> devices) {
156 if (devices.empty()) {
157 return InvalidArgument(
158 "Device assignment passed to Compile() must be non-empty.");
159 }
160 if (devices[0].empty()) {
161 return InvalidArgument(
162 "Device assignment passed to Compile() must have a nonzero number of "
163 "partitions per replica; replica 0 had 0 partitions.");
164 }
165 DeviceAssignment xla_assignment(devices.size(), devices[0].size());
166 for (int replica = 0; replica < devices.size(); ++replica) {
167 if (devices[replica].size() != devices[0].size()) {
168 return InvalidArgument(
169 "Device assignment passed to Compile() has different numbers of "
170 "partitions between replicas; %d partitions for replica %d versus %d "
171 "partitions for replica 0.",
172 devices[replica].size(), replica, devices[0].size());
173 }
174 for (int partition = 0; partition < devices[replica].size(); ++partition) {
175 if (devices[0][0]->client()->platform_id() !=
176 devices[replica][partition]->client()->platform_id()) {
177 return InvalidArgument(
178 "Device assignment passed to Compile() must have devices of a "
179 "single kind, got %s for replica 0 partition 0 and %s for replica "
180 "%d partition %d.",
181 devices[0][0]->client()->platform_name(),
182 devices[replica][partition]->client()->platform_name(), replica,
183 partition);
184 }
185 xla_assignment(replica, partition) = devices[replica][partition]->id();
186 }
187 }
188 return xla_assignment;
189 }
190
191 class CpuAllocator : public tensorflow::Allocator {
192 public:
193 CpuAllocator() = default;
194
Name()195 std::string Name() override { return "cpu"; }
196
AllocateRaw(size_t alignment,size_t num_bytes)197 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
198 return tensorflow::port::AlignedMalloc(num_bytes, alignment);
199 }
DeallocateRaw(void * ptr)200 void DeallocateRaw(void* ptr) override {
201 return tensorflow::port::AlignedFree(ptr);
202 }
203 };
204
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)205 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
206 std::string platform_name, LocalClient* client,
207 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
208 int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
209 std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
210 bool should_stage_host_to_device_transfers,
211 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
212 : platform_id_(tensorflow::Fingerprint64(platform_name)),
213 platform_name_(std::move(platform_name)),
214 client_(client),
215 host_memory_allocator_(std::move(host_memory_allocator)),
216 owned_allocator_(std::move(allocator)),
217 owned_devices_(std::move(devices)),
218 process_index_(process_index),
219 should_stage_host_to_device_transfers_(
220 should_stage_host_to_device_transfers),
221 gpu_run_options_(std::move(gpu_run_options)),
222 thread_pool_(
223 tensorflow::Env::Default(), "pjrt_thread_pool",
224 std::max<int>(DefaultThreadPoolSize(), client->device_count())),
225 transpose_cache_(1024) {
226 if (owned_allocator_ != nullptr) {
227 allocator_ = owned_allocator_.get();
228 } else {
229 allocator_ = client_->backend().memory_allocator();
230 }
231
232 if (!host_memory_allocator_) {
233 host_memory_allocator_ = std::make_unique<CpuAllocator>();
234 }
235
236 for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
237 owned_devices_) {
238 devices_.push_back(device.get());
239 CHECK(id_to_device_.insert({device->id(), device.get()}).second)
240 << "Duplicate device id: " << device->id();
241
242 if (device->IsAddressable()) {
243 addressable_devices_.push_back(device.get());
244 }
245 device->SetClient(this);
246 }
247 // TODO(phawkins): we don't really promise anything about the order of
248 // these devices, but users may be depending on the current order. Sort into
249 // device ordinal order, which is the historical order these values have
250 // appeared.
251 absl::c_sort(addressable_devices_,
252 [](const PjRtDevice* a, const PjRtDevice* b) {
253 return a->local_hardware_id() < b->local_hardware_id();
254 });
255 }
256
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const257 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
258 int num_replicas, int num_partitions) const {
259 return client_->backend().computation_placer()->AssignDevices(num_replicas,
260 num_partitions);
261 }
262
263 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()264 PjRtStreamExecutorClient::GetHloCostAnalysis() {
265 return std::make_unique<HloCostAnalysis>(
266 client_->backend().compiler()->ShapeSizeBytesFunction());
267 }
268
269 namespace {
270
271 // Ensures that it is safe to deallocate any buffers that have been enqueued in
272 // an operation on stream. Called only in rare error cases that are triggered
273 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)274 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
275 switch (local_device->allocation_model()) {
276 case LocalDeviceState::kAsynchronous:
277 // We can safely deallocate any dangling buffers immediately. NOTE: this
278 // assumes that any buffers enqueued on stream are local to stream's
279 // executor, and manual action may be needed if that condition is not met.
280 break;
281
282 case LocalDeviceState::kComputeSynchronized:
283 // This will stall computation but that's ok in this very rare error
284 // case.
285 if (stream != local_device->compute_stream()) {
286 local_device->compute_stream()->ThenWaitFor(stream);
287 }
288 break;
289
290 case LocalDeviceState::kSynchronous:
291 // This will stall the calling thread but that's ok in this very rare
292 // error case. If the stall fails just crash, since we have no other
293 // way to synchronize.
294 TF_CHECK_OK(stream->BlockHostUntilDone());
295 break;
296 }
297 }
298
299 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
300 // a stream, to ensure that the buffer will be kept alive until its use on that
301 // stream is complete.
302 //
303 // device_buffer: the buffer that was enqueued.
304 // buffer_local_device: the device the buffer was allocated on.
305 // stream_local_device: the device that manages usage_stream.
306 // event: an event that was recorded on usage_stream
307 // after the usage of device_buffer was enqueued.
308 // usage_stream: the stream the operation using device_buffer
309 // was enqueued on.
310 // prefer_to_retain_reference: relevant only for the compute synchronous
311 // allocation model. If true, retain a reference
312 // to device_buffer until after the operation
313 // completes. If false then the compute stream
314 // will have to be synchronized past event before
315 // device_buffer can be freed.
316 //
317 // prefer_to_retain_reference encodes a heuristic set by the caller for the
318 // compute synchronous model:
319 //
320 // Generally when a buffer is the destination of a copy to a device, it will
321 // subsequently be used on the device's compute stream before being freed. In
322 // that case, there is no need to retain a reference to the buffer. If the
323 // buffer is freed before being used on the compute stream, the free will be
324 // delayed until the host knows that event has completed, but this is expected
325 // to be uncommon.
326 //
327 // When a buffer is the source of a copy from a device, we need to either retain
328 // a reference to the buffer until the copy completes or serialize the compute
329 // stream behind the copy. It is often better to retain a reference since while
330 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference,std::vector<std::shared_ptr<TrackedDeviceBuffer>> * buffers_to_release=nullptr)331 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
332 LocalDeviceState* buffer_local_device,
333 LocalDeviceState* stream_local_device,
334 std::shared_ptr<BufferSequencingEvent> event,
335 se::Stream* usage_stream, bool prefer_to_retain_reference,
336 std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
337 buffers_to_release = nullptr) {
338 tensorflow::profiler::TraceMe traceme("RecordUsage");
339 bool retain_buffer_until_completion =
340 // If the buffer wasn't allocated on the same device as the stream, always
341 // retain a reference.
342 (stream_local_device != buffer_local_device) ||
343 // In the synchronous allocation model, always retain a reference.
344 (stream_local_device->allocation_model() ==
345 LocalDeviceState::kSynchronous) ||
346 // In the compute synchronous model, use the caller's heuristic.
347 (stream_local_device->allocation_model() ==
348 LocalDeviceState::kComputeSynchronized &&
349 prefer_to_retain_reference);
350 if (retain_buffer_until_completion) {
351 if (buffers_to_release) {
352 buffers_to_release->push_back(device_buffer.buffer());
353 } else {
354 buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
355 }
356 }
357 device_buffer.ConvertUsageHold(usage_stream, event,
358 retain_buffer_until_completion);
359 }
360
361 // Allocates the device buffers for a buffer that will be used as the
362 // destination of a copy, either from the host or another device. copy_stream
363 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
364 // buffer is a tuple then the tuple tables are allocated, and all necessary
365 // synchronization for them is dealt with, before the buffer is returned.
366 //
367 // It is safe to delete the returned PjRtBuffer without further
368 // synchronization if an error occurs before the buffer is used.
369 //
370 // The caller may optionally provide a definition event to be recorded in
371 // the buffer.
372 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)373 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
374 const Shape& on_host_shape, PjRtDevice* device,
375 LocalDeviceState* local_device, se::Stream* copy_stream,
376 bool is_uninitialized_create, PjRtClient* client,
377 std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
378 if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
379 return InvalidArgument("Can't make a buffer from an empty tuple");
380 }
381
382 auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
383 TransferManager* transfer_manager =
384 se_client->client()->backend().transfer_manager();
385 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
386 transfer_manager->AllocateScopedShapedBuffer(
387 on_host_shape, se_client->allocator(),
388 local_device->device_ordinal()));
389 if (local_device->allocation_model() ==
390 LocalDeviceState::kComputeSynchronized) {
391 if (copy_stream == nullptr) {
392 CHECK(is_uninitialized_create);
393 } else {
394 copy_stream->ThenWaitFor(local_device->compute_stream());
395 }
396 } else {
397 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
398 local_device->compute_stream()->parent(), dst_buffer));
399 }
400 Shape on_device_shape = dst_buffer.on_device_shape();
401
402 absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
403 definition_events;
404 if (is_uninitialized_create) {
405 // There is not going to be any copy into the buffer so in general we don't
406 // need a definition event.
407 if (local_device->allocation_model() ==
408 LocalDeviceState::kComputeSynchronized) {
409 // The allocation is not valid until the compute stream passes this point,
410 // so add a definition event in the compute stream.
411 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
412 TF_ASSIGN_OR_RETURN(EventPool::Handle event,
413 local_device->event_pool().ThenAllocateAndRecordEvent(
414 local_device->compute_stream()));
415 definition_events.back()->SetSequencingEvent(
416 std::move(event), local_device->compute_stream());
417 }
418 // if the caller provided a definition event then we record that.
419 if (definition_event) {
420 definition_events.emplace_back(definition_event);
421 }
422 } else {
423 // We have at least one definition event, for the copy completing to
424 // the device buffers.
425 if (definition_event) {
426 definition_events.emplace_back(definition_event);
427 } else {
428 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
429 }
430 }
431 se::Stream* tuple_table_stream = local_device->host_to_device_stream();
432 if (on_device_shape.IsTuple()) {
433 // We also need to copy the tuple tables, so we'll have an additional
434 // definition event for that copy to complete.
435 if (tuple_table_stream != copy_stream) {
436 if (local_device->allocation_model() ==
437 LocalDeviceState::kComputeSynchronized) {
438 tuple_table_stream->ThenWaitFor(local_device->compute_stream());
439 } else {
440 DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
441 local_device->compute_stream()->parent(), dst_buffer));
442 }
443 }
444
445 TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
446 tuple_table_stream, dst_buffer));
447 // CAUTION: From this point onwards we need to be careful about returning
448 // from error cases because we have started a transfer and must not allow
449 // dst_buffer to be freed too soon in the non-async allocation models.
450
451 definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
452 StatusOr<EventPool::Handle> event_or =
453 local_device->event_pool().ThenAllocateAndRecordEvent(
454 tuple_table_stream);
455 if (!event_or.ok()) {
456 StallStreamOnError(local_device, tuple_table_stream);
457 return event_or.status();
458 }
459 definition_events.back()->SetSequencingEvent(std::move(event_or).value(),
460 tuple_table_stream);
461 }
462 std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
463 TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
464 definition_events);
465
466 auto py_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
467 on_device_shape, std::move(dst_device_buffer), client, device);
468
469 if (on_device_shape.IsTuple()) {
470 // Add a usage hold for the tuple table write and immediately convert it to
471 // the appropriate form of synchronization. prefer_to_retain_reference=false
472 // means don't retain a memory reference until the transfer is complete when
473 // using the ComputeSynchronized allocation model. This is a heuristic
474 // because in the common case destination buffers will be used on the
475 // compute stream and therefore don't require any synchronization before
476 // being freed. If the buffer is allocated and never used, the free will
477 // take longer and this is assumed to be ok.
478 RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
479 definition_events.back(), tuple_table_stream,
480 /*prefer_to_retain_reference=*/false);
481 }
482
483 return py_buffer;
484 }
485
486 // Adds necessary synchronization after a copy has been enqueued to a buffer.
487 // definition_event was added when the buffer was allocated, but has not yet
488 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)489 Status AddDestinationBufferSynchronization(
490 LocalDeviceState* local_device,
491 PjRtStreamExecutorBuffer::ScopedHold device_buffer,
492 std::shared_ptr<BufferSequencingEvent> definition_event,
493 se::Stream* copy_stream) {
494 StatusOr<EventPool::Handle> event_or =
495 local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
496 if (!event_or.ok()) {
497 StallStreamOnError(local_device, copy_stream);
498 return event_or.status();
499 }
500 definition_event->SetSequencingEvent(std::move(event_or).value(),
501 copy_stream);
502 // prefer_to_retain_reference=false means don't retain a memory reference
503 // until the transfer is complete when using the ComputeSynchronized
504 // allocation model. This is a heuristic because in the common case
505 // destination buffers will be used on the compute stream and therefore don't
506 // require any synchronization before being freed. If the buffer is allocated
507 // and never used, the free will take longer and this is assumed to be ok.
508 RecordUsage(std::move(device_buffer), local_device, local_device,
509 definition_event, copy_stream,
510 /*prefer_to_retain_reference=*/false);
511 return OkStatus();
512 }
513
514 } // namespace
515
~ScopedHold()516 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
517 if (ok()) {
518 parent_->DropHold(type_, buffer().get());
519 }
520 }
521
ScopedHold(ScopedHold && other)522 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
523 : parent_(other.parent_),
524 type_(other.type_),
525 state_(other.state_),
526 status_(std::move(other.status_)),
527 buffer_(std::move(other.buffer_)) {
528 // Preserve the invariant that status is invalid if buffer == nullptr.
529 other.SetState(kMoved);
530 }
531
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)532 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
533 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
534 CHECK(!ok());
535 if (buffer_or.ok()) {
536 buffer_ = buffer_or.ValueOrDie();
537 SetState(kValid);
538 } else {
539 status_ = buffer_or.status();
540 buffer_ = nullptr;
541 SetState(kError);
542 }
543 // Check the invariant holds.
544 CHECK(!ok() || buffer_ != nullptr);
545 }
546
547 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()548 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
549 CHECK(ok());
550 ForClosure for_closure(parent_, type_, state_, std::move(status_),
551 std::move(buffer_));
552 SetState(kReleased);
553 return for_closure;
554 }
555
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)556 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
557 se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
558 bool reference_held) {
559 CHECK(ok());
560 CHECK_EQ(type_, kUsage);
561 parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
562 reference_held);
563 SetState(kConverted);
564 }
565
ConfirmDonation()566 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
567 CHECK(ok());
568 CHECK_EQ(type_, kDonation);
569 parent_->ConfirmDonation(buffer().get());
570 SetState(kDonated);
571 }
572
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const573 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
574 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
575 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
576 ExecutionInput* execution_input,
577 se::DeviceMemoryAllocator* allocator) const {
578 CHECK(ok());
579 if (type_ == kDonation) {
580 buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
581 } else {
582 CHECK_EQ(type_, kUsage);
583 buffer()->AddToInputAsImmutable(iterator, end);
584 }
585 }
586
IsOnCpu() const587 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
588 return client()->platform_id() == CpuId();
589 }
590
logical_on_device_shape()591 StatusOr<Shape> PjRtStreamExecutorBuffer::logical_on_device_shape() {
592 if (on_device_shape_.is_static()) {
593 return on_device_shape_;
594 }
595 auto* local_device = device_->local_device_state();
596 auto* stream = local_device->GetDeviceToHostStream();
597 ScopedHold device_buffer(this, ScopedHold::kUsage);
598 {
599 absl::MutexLock lock(&mu_);
600 // We can't perform any other action while a donation hold is in progress.
601 WaitForOutstandingDonationHold();
602 if (device_buffer_ == nullptr) {
603 return InvalidArgument(
604 "logical_on_device_shape() called on deleted or donated buffer");
605 }
606 AcquireHoldLocked(&device_buffer);
607 }
608
609 WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
610 ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
611 StatusOr<EventPool::Handle> event_or =
612 local_device->event_pool().AllocateEvent(stream->parent());
613 if (!event_or.ok()) {
614 return event_or.status();
615 }
616 Shape ret_shape = on_device_shape_;
617 TransferManager* transfer_manager =
618 client_->client()->backend().transfer_manager();
619 TF_RETURN_IF_ERROR(
620 transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
621 return ret_shape;
622 }
623
624 namespace {
625
626 // Implements PjRtBuffer::ExternalReference as a wrapped
627 // ScopedHold::kExternalReference.
628 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
629 public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)630 explicit ScopedHoldAsExternalReference(
631 PjRtStreamExecutorBuffer::ScopedHold hold)
632 : external_reference_(std::move(hold)) {
633 CHECK(external_reference_.type() ==
634 PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
635 data_ptr_ = external_reference_->device_memory().front().opaque();
636 }
637
638 ~ScopedHoldAsExternalReference() override = default;
639
640 private:
641 PjRtStreamExecutorBuffer::ScopedHold external_reference_;
642 };
643
644 } // namespace
645
646 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()647 PjRtStreamExecutorBuffer::AcquireExternalReference() {
648 ScopedHold hold = GetBufferWithExternalReference();
649 Status hold_status = hold.status();
650 if (!hold_status.ok()) return hold_status;
651 return std::unique_ptr<ExternalReference>(
652 std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
653 }
654
655 class TrackedDeviceBufferExternalReference
656 : public PjRtBuffer::ExternalReference {
657 public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)658 explicit TrackedDeviceBufferExternalReference(
659 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
660 : tracked_device_buffer_(std::move(tracked_device_buffer)) {
661 data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
662 }
663
664 ~TrackedDeviceBufferExternalReference() override = default;
665
666 private:
667 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
668 };
669
670 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)671 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
672 bool wait_for_operations_to_complete) {
673 if (on_device_shape_.IsTuple()) {
674 return InvalidArgument(
675 "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
676 }
677 TF_ASSIGN_OR_RETURN(
678 std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
679 Release(wait_for_operations_to_complete));
680
681 std::unique_ptr<PjRtBuffer::ExternalReference> ref;
682 if (tracked_device_buffer) {
683 ref = std::make_unique<TrackedDeviceBufferExternalReference>(
684 std::move(tracked_device_buffer));
685 }
686 return ref;
687 }
688
689 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)690 PjRtStreamExecutorClient::BufferFromHostBuffer(
691 const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
692 std::optional<absl::Span<int64_t const>> byte_strides,
693 HostBufferSemantics host_buffer_semantics,
694 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
695 tensorflow::profiler::TraceMe traceme(
696 "PjRtStreamExecutorClient::BufferFromHostBuffer");
697 Shape shape = ShapeUtil::MakeShape(type, dims);
698 VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
699 << shape.ToString() << " device: " << device->DebugString();
700 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
701 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
702 ->GetLocalDeviceState());
703
704 absl::InlinedVector<int64_t, 4> tmp_strides;
705 if (!byte_strides) {
706 tmp_strides.resize(dims.size());
707 TF_RETURN_IF_ERROR(
708 ShapeUtil::ByteStrides(shape, absl::MakeSpan(tmp_strides)));
709 byte_strides = tmp_strides;
710 }
711 int64_t size = ShapeUtil::ByteSizeOf(shape);
712
713 TransferManager* transfer_manager = client()->backend().transfer_manager();
714 TF_ASSIGN_OR_RETURN(Shape compact_shape,
715 transfer_manager->ChooseCompactLayoutForShape(shape));
716 absl::InlinedVector<int64_t, 4> compact_shape_strides(
717 compact_shape.dimensions_size());
718 TF_RETURN_IF_ERROR(ShapeUtil::ByteStrides(
719 compact_shape, absl::MakeSpan(compact_shape_strides)));
720 bool host_and_device_strides_equal =
721 (size == 0 || *byte_strides == compact_shape_strides);
722 // The CPU platform is special because the "host" and the "device" are in the
723 // same memory space. If the input shape is in the correct layout and we don't
724 // want to defer the copy onto a thread, we can use the following fast
725 // path.
726 bool is_cpu_platform =
727 local_device->executor()->platform()->id() == se::host::kHostPlatformId;
728 if (is_cpu_platform) {
729 // If we are on the host platform and the input buffer is sufficiently
730 // aligned, we can simply point to the input array's data without any
731 // further copies. At the time of writing we require a 16-byte alignment
732 // because XLA may generate code which requires it.
733 bool can_use_zero_copy =
734 host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
735 ((absl::bit_cast<std::uintptr_t>(data) &
736 (cpu_function_runtime::MinAlign() - 1)) == 0);
737 if (host_and_device_strides_equal &&
738 (host_buffer_semantics ==
739 HostBufferSemantics::kImmutableOnlyDuringCall ||
740 can_use_zero_copy)) {
741 std::function<void()> on_delete_callback;
742 se::DeviceMemoryBase buffer;
743 // If we are on the host platform and the input buffer is sufficiently
744 // aligned, we can simply point to the input array's data without any
745 // further copies. At the time of writing we require a 16-byte alignment
746 // because XLA may generate code which requires it.
747 if (can_use_zero_copy) {
748 on_delete_callback = std::move(on_done_with_host_buffer);
749 buffer = se::DeviceMemoryBase(
750 const_cast<void*>(static_cast<const void*>(data)), size);
751 } else {
752 void* staging_buffer = host_memory_allocator()->AllocateRaw(
753 cpu_function_runtime::MinAlign(), size);
754 buffer = se::DeviceMemoryBase(staging_buffer, size);
755 std::memcpy(staging_buffer, data, size);
756 if (on_done_with_host_buffer) {
757 on_done_with_host_buffer();
758 }
759 on_delete_callback = [staging_buffer, host_memory_allocator =
760 host_memory_allocator()]() {
761 host_memory_allocator->DeallocateRaw(staging_buffer);
762 };
763 }
764 absl::Span<const std::shared_ptr<BufferSequencingEvent>>
765 definition_events;
766 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
767 /*allocator=*/nullptr, local_device->device_ordinal(),
768 std::initializer_list<se::DeviceMemoryBase>{buffer},
769 definition_events, std::move(on_delete_callback));
770 return std::unique_ptr<PjRtBuffer>(
771 std::make_unique<PjRtStreamExecutorBuffer>(
772 shape, std::move(device_buffer), this, device));
773 }
774 }
775
776 TF_ASSIGN_OR_RETURN(
777 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
778 AllocateDestinationBuffer(compact_shape, device, local_device,
779 local_device->host_to_device_stream(),
780 /*is_uninitialized_create=*/false, this));
781
782 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
783 py_buffer->GetBufferWithUsageHold());
784 CHECK(device_buffer.ok());
785
786 // If necessary, allocate a host-side buffer for staging host-to-device
787 // transfers. On GPU this is a buffer in pinned memory.
788 std::shared_ptr<void> staging_buffer;
789 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
790 should_stage_host_to_device_transfers() ||
791 !host_and_device_strides_equal) {
792 void* ptr = host_memory_allocator()->AllocateRaw(
793 tensorflow::Allocator::kAllocatorAlignment, size);
794 staging_buffer = std::shared_ptr<void>(
795 ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
796 host_memory_allocator->DeallocateRaw(ptr);
797 });
798 }
799
800 std::shared_ptr<TransposePlan> transpose;
801 if (!host_and_device_strides_equal) {
802 absl::InlinedVector<int64_t, 4> permutation(dims.size());
803 absl::c_reverse_copy(compact_shape.layout().minor_to_major(),
804 permutation.begin());
805 absl::MutexLock lock(&transpose_mu_);
806 TF_ASSIGN_OR_RETURN(transpose,
807 transpose_cache_.GetOrCreate(
808 primitive_util::ByteWidth(type), dims, permutation,
809 TransposePlan::Striding{*byte_strides}));
810 }
811
812 // Copy the buffer into a staging buffer before returning control to the
813 // caller if the caller only guaranteed that the buffer is valid for the
814 // duration of the call. Otherwise, we stage (if necessary) on a separate
815 // thread.
816 if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
817 if (transpose) {
818 transpose->Execute(data, staging_buffer.get());
819 } else {
820 std::memcpy(staging_buffer.get(), data, size);
821 }
822 if (on_done_with_host_buffer) {
823 on_done_with_host_buffer();
824 on_done_with_host_buffer = nullptr;
825 }
826 }
827
828 // The host to device transfer is performed on a thread pool, mostly because
829 // it includes linearization that may be slow. It is OK to capture the
830 // py_buffer pointer because the py_buffer can't be deleted until all the
831 // usage holds have gone away.
832 // TODO(misard) assess if it would be preferable to introduce a heuristic to
833 // put the transfer into the calling thread for small literals.
834 auto transfer_h2d =
835 [local_client = client(), transfer_manager, local_device, data, size,
836 movable_device_buffer{device_buffer.ToClosure()}, shape,
837 py_buffer{py_buffer.get()},
838 on_device_shape{py_buffer->on_device_shape()},
839 staging_buffer{std::move(staging_buffer)},
840 on_done_with_host_buffer{std::move(on_done_with_host_buffer)},
841 host_buffer_semantics, transpose{std::move(transpose)}]() {
842 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
843 movable_device_buffer);
844 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
845 // to report failures from a callback. However, the operations here are
846 // unlikely to fail and not recoverable even if we were to fail: DMAs to
847 // memory that has already been allocated, and a possible Event
848 // allocation.
849
850 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
851 // If applicable on the backend, stage the transfer via host memory
852 // allocated via the host_memory_allocator. On GPU, this is pinned
853 // memory.
854 if (staging_buffer) {
855 // If we didn't already copy the input buffer into the staging buffer,
856 // do so now.
857 if (host_buffer_semantics !=
858 HostBufferSemantics::kImmutableOnlyDuringCall) {
859 if (transpose) {
860 transpose->Execute(data, staging_buffer.get());
861 } else {
862 std::memcpy(staging_buffer.get(), data, size);
863 }
864 }
865 // The buffer has the same dimension order as the on-device shape, but
866 // is not tiled, etc.
867 BorrowingLiteral literal(
868 static_cast<const char*>(staging_buffer.get()),
869 ShapeUtil::DeviceShapeToHostShape(on_device_shape));
870 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
871 local_device->host_to_device_stream(), literal, buffer));
872 } else {
873 BorrowingLiteral literal(
874 reinterpret_cast<const char*>(data),
875 ShapeUtil::DeviceShapeToHostShape(on_device_shape));
876 // Otherwise, just transfer the literal.
877 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
878 local_device->host_to_device_stream(), literal, buffer));
879 }
880
881 std::shared_ptr<BufferSequencingEvent> event =
882 device_buffer->definition_events()[0];
883 TF_CHECK_OK(AddDestinationBufferSynchronization(
884 local_device, std::move(device_buffer), event,
885 local_device->host_to_device_stream()));
886
887 local_device->ThenExecuteCallback(
888 local_device->host_to_device_stream(),
889 [staging_buffer{std::move(staging_buffer)},
890 on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
891 if (on_done_with_host_buffer) {
892 on_done_with_host_buffer();
893 }
894 });
895 };
896 if (is_cpu_platform) {
897 // Using the thread_pool would be a double thread hop; the code
898 // already defers its work onto a stream (= thread on CPU).
899 transfer_h2d();
900 } else {
901 thread_pool()->Schedule(transfer_h2d);
902 }
903 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
904 }
905
906 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)907 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
908 PjRtDevice* device) {
909 return CreateUninitializedBuffer(shape, device, nullptr);
910 }
911
912 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)913 PjRtStreamExecutorClient::CreateUninitializedBuffer(
914 const Shape& shape, PjRtDevice* device,
915 std::shared_ptr<BufferSequencingEvent> definition_event) {
916 tensorflow::profiler::TraceMe traceme(
917 "PjRtStreamExecutorClient::CreateUninitializedBuffer");
918 VLOG(1) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
919 << shape.ToString() << " device: " << device->DebugString();
920 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
921 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
922 ->GetLocalDeviceState());
923
924 TransferManager* transfer_manager = client()->backend().transfer_manager();
925 TF_ASSIGN_OR_RETURN(Shape compact_shape,
926 transfer_manager->ChooseCompactLayoutForShape(shape));
927
928 TF_ASSIGN_OR_RETURN(
929 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
930 AllocateDestinationBuffer(compact_shape, device, local_device,
931 /*copy_stream=*/nullptr,
932 /*is_uninitialized_create=*/true, this,
933 definition_event));
934 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
935 }
936
937 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)938 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
939 PjRtDevice* device) {
940 tensorflow::profiler::TraceMe traceme(
941 "PjRtStreamExecutorClient::BufferFromHostLiteral");
942 VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
943 << literal.shape().ToString() << " device: " << device->DebugString();
944 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
945 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
946 ->GetLocalDeviceState());
947
948 TransferManager* transfer_manager = client()->backend().transfer_manager();
949 TF_ASSIGN_OR_RETURN(
950 Shape compact_shape,
951 transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
952 TF_ASSIGN_OR_RETURN(
953 std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
954 AllocateDestinationBuffer(compact_shape, device, local_device,
955 local_device->host_to_device_stream(),
956 /*is_uninitialized_create=*/false, this));
957
958 PjRtStreamExecutorBuffer::ScopedHold device_buffer(
959 py_buffer->GetBufferWithUsageHold());
960 CHECK(device_buffer.ok());
961
962 // The host to device transfer is performed on a thread pool, mostly because
963 // it includes linearization that may be slow. It is OK to capture the
964 // py_buffer pointer because the py_buffer can't be deleted until all the
965 // usage holds have gone away.
966 // TODO(misard) assess if it would be preferable to introduce a heuristic to
967 // put the transfer into the calling thread for small literals.
968 auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
969 movable_device_buffer{device_buffer.ToClosure()},
970 literal, py_buffer{py_buffer.get()},
971 on_device_shape{py_buffer->on_device_shape()}]() {
972 PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
973 // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
974 // to report failures from a callback. However, the operations here are
975 // unlikely to fail and not recoverable even if we were to fail: DMAs to
976 // memory that has already been allocated, and a possible Event
977 // allocation.
978
979 se::Stream* h2d_stream = local_device->host_to_device_stream();
980 ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
981 TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
982 h2d_stream, literal, buffer));
983
984 std::shared_ptr<BufferSequencingEvent> event =
985 device_buffer->definition_events()[0];
986 TF_CHECK_OK(AddDestinationBufferSynchronization(
987 local_device, std::move(device_buffer), event, h2d_stream));
988
989 // This can sometimes catch the case where the literal memory has been
990 // freed before the H2D transfer was issued.
991 h2d_stream->RefreshStatus()
992 .IgnoreError(); // Can return error::Unimplemented
993 QCHECK(h2d_stream->ok());
994 };
995 thread_pool()->Schedule(transfer_h2d);
996 return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
997 }
998
999 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)1000 PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
1001 absl::Span<const Shape> shapes, PjRtDevice* device,
1002 PjRtCrossHostRecvNotifier notifier) {
1003 if (shapes.empty()) {
1004 return InvalidArgument(
1005 "shapes parameter empty in MakeCrossHostReceiveBuffers");
1006 }
1007
1008 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
1009 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1010 ->GetLocalDeviceState());
1011 std::shared_ptr<BufferSequencingEvent> definition_event =
1012 std::make_shared<BufferSequencingEvent>();
1013 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1014 buffers.reserve(shapes.size());
1015 for (const auto& shape : shapes) {
1016 TF_ASSIGN_OR_RETURN(
1017 std::unique_ptr<PjRtBuffer> buffer,
1018 AllocateDestinationBuffer(shape, device, local_device,
1019 /*copy_stream=*/nullptr,
1020 /*is_uninitialized_create=*/false, this,
1021 definition_event));
1022 buffers.push_back(std::move(buffer));
1023 }
1024
1025 TF_RETURN_IF_ERROR(EnqueueCrossHostReceive(
1026 buffers, std::move(definition_event), std::move(notifier), std::nullopt));
1027 return buffers;
1028 }
1029
1030 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)1031 PjRtStreamExecutorClient::MakeCrossHostReceiveBuffersForGather(
1032 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
1033 PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) {
1034 VLOG(2) << "Making " << gather_details.size()
1035 << " cross host receive buffers for gather";
1036 if (gather_details.empty()) {
1037 return InvalidArgument(
1038 "gather_details parameter empty in "
1039 "MakeCrossHostReceiveBuffersForGather");
1040 }
1041
1042 if (shapes.size() != gather_details.size()) {
1043 return InvalidArgument(
1044 "gather_details parameter has length %lld but shapes "
1045 "parameter has length %lld in "
1046 "MakeCrossHostReceiveBuffersForGather",
1047 gather_details.size(), shapes.size());
1048 }
1049
1050 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
1051 tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1052 ->GetLocalDeviceState());
1053 std::shared_ptr<BufferSequencingEvent> definition_event =
1054 std::make_shared<BufferSequencingEvent>();
1055 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1056 buffers.reserve(shapes.size());
1057 for (int i = 0; i < shapes.size(); ++i) {
1058 TF_ASSIGN_OR_RETURN(
1059 std::unique_ptr<PjRtBuffer> buffer,
1060 AllocateDestinationBuffer(shapes[i], device, local_device,
1061 /*copy_stream=*/nullptr,
1062 /*is_uninitialized_create=*/false, this,
1063 definition_event));
1064 buffers.push_back(std::move(buffer));
1065 }
1066
1067 TF_RETURN_IF_ERROR(
1068 EnqueueCrossHostReceive(buffers, std::move(definition_event),
1069 std::move(notifier), gather_details));
1070 return buffers;
1071 }
1072
1073 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)1074 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
1075 void* device_ptr, const Shape& shape, PjRtDevice* device,
1076 std::function<void()> on_delete_callback) {
1077 se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
1078 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
1079 auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
1080 /*allocator=*/nullptr, device->local_hardware_id(),
1081 std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
1082 std::move(on_delete_callback));
1083 return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
1084 shape, std::move(device_buffer), this, device));
1085 }
1086
1087 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)1088 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
1089 // Only support infeed to local device.
1090 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1091 return local_device->client()->TransferToInfeedLocal(
1092 literal, local_device->device_ordinal());
1093 }
1094
TransferFromOutfeed(MutableBorrowingLiteral literal)1095 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
1096 MutableBorrowingLiteral literal) {
1097 VLOG(1) << "PjRtStreamExecutorDevice::TransferFromOutfeed";
1098 TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1099 return local_device->client()->TransferFromOutfeedLocal(
1100 local_device->device_ordinal(), literal);
1101 }
1102
LookupAddressableDevice(int local_hardware_id) const1103 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
1104 int local_hardware_id) const {
1105 for (auto* device : addressable_devices_) {
1106 if (local_hardware_id == device->local_hardware_id()) {
1107 return device;
1108 }
1109 }
1110 return InvalidArgument("No matching device found for local_hardware_id %d",
1111 local_hardware_id);
1112 }
1113
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)1114 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
1115 Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
1116 PjRtClient* client, PjRtDevice* device)
1117 : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1118 on_device_shape_(std::move(on_device_shape)),
1119 device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1120 device_buffer_(std::move(device_buffer)) {
1121 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1122 holds_[i] = 0;
1123 }
1124 }
1125
~PjRtStreamExecutorBuffer()1126 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1127 Delete();
1128 for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1129 CHECK_EQ(holds_[i], 0);
1130 }
1131 }
1132
WaitForOutstandingUsageHolds()1133 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1134 auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1135 return holds_[ScopedHold::kUsage] == 0;
1136 };
1137 mu_.Await(absl::Condition(¬_in_usage_hold));
1138 }
1139
WaitForOutstandingDonationHold()1140 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1141 auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1142 return holds_[ScopedHold::kDonation] == 0;
1143 };
1144 mu_.Await(absl::Condition(¬_in_donation_hold));
1145 }
1146
1147 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1148 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1149 tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1150 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1151 TrackedDeviceBuffer::StreamAndEventContainer events;
1152 {
1153 absl::MutexLock lock(&mu_);
1154 // We first wait for a donation hold to complete if there is one in
1155 // progress. If the donation succeeds via ConfirmDonation() then it will
1156 // set device_buffer_ to nullptr before returning to this thread.
1157 WaitForOutstandingDonationHold();
1158 if (device_buffer_ == nullptr) {
1159 return std::shared_ptr<TrackedDeviceBuffer>();
1160 }
1161 // Set device_buffer_ to null now so that no other
1162 // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1163 // below.
1164 std::swap(device_buffer_, device_buffer);
1165 WaitForOutstandingUsageHolds();
1166 // Now that all holds have completed and no more can be added, we can get
1167 // the final set of usage events.
1168 events = device_buffer->LockUseAndTransferUsageEvents();
1169 }
1170 LocalDeviceState* local_device_state = device_->local_device_state();
1171 if (wait_for_operations_to_complete) {
1172 // Block the host until all usage events have completed. Usage events
1173 // dominate definition events, so this also waits for the buffer to be
1174 // defined.
1175 std::unique_ptr<se::Stream> stream;
1176 for (const auto& stream_and_event : events) {
1177 if (!stream_and_event.event->IsComplete()) {
1178 if (stream == nullptr) {
1179 stream = local_device_state->BorrowStreamFromPool();
1180 }
1181 stream_and_event.event->WaitForEventOnStream(stream.get());
1182 }
1183 }
1184 if (stream != nullptr) {
1185 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1186 local_device_state->ReturnStreamToPool(std::move(stream));
1187 }
1188 } else {
1189 if (local_device_state->allocation_model() ==
1190 LocalDeviceState::kComputeSynchronized) {
1191 std::unique_ptr<se::Stream> block_stream;
1192 for (const auto& stream_and_event : events) {
1193 // We only need to do something for events that didn't already acquire a
1194 // reference to the buffer, and also which the compute stream didn't
1195 // already wait for. Based on our heuristics this rare case should only
1196 // occur when a buffer was copied to a device and then never used there.
1197 // In that case we get a new stream and use it to hold onto a reference
1198 // to the buffer until the events are complete.
1199 if (!stream_and_event.reference_held &&
1200 !stream_and_event.event->DefinedOn(
1201 local_device_state->compute_stream()) &&
1202 !stream_and_event.event->IsComplete()) {
1203 if (block_stream == nullptr) {
1204 block_stream = local_device_state->BorrowStreamFromPool();
1205 }
1206 stream_and_event.event->WaitForEventOnStream(block_stream.get());
1207 }
1208 }
1209 if (block_stream != nullptr) {
1210 se::Stream* block_stream_ptr = block_stream.release();
1211 local_device_state->ThenExecuteCallback(
1212 block_stream_ptr,
1213 [device_buffer, block_stream_ptr, local_device_state]() {
1214 local_device_state->ReturnStreamToPool(
1215 std::unique_ptr<se::Stream>(block_stream_ptr));
1216 });
1217 }
1218 }
1219 }
1220 return device_buffer;
1221 }
1222
Delete()1223 void PjRtStreamExecutorBuffer::Delete() {
1224 VLOG(1) << "PjRtStreamExecutorBuffer::Delete";
1225 // When wait_for_reads_to_complete is false, Release should never fail.
1226 TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1227 }
1228
IsDeleted()1229 bool PjRtStreamExecutorBuffer::IsDeleted() {
1230 absl::MutexLock lock(&mu_);
1231 return device_buffer_ == nullptr;
1232 }
1233
1234 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1235 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1236 // All callers should have called WaitForOutstandingDonationHold().
1237 CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1238 if (type == ScopedHold::kDonation) {
1239 if (device_buffer_ == nullptr) {
1240 return InvalidArgument("Donation requested for invalid buffer");
1241 }
1242 if (holds_[ScopedHold::kExternalReference] > 0) {
1243 return InvalidArgument(
1244 "Donation requested for buffer with external reference");
1245 }
1246 // First add the donation hold.
1247 ++holds_[type];
1248 // Then wait for any usage holds to be dropped or converted. No new usage
1249 // holds can be added until we drop the donation hold so this wait will
1250 // complete eventually.
1251 WaitForOutstandingUsageHolds();
1252 // Because we added a donation hold, nobody could release the buffer while
1253 // we were waiting.
1254 CHECK(device_buffer_ != nullptr);
1255 } else {
1256 if (device_buffer_ == nullptr) {
1257 return InvalidArgument("Buffer has been deleted or donated.");
1258 } else {
1259 ++holds_[type];
1260 }
1261 }
1262 return device_buffer_;
1263 }
1264
AcquireHoldLocked(ScopedHold * hold)1265 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1266 hold->Acquire(GetBufferForHoldLocked(hold->type()));
1267 }
1268
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1269 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1270 TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1271 std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1272 absl::MutexLock lock(&mu_);
1273 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1274 buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1275 CHECK_GT(holds_[ScopedHold::kUsage], 0);
1276 --holds_[ScopedHold::kUsage];
1277 }
1278
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1279 void PjRtStreamExecutorBuffer::ConfirmDonation(
1280 TrackedDeviceBuffer* device_buffer) {
1281 {
1282 absl::MutexLock lock(&mu_);
1283 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1284 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1285 CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1286 holds_[ScopedHold::kDonation] = 0;
1287 CHECK(device_buffer_.get() == device_buffer);
1288 // As a sanity check ensure no more usage events can be added to the buffer.
1289 device_buffer->LockUseAndTransferUsageEvents();
1290 // Give up ownership of the device memory so we don't free it when the last
1291 // reference to device_buffer_ goes away.
1292 device_buffer->ReleaseDeviceMemory();
1293 // Make *this invalid so it can't be used again. Any threads blocking in
1294 // Release or GetBufferWithHold will see an invalid buffer and return.
1295 device_buffer_.reset();
1296 }
1297 }
1298
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1299 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1300 TrackedDeviceBuffer* buffer) {
1301 absl::MutexLock lock(&mu_);
1302 CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1303 CHECK_GT(holds_[type], 0);
1304 --holds_[type];
1305 if (type == ScopedHold::kDonation) {
1306 CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1307 CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1308 CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1309 }
1310 }
1311
ToLiteral(MutableLiteralBase * literal)1312 PjRtFuture<Status> PjRtStreamExecutorBuffer::ToLiteral(
1313 MutableLiteralBase* literal) {
1314 VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1315 if (IsEmptyTuple()) {
1316 return PjRtFuture<Status>(
1317 InvalidArgument("ToLiteral called on empty tuple"));
1318 }
1319 LocalDeviceState* local_device = device_->local_device_state();
1320 se::Stream* stream = local_device->GetDeviceToHostStream();
1321 ScopedHold device_buffer(this, ScopedHold::kUsage);
1322 {
1323 absl::MutexLock lock(&mu_);
1324 // We can't perform any other action while a donation hold is in progress.
1325 WaitForOutstandingDonationHold();
1326 if (device_buffer_ == nullptr) {
1327 return PjRtFuture<Status>(InvalidArgument(
1328 "CopyToHostAsync() called on deleted or donated buffer"));
1329 }
1330 AcquireHoldLocked(&device_buffer);
1331 }
1332
1333 WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1334 ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1335 StatusOr<EventPool::Handle> event_or =
1336 local_device->event_pool().AllocateEvent(stream->parent());
1337 if (!event_or.ok()) {
1338 return PjRtFuture<Status>(event_or.status());
1339 }
1340 auto promise = PjRtFuture<Status>::CreatePromise();
1341 client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
1342 stream, shaped_buffer, literal,
1343 [promise](Status status) mutable { promise.Set(status); });
1344
1345 auto usage_event = std::make_shared<BufferSequencingEvent>();
1346 local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1347 usage_event->SetSequencingEvent(std::move(event_or).value(), stream);
1348 // When using the ComputeSynchronized allocation model, retain a reference to
1349 // the device_buffer until the copy completes, to ensure that the buffer isn't
1350 // deleted or donated while it is still in use. The choice of retaining a
1351 // reference at the host is a heuristic; the alternative is to ensure, before
1352 // freeing the buffer, that the compute stream is synchronized past the
1353 // transfer, but it seems better to hold onto the buffer too long than to
1354 // stall the compute stream, particularly since the overwhelmingly common
1355 // use case of CopyToHostAsync will hold onto the reference long enough to
1356 // read the buffer in a subsequent call to ToLiteral.
1357 RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1358 stream,
1359 /*prefer_to_retain_reference=*/true);
1360
1361 return PjRtFuture<Status>(
1362 std::move(promise),
1363 /*on_block_start=*/
1364 []() {
1365 tensorflow::profiler::TraceMeProducer traceme(
1366 "PjRtStreamExecutorBuffer::ToLiteral");
1367 VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1368 return PjRtFutureHelpers::ProfilingKeys(
1369 {/*traceme_context_id =*/traceme.GetContextId()});
1370 },
1371 /*on_block_end=*/
1372 [](PjRtFutureHelpers::ProfilingKeys keys) {
1373 tensorflow::profiler::TraceMeConsumer traceme(
1374 "PjRtStreamExecutorBuffer::ToLiteral", keys.traceme_context_id);
1375 });
1376 }
1377
GetOnDeviceSizeInBytes() const1378 StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const {
1379 absl::MutexLock lock(&mu_);
1380 if (device_buffer_ == nullptr) {
1381 return InvalidArgument(
1382 "GetOnDeviceSizeInBytes called on deleted or donated buffer");
1383 }
1384 if (device_buffer_->device_memory().size() != 1) {
1385 return InvalidArgument(
1386 "GetOnDeviceSizeInBytes called on tuple-shaped buffer");
1387 }
1388 return device_buffer_->device_memory()[0].size();
1389 }
1390
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)1391 PjRtFuture<Status> PjRtStreamExecutorBuffer::CopyRawToHost(
1392 void* dst, int64_t offset, int64_t transfer_size) {
1393 return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size);
1394 }
1395
AsShapedBuffer() const1396 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1397 absl::MutexLock lock(&mu_);
1398 if (device_buffer_ == nullptr) {
1399 return InvalidArgument(
1400 "Attempted to fetch value of invalid/deleted buffer.");
1401 }
1402 return device_buffer_->AsShapedBuffer(on_device_shape_);
1403 }
1404
1405 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1406 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1407 absl::MutexLock lock(&mu_);
1408 // Ensure that at most one donation hold can be in progress at a time.
1409 WaitForOutstandingDonationHold();
1410 ScopedHold hold(this, type);
1411 AcquireHoldLocked(&hold);
1412 return hold;
1413 }
1414
1415 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1416 std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1417 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1418 PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1419 LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1420 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1421 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1422 AllocateDestinationBuffer(
1423 ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1424 dst_device, dst_local_device, transfer_stream,
1425 /*is_uninitialized_create=*/false, client_));
1426
1427 TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1428
1429 WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1430
1431 ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1432 CHECK(dst_device_buffer.ok());
1433 ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1434
1435 // Copy the leaf buffers.
1436 StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1437 [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1438 for (const auto& leaf : src_buffer.buffers().leaves()) {
1439 const ShapeIndex& index = leaf.first;
1440 const se::DeviceMemoryBase& input_buffer = leaf.second;
1441 const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1442 TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1443 << "input: " << input_buffer.size()
1444 << " output: " << output_buffer.size();
1445 if (input_buffer.size() != 0) {
1446 TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1447 transfer_stream, dst_local_device->compute_stream(), input_buffer,
1448 output_buffer));
1449 }
1450 }
1451 std::shared_ptr<BufferSequencingEvent> event =
1452 dst_device_buffer->definition_events()[0];
1453 TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1454 transfer_local_device, std::move(dst_device_buffer), event,
1455 transfer_stream));
1456 return event;
1457 }();
1458 if (!copy_event_or.ok()) {
1459 StallStreamOnError(transfer_local_device, transfer_stream);
1460 if (transfer_local_device == dst_local_device) {
1461 // Some copies may have been enqueued before the error was returned, and
1462 // StallStreamOnError only makes sure the destination device is ok, so
1463 // make sure that the src buffer remains valid until after any transfers
1464 // have completed.
1465 device_->local_device_state()->ThenRelease(transfer_stream,
1466 std::move(src_device_buffer));
1467 }
1468 return copy_event_or.status();
1469 }
1470
1471 return std::pair<std::unique_ptr<PjRtBuffer>,
1472 std::shared_ptr<BufferSequencingEvent>>(
1473 std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1474 std::move(copy_event_or).value());
1475 }
1476
CopyToDevice(PjRtDevice * dst_device)1477 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1478 PjRtDevice* dst_device) {
1479 tensorflow::profiler::TraceMe traceme(
1480 "PjRtStreamExecutorBuffer::CopyToDevice");
1481 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToDevice";
1482 if (dst_device == device_) {
1483 return InvalidArgument(
1484 "CopyToDevice cannot accept the same source and destination devices");
1485 }
1486
1487 // Copying across PjRtClients involves a copy through the host.
1488 if (dst_device->client() != client_) {
1489 TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteralSync());
1490 // Avoid use-after-free on `literal` due to unsequenced move and use.
1491 Literal* literal_pointer = literal.get();
1492 absl::InlinedVector<int64_t, 4> byte_strides(
1493 literal->shape().dimensions_size());
1494 TF_RETURN_IF_ERROR(
1495 ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1496 return dst_device->client()->BufferFromHostBuffer(
1497 literal_pointer->untyped_data(),
1498 literal_pointer->shape().element_type(),
1499 literal_pointer->shape().dimensions(), byte_strides,
1500 PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1501 [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1502 }
1503
1504 TF_ASSIGN_OR_RETURN(
1505 LocalDeviceState * dst_local_device,
1506 tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1507 ->GetLocalDeviceState());
1508 LocalDeviceState* transfer_local_device =
1509 client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
1510 : dst_local_device;
1511 CHECK_EQ(dst_local_device->allocation_model(),
1512 transfer_local_device->allocation_model());
1513
1514 se::Stream* transfer_stream =
1515 transfer_local_device->GetDeviceToDeviceStream();
1516
1517 ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1518 {
1519 absl::MutexLock lock(&mu_);
1520 // We can't perform any other action while a donation hold is in progress.
1521 WaitForOutstandingDonationHold();
1522 if (device_buffer_ == nullptr) {
1523 return InvalidArgument(
1524 "CopyToDevice called on deleted or donated buffer");
1525 }
1526 AcquireHoldLocked(&src_device_buffer);
1527 }
1528
1529 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1530 std::shared_ptr<BufferSequencingEvent>>>
1531 buffer_and_event_or = CopyToDeviceHelper(
1532 dst_device, dst_local_device, transfer_local_device, transfer_stream,
1533 src_device_buffer.buffer());
1534 if (!buffer_and_event_or.ok()) {
1535 return buffer_and_event_or.status();
1536 }
1537
1538 auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1539 std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1540 std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1541
1542 // prefer_to_retain_reference=*/true means that, when using the
1543 // ComputeSynchronized allocation model, retain a reference to the
1544 // src_device_buffer until the copy completes. This is a heuristic; the
1545 // alternative is to ensure, before freeing the buffer, that the compute
1546 // stream is synchronized past the transfer, but it seems better to hold onto
1547 // the buffer too long than to stall the compute stream.
1548 RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
1549 transfer_local_device, event, transfer_stream,
1550 /*prefer_to_retain_reference=*/true);
1551
1552 return std::move(buffer);
1553 }
1554
CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)1555 void PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1556 absl::string_view serialized_descriptor, RemoteSendCallback on_done) {
1557 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDevice";
1558 client_->CopyToRemoteDevice(this, serialized_descriptor, std::move(on_done));
1559 }
1560
CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)1561 void PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered(
1562 absl::Span<const std::pair<std::string, RemoteSendCallback>>
1563 serialized_descriptors_and_callbacks,
1564 const ScatterDetails& scatter_details) {
1565 VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered";
1566 client_->CopyToRemoteDeviceScattered(
1567 this, serialized_descriptors_and_callbacks, scatter_details);
1568 }
1569
GetReadyFuture()1570 PjRtFuture<Status> PjRtStreamExecutorBuffer::GetReadyFuture() {
1571 std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1572 PjRtFuture<Status>::Promise definition_promise;
1573 {
1574 absl::MutexLock lock(&mu_);
1575 if (device_buffer_ == nullptr) {
1576 return PjRtFuture<Status>(InvalidArgument(
1577 "GetReadyFuture() called on deleted or donated buffer"));
1578 }
1579 if (!definition_promise_) {
1580 device_buffer = device_buffer_;
1581 definition_promise_ = PjRtFuture<Status>::CreatePromise();
1582 }
1583 definition_promise = definition_promise_;
1584 }
1585
1586 if (device_buffer) {
1587 LocalDeviceState* local_device_state = device_->local_device_state();
1588 std::unique_ptr<se::Stream> stream;
1589 for (auto& event : device_buffer->definition_events()) {
1590 if (!event->IsComplete()) {
1591 if (stream == nullptr) {
1592 stream = local_device_state->BorrowStreamFromPool();
1593 }
1594 event->WaitForEventOnStream(stream.get());
1595 }
1596 }
1597 if (stream != nullptr) {
1598 auto* stream_ptr = stream.release();
1599 // We already borrowed a stream from the pool so we can safely do the
1600 // callback directly on that stream instead of bouncing through
1601 // local_device_state->ThenExecuteCallback. The direct callback saves
1602 // significant time.
1603 stream_ptr->ThenDoHostCallback(
1604 [definition_promise, stream_ptr, local_device_state]() mutable {
1605 local_device_state->ReturnStreamToPool(
1606 std::unique_ptr<se::Stream>(stream_ptr));
1607 definition_promise.Set(OkStatus());
1608 });
1609 } else {
1610 // All events are already complete.
1611 definition_promise.Set(OkStatus());
1612 }
1613 }
1614
1615 return PjRtFuture<Status>(
1616 std::move(definition_promise),
1617 /*on_block_start=*/
1618 []() {
1619 tensorflow::profiler::TraceMeProducer traceme(
1620 "PjRtStreamExecutorBuffer::Await");
1621 VLOG(1) << "PjRtStreamExecutorBuffer::Await";
1622 return PjRtFutureHelpers::ProfilingKeys(
1623 {/*traceme_context_id=*/traceme.GetContextId()});
1624 },
1625 /*on_block_end=*/
1626 [](PjRtFutureHelpers::ProfilingKeys keys) {
1627 tensorflow::profiler::TraceMeConsumer traceme(
1628 "PjRtStreamExecutorBuffer::Await", keys.traceme_context_id);
1629 });
1630 }
1631
1632 namespace {
1633
1634 // Helper struct for the tuple that is transiently constructed to hold the
1635 // arguments of an execution.
1636 struct TupleHandle {
1637 // The ExecutionInput describing the tuple.
1638 ExecutionInput execution_input;
1639 // A definition event that has been recorded on the host_to_device stream
1640 // after the tuple table transfer.
1641 std::shared_ptr<BufferSequencingEvent> event;
1642 };
1643
CheckCompatibleShapes(bool strict_shape_checking,const Shape & buffer_shape,const Shape & execution_shape,const TransferManager & transfer_manager,int parameter_index)1644 Status CheckCompatibleShapes(bool strict_shape_checking,
1645 const Shape& buffer_shape,
1646 const Shape& execution_shape,
1647 const TransferManager& transfer_manager,
1648 int parameter_index) {
1649 // TODO(misard) Support casting of tuple parameters.
1650 if (strict_shape_checking || buffer_shape.IsTuple()) {
1651 if (!ShapeUtil::Equal(buffer_shape, execution_shape)) {
1652 return InvalidArgument(
1653 "Executable expected shape %s for argument %d but got "
1654 "incompatible "
1655 "shape %s",
1656 ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1657 ShapeUtil::HumanStringWithLayout(buffer_shape));
1658 }
1659 } else {
1660 if (transfer_manager.GetByteSizeRequirement(buffer_shape) !=
1661 transfer_manager.GetByteSizeRequirement(execution_shape)) {
1662 return InvalidArgument(
1663 "Executable expected shape %s for argument %d but got "
1664 "incompatible "
1665 "shape %s",
1666 ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1667 ShapeUtil::HumanStringWithLayout(buffer_shape));
1668 }
1669 }
1670 return OkStatus();
1671 }
1672
1673 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtStreamExecutorClient * client,LocalDeviceState * local_device,bool strict_shape_checking,const Shape & tupled_parameter_shape,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1674 StatusOr<TupleHandle> MakeTupleHelper(
1675 PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
1676 bool strict_shape_checking, const Shape& tupled_parameter_shape,
1677 absl::Span<PjRtBuffer* const> py_buffers,
1678 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1679 int device_ordinal) {
1680 se::DeviceMemoryAllocator* allocator = client->allocator();
1681 TransferManager* transfer_manager =
1682 client->client()->backend().transfer_manager();
1683
1684 if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
1685 return InvalidArgument("Executable expected %lld parameters but got %lld",
1686 tupled_parameter_shape.tuple_shapes_size(),
1687 py_buffers.size());
1688 }
1689 for (int i = 0; i < py_buffers.size(); ++i) {
1690 TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1691 strict_shape_checking, py_buffers[i]->on_device_shape(),
1692 tupled_parameter_shape.tuple_shapes(i), *transfer_manager, i));
1693 }
1694
1695 se::Stream* stream = local_device->host_to_device_stream();
1696 TF_ASSIGN_OR_RETURN(
1697 se::OwningDeviceMemory root_table_memory,
1698 allocator->Allocate(
1699 device_ordinal,
1700 transfer_manager->GetByteSizeRequirement(tupled_parameter_shape)));
1701
1702 if (local_device->allocation_model() ==
1703 LocalDeviceState::kComputeSynchronized) {
1704 stream->ThenWaitFor(local_device->compute_stream());
1705 } else {
1706 DCHECK(transfer_manager->CanBufferBeAccessedNow(
1707 local_device->compute_stream()->parent(), root_table_memory.cref()));
1708 }
1709
1710 ExecutionInput execution_input(tupled_parameter_shape);
1711 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1712 execution_input.MutableBuffers()->begin();
1713 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1714 execution_input.MutableBuffers()->end();
1715 // First set the root tuple table which is the first buffer in the ShapeTree.
1716 execution_input.SetBuffer(
1717 input_iterator->first,
1718 MaybeOwningDeviceMemory(std::move(root_table_memory)));
1719 ++input_iterator;
1720 // Then set each sub-tuple in turn from the parameters.
1721 for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1722 device_buffers) {
1723 device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1724 allocator);
1725 }
1726 CHECK(input_iterator == iterator_end);
1727
1728 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1729 stream, execution_input.Buffers()));
1730 StatusOr<EventPool::Handle> event_or =
1731 local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1732 if (!event_or.ok()) {
1733 StallStreamOnError(local_device, stream);
1734 return event_or.status();
1735 }
1736
1737 auto transfer_event = std::make_shared<BufferSequencingEvent>();
1738 transfer_event->SetSequencingEvent(std::move(event_or).value(), stream);
1739 return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1740 }
1741
1742 // Converts a ScopedShapedBuffer returned from an execution into a
1743 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release)1744 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1745 ScopedShapedBuffer* result_buffer,
1746 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1747 PjRtDevice* device, LocalDeviceState* local_device,
1748 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) {
1749 std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1750 TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1751 {definition_event});
1752 auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
1753 result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1754 RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1755 definition_event, local_device->compute_stream(),
1756 /*prefer_to_retain_reference=*/false, &buffers_to_release);
1757 return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1758 }
1759 } // namespace
1760
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1761 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1762 std::vector<std::unique_ptr<LocalExecutable>> executables,
1763 bool parameter_is_tupled_arguments,
1764 std::shared_ptr<DeviceAssignment> device_assignment,
1765 std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1766 std::vector<PjRtDevice*> addressable_devices,
1767 PjRtStreamExecutorClient* client)
1768 : client_(client),
1769 device_assignment_(std::move(device_assignment)),
1770 parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1771 addressable_device_logical_ids_(
1772 std::move(addressable_device_logical_ids)),
1773 addressable_devices_(std::move(addressable_devices)) {
1774 TransferManager* transfer_manager =
1775 client_->client()->backend().transfer_manager();
1776 executables_.reserve(executables.size());
1777 for (auto& executable : executables) {
1778 const auto& computation_layout =
1779 executable->executable()->module().entry_computation_layout();
1780 std::vector<Shape> parameter_shapes;
1781 parameter_shapes.reserve(computation_layout.parameter_count());
1782 for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1783 parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
1784 computation_layout.parameter_shape(i)));
1785 }
1786 executables_.emplace_back(std::move(executable));
1787 on_device_executable_parameter_shapes_.push_back(
1788 std::move(parameter_shapes));
1789 }
1790
1791 int num_partitions;
1792 if (device_assignment_ == nullptr) {
1793 // This must go after `executables_` is initialized.
1794 VLOG(3) << "PjRtStreamExecutorExecutable portable single-core";
1795 num_partitions = 1;
1796 CHECK(addressable_devices_.empty());
1797 } else {
1798 // This must go after `executables_` is initialized.
1799 VLOG(3) << "PjRtStreamExecutorExecutable device_assignment:\n"
1800 << device_assignment_->ToString();
1801 CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1802 CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1803 << "Inconsistent local device count.";
1804 num_partitions = device_assignment_->computation_count();
1805 }
1806
1807 // SPMD sharding produces a single executable for multiple partitions.
1808 if (executables_.size() > 1) {
1809 CHECK_EQ(num_partitions, executables_.size())
1810 << "Number of executables " << executables_.size()
1811 << " did not match number of partitions " << num_partitions;
1812 }
1813 }
1814
SetUpDonation(bool tuple_inputs)1815 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1816 parameters_that_must_be_donated_.reserve(executables_.size());
1817 for (auto& executable : executables_) {
1818 TF_ASSIGN_OR_RETURN(std::vector<int> parameters_to_donate,
1819 ComputeParametersThatMustBeDonated(
1820 executable->executable()->module(), tuple_inputs));
1821 parameters_that_must_be_donated_.emplace_back(
1822 std::move(parameters_to_donate));
1823 }
1824 return OkStatus();
1825 }
1826
name() const1827 absl::string_view PjRtStreamExecutorExecutable::name() const {
1828 Executable* executable = executables_[0]->executable();
1829 if (executable->has_module()) {
1830 return executable->module().name();
1831 } else {
1832 return "<unknown executable>";
1833 }
1834 }
1835
ParametersThatMustBeDonated(int executable_idx) const1836 absl::Span<int const> PjRtStreamExecutorExecutable::ParametersThatMustBeDonated(
1837 int executable_idx) const {
1838 return parameters_that_must_be_donated_[executable_idx];
1839 }
1840
1841 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<const Shape> executable_parameter_shapes,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1842 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1843 int device_ordinal, const ExecuteOptions& options,
1844 absl::Span<const Shape> executable_parameter_shapes,
1845 absl::Span<PjRtBuffer* const> argument_handles,
1846 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1847 absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1848 std::vector<ExecutionInput> execution_inputs;
1849 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1850 TransferManager* transfer_manager =
1851 client_->client()->backend().transfer_manager();
1852 // Lift tuple_handle outside the conditional so that the event it returns is
1853 // not destroyed until after the loop below that waits on events.
1854 std::optional<TupleHandle> tuple_handle;
1855 if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1856 TF_ASSIGN_OR_RETURN(
1857 tuple_handle,
1858 MakeTupleHelper(client_, device_state, options.strict_shape_checking,
1859 executable_parameter_shapes[0], argument_handles,
1860 device_buffers, device_ordinal));
1861 events.insert(tuple_handle->event.get());
1862 execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1863 } else {
1864 if (argument_handles.size() != executable_parameter_shapes.size()) {
1865 return InvalidArgument("Executable expected %lld arguments but got %lld",
1866 executable_parameter_shapes.size(),
1867 argument_handles.size());
1868 }
1869 execution_inputs.reserve(argument_handles.size());
1870 for (int i = 0; i < argument_handles.size(); ++i) {
1871 PjRtBuffer* handle = argument_handles[i];
1872
1873 // Make an ExecutionInput from the device buffer.
1874 TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1875 options.strict_shape_checking, handle->on_device_shape(),
1876 executable_parameter_shapes[i], *transfer_manager, i));
1877 execution_inputs.emplace_back(executable_parameter_shapes[i]);
1878 ExecutionInput& execution_input = execution_inputs.back();
1879 ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1880 execution_input.MutableBuffers()->begin();
1881 ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1882 execution_input.MutableBuffers()->end();
1883 device_buffers[i].AddToInput(&input_iterator, iterator_end,
1884 &execution_input, client_->allocator());
1885 CHECK(input_iterator == iterator_end);
1886 }
1887 }
1888
1889 for (BufferSequencingEvent* event : events) {
1890 event->WaitForEventOnStream(device_state->compute_stream());
1891 }
1892
1893 return execution_inputs;
1894 }
1895
1896 // Enqueues a computation onto the compute stream. Each buffer returned in
1897 // device_buffers has a usage hold added that must be dropped on error or
1898 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<std::function<void ()>> & compute_callbacks) const1899 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1900 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1901 int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1902 PjRtDevice* device,
1903 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1904 std::shared_ptr<DeviceAssignment> device_assignment,
1905 std::vector<std::function<void()>>& compute_callbacks) const {
1906 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1907 ->local_device_state()
1908 ->device_ordinal();
1909 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1910 tensorflow::profiler::TraceMeConsumer activity(
1911 "PjRtStreamExecutorExecutable::EnqueueExecution",
1912 tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1913 VLOG(3) << "Replica " << replica << ", partition " << partition
1914 << " mapped to device ordinal for execution: " << device_ordinal;
1915
1916 absl::flat_hash_set<BufferSequencingEvent*> events;
1917 device_buffers->reserve(argument_handles.size());
1918 absl::Span<int const> donated_params =
1919 ParametersThatMustBeDonated(executable_idx);
1920 auto donate_it = donated_params.begin();
1921 for (int i = 0; i < argument_handles.size(); ++i) {
1922 auto* handle =
1923 tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1924 if (handle->device() != device) {
1925 return InvalidArgument(
1926 "Buffer passed to Execute() as argument %d to replica %d is on "
1927 "device %s, but replica is assigned to device %s.",
1928 i, replica, handle->device()->DebugString(), device->DebugString());
1929 }
1930 bool must_donate = donate_it != donated_params.end() && *donate_it == i;
1931 if (must_donate) {
1932 ++donate_it;
1933 }
1934 device_buffers->emplace_back(handle->GetBufferWithHold(
1935 must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1936 : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1937 PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1938 device_buffers->back();
1939 if (!device_buffer.ok()) {
1940 return InvalidArgument(
1941 "Invalid buffer passed to Execute() as argument %d to replica %d: "
1942 "%s",
1943 i, replica, device_buffer.status().ToString());
1944 }
1945 // If we are trying to donate the buffer wait on the usage events as well
1946 // as the definition events to ensure that all reads have been completed
1947 // before the buffer is mutated. Usage holds are excluded during a donation
1948 // hold so we know that the set of usage events won't be modified while we
1949 // are enqueueing.
1950 GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1951 &events);
1952 }
1953
1954 if (options.arguments_are_tupled) {
1955 if (!parameter_is_tupled_arguments_) {
1956 return InvalidArgument(
1957 "Arguments may only be supplied as a tuple when the executable was "
1958 "compiled with a single tupled parameter");
1959 }
1960 if (argument_handles.size() != 1) {
1961 return InvalidArgument(
1962 "Option arguments_are_tupled was true but %d buffers were passed to "
1963 "execution",
1964 argument_handles.size());
1965 }
1966 }
1967
1968 TF_ASSIGN_OR_RETURN(
1969 std::vector<ExecutionInput> execution_inputs,
1970 MakeExecutionInputsAndWaitForEvents(
1971 device_ordinal, options,
1972 on_device_executable_parameter_shapes_[executable_idx],
1973 argument_handles, *device_buffers, events));
1974
1975 ExecutableRunOptions run_options;
1976 run_options.set_stream(device_state->compute_stream());
1977 run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1978 run_options.set_allocator(client_->allocator());
1979 run_options.set_intra_op_thread_pool(
1980 client_->client()->backend().eigen_intra_op_thread_pool_device());
1981 run_options.set_device_assignment(device_assignment.get());
1982 run_options.set_run_id(run_id);
1983 run_options.set_rng_seed(device_state->GetNewPrngSeed());
1984 run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1985 run_options.set_launch_id(options.launch_id);
1986 if (run_options.launch_id() != 0) {
1987 VLOG(3) << "launch id for " << name() << ": " << run_options.launch_id();
1988 }
1989
1990 // The choice of where we wait is arbitrary; the reason for the wait is
1991 // pacing to avoid problems such as memory fragmentation and running ahead
1992 // too far, not for correctness. Placing it before the executable launch
1993 // allows the inputs for the next executable to be fetched even if the
1994 // launch is delayed.
1995 std::shared_ptr<Semaphore::ScopedReservation> compute_reservation;
1996 {
1997 tensorflow::profiler::TraceMe traceme("ComputeSemaphoreAcquire");
1998 compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1999 device_state->compute_semaphore().ScopedAcquire(1));
2000 }
2001
2002 StatusOr<ExecutionOutput> result_buffer_or_status =
2003 executables_[executable_idx]->RunAsync(std::move(execution_inputs),
2004 run_options);
2005
2006 VLOG(1) << "Replica " << replica << " partition " << partition
2007 << " completed; ok=" << result_buffer_or_status.ok();
2008
2009 if (!result_buffer_or_status.ok()) {
2010 return result_buffer_or_status.status();
2011 }
2012
2013 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2014 ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
2015 // If we used a transient tuple for the arguments we donated its root table
2016 // buffer. In that case, and/or if we donated any input buffers that were
2017 // not aliased, the donated buffers are going to be passed back to us via
2018 // the execution output. We need to ensure they aren't freed until after
2019 // execution completes. (Currently XLA does not support aliasing tuple
2020 // tables, so if any donated parameter is a tuple there will be donated but
2021 // unaliased buffers.)
2022 std::vector<se::OwningDeviceMemory> donated_memory =
2023 execution_output.ConsumeToBeReleased();
2024 absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
2025 donated_ptrs.reserve(donated_memory.size());
2026 for (se::OwningDeviceMemory& owning : donated_memory) {
2027 // Release the owning memory so we can pass it to the closure.
2028 donated_ptrs.push_back(owning.Release());
2029 }
2030 compute_callbacks.push_back(
2031 [references{std::make_tuple(executables_[executable_idx],
2032 compute_reservation, device_assignment)},
2033 donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
2034 device_ordinal]() {
2035 for (const auto& ptr : donated_ptrs) {
2036 TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
2037 }
2038 });
2039 } else {
2040 // Any donated memory returned by the ExecutionOutput can be immediately
2041 // freed.
2042 compute_callbacks.push_back(
2043 [to_release{std::make_tuple(executables_[executable_idx],
2044 compute_reservation,
2045 device_assignment)}]() {});
2046 }
2047
2048 return std::move(result_buffer_or_status).value().ConsumeResult();
2049 }
2050
2051 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device,std::vector<std::function<void ()>> & compute_callbacks,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release) const2052 PjRtStreamExecutorExecutable::MakeOutputBuffers(
2053 int device_ordinal, const ExecuteOptions& options,
2054 ScopedShapedBuffer result_buffer,
2055 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtDevice* device,
2056 std::vector<std::function<void()>>& compute_callbacks,
2057 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
2058 const {
2059 tensorflow::profiler::TraceMe traceme("MakeOutputBuffers");
2060 std::vector<std::unique_ptr<PjRtBuffer>> outputs;
2061 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2062 if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
2063 int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
2064 outputs.reserve(tuple_count);
2065 // Take ownership of each of the output values, leaving only the root table
2066 // in result_buffer.
2067 for (int i = 0; i < tuple_count; ++i) {
2068 ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
2069 outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
2070 client_, device, device_state,
2071 buffers_to_release));
2072 }
2073 if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2074 // Don't release the root buffer until after execution completes.
2075 ShapedBuffer root_buffer_holder = result_buffer.release();
2076 se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
2077 compute_callbacks.push_back(
2078 [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
2079 TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
2080 });
2081 }
2082 } else {
2083 outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
2084 client_, device, device_state,
2085 buffers_to_release));
2086 }
2087 return outputs;
2088 }
2089
2090 StatusOr<PjRtLoadedExecutable::Result>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,bool fill_future,PjRtDevice * device) const2091 PjRtStreamExecutorExecutable::ExecuteHelper(
2092 absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
2093 const RunId& run_id, const ExecuteOptions& options, bool fill_future,
2094 PjRtDevice* device) const {
2095 const uint64_t start_time_usecs = tensorflow::Env::Default()->NowMicros();
2096 std::shared_ptr<DeviceAssignment> device_assignment;
2097 if (device == nullptr) {
2098 CHECK(device_assignment_ != nullptr);
2099 const int device_id = (*device_assignment_)(replica, partition);
2100 TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
2101 device_assignment = device_assignment_;
2102 } else {
2103 CHECK(device_assignment_ == nullptr);
2104 CHECK_EQ(replica, 0);
2105 CHECK_EQ(partition, 0);
2106 CHECK(addressable_devices_.empty());
2107 device_assignment = std::make_shared<DeviceAssignment>(1, 1);
2108 (*device_assignment)(0, 0) = device->id();
2109 }
2110
2111 CHECK_EQ(device->process_index(), client_->process_index());
2112 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2113 ->local_device_state()
2114 ->device_ordinal();
2115 tensorflow::profiler::TraceMe traceme(
2116 "PjRtStreamExecutorExecutable::ExecuteHelper");
2117 VLOG(1) << "Replica " << replica << ", partition " << partition
2118 << " mapped to device ordinal for execution: " << device_ordinal;
2119
2120 // SPMD sharding produces a single executable for multiple partitions.
2121 int executable_idx = executables_.size() > 1 ? partition : 0;
2122
2123 std::vector<std::function<void()>> compute_callbacks;
2124 std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
2125 device_buffers.reserve(argument_handles.size());
2126 StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
2127 argument_handles, replica, partition, executable_idx, run_id, options,
2128 device, &device_buffers, std::move(device_assignment), compute_callbacks);
2129
2130 if (!result_buffer_or_status.ok()) {
2131 LOG(ERROR) << "Execution of replica " << replica
2132 << " failed: " << result_buffer_or_status.status();
2133 return result_buffer_or_status.status();
2134 }
2135 ScopedShapedBuffer result_buffer = std::move(result_buffer_or_status).value();
2136
2137 LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2138 se::Stream* stream = device_state->compute_stream();
2139 StatusOr<EventPool::Handle> event_or =
2140 device_state->event_pool().ThenAllocateAndRecordEvent(stream);
2141 if (!event_or.ok()) {
2142 StallStreamOnError(device_state, stream);
2143 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2144 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
2145 // Even though there was an error we need to call ConfirmDonation, which
2146 // renders b invalid, since the computation has been enqueued and b has
2147 // been donated.
2148 b.ConfirmDonation();
2149 }
2150 }
2151 return event_or.status();
2152 }
2153 auto definition_event = std::make_shared<BufferSequencingEvent>();
2154 definition_event->SetSequencingEvent(std::move(event_or).value(), stream);
2155 std::vector<std::shared_ptr<TrackedDeviceBuffer>> buffers_to_release;
2156 std::vector<std::unique_ptr<PjRtBuffer>> outputs = MakeOutputBuffers(
2157 device_ordinal, options, std::move(result_buffer), definition_event,
2158 device, compute_callbacks, buffers_to_release);
2159
2160 for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2161 // prefer_to_retain_reference=false because when using the
2162 // ComputeSynchronized allocation model we don't need to retain a reference
2163 // to the device_buffer during execution because by definition the compute
2164 // stream is synchronized past the execution.
2165 if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
2166 RecordUsage(std::move(b), device_state, device_state, definition_event,
2167 stream,
2168 /*prefer_to_retain_reference=*/false, &buffers_to_release);
2169 } else {
2170 CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
2171 b.ConfirmDonation();
2172 }
2173 }
2174
2175 std::optional<PjRtFuture<Status>> future;
2176 if (fill_future) {
2177 auto promise = PjRtFuture<Status>::CreatePromise();
2178 future = PjRtFuture<Status>(promise);
2179 compute_callbacks.push_back(
2180 [promise = std::move(promise)]() mutable { promise.Set(OkStatus()); });
2181 }
2182 device_state->ThenExecuteCallback(
2183 stream, [callbacks{std::move(compute_callbacks)},
2184 buffers_to_release{std::move(buffers_to_release)}]() {
2185 for (auto& fn : callbacks) {
2186 fn();
2187 }
2188 });
2189 ReportExecutableEnqueueTime(tensorflow::Env::Default()->NowMicros() -
2190 start_time_usecs);
2191 return Result({/*future=*/std::move(future), /*buffers=*/std::move(outputs)});
2192 }
2193
2194 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)2195 PjRtStreamExecutorExecutable::Execute(
2196 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
2197 const ExecuteOptions& options,
2198 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
2199 if (device_assignment_ == nullptr) {
2200 return InvalidArgument("Execute expects a non-null device_assignment");
2201 }
2202
2203 RunId run_id;
2204 tensorflow::profiler::TraceMeProducer activity(
2205 "PjRtStreamExecutorExecutable::Execute",
2206 tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
2207
2208 const int num_addressable_devices = addressable_devices_.size();
2209
2210 if (argument_handles.size() != num_addressable_devices) {
2211 return InvalidArgument(
2212 "Attempted to execute with %d argument lists when local device "
2213 "count is %d (total replica count: %d, partition count: %d)",
2214 argument_handles.size(), num_addressable_devices, num_replicas(),
2215 num_partitions());
2216 }
2217
2218 VLOG(1) << "Executing computation " << name()
2219 << "; num_replicas=" << num_replicas()
2220 << " num_partitions=" << num_partitions()
2221 << " num_addressable_devices=" << num_addressable_devices;
2222 std::vector<StatusOr<Result>> results(num_addressable_devices);
2223 if (num_addressable_devices == 1) {
2224 // Fast-path if there is only one device — run the computation on the
2225 // current thread.
2226 const int replica = addressable_device_logical_ids_[0].replica;
2227 const int partition = addressable_device_logical_ids_[0].partition;
2228 results[0] = ExecuteHelper(argument_handles[0], replica, partition, run_id,
2229 options, returned_futures.has_value());
2230 } else {
2231 absl::Mutex mu;
2232 int running = num_addressable_devices;
2233 int failed = 0;
2234 Status first_failure_status;
2235
2236 for (int i = 0; i < num_addressable_devices; ++i) {
2237 const int replica = addressable_device_logical_ids_[i].replica;
2238 const int partition = addressable_device_logical_ids_[i].partition;
2239 PjRtDevice* device = addressable_devices_[i];
2240 const LocalDeviceState& device_state =
2241 *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2242 ->local_device_state();
2243 device_state.execute_thread()->Schedule([&, replica, partition, i] {
2244 results[i] =
2245 ExecuteHelper(argument_handles[i], replica, partition, run_id,
2246 options, returned_futures.has_value());
2247
2248 absl::MutexLock lock(&mu);
2249 --running;
2250 if (!results[i].ok()) {
2251 if (failed == 0) {
2252 first_failure_status = results[i].status();
2253 }
2254 ++failed;
2255 }
2256 });
2257 }
2258
2259 auto done_running_or_failed = [&]() {
2260 mu.AssertHeld();
2261 return running == 0 || failed > 0;
2262 };
2263 absl::MutexLock lock(&mu);
2264 mu.Await(absl::Condition(&done_running_or_failed));
2265 if (failed > 0) {
2266 auto done_running = [&]() {
2267 mu.AssertHeld();
2268 return running == 0;
2269 };
2270 // If execution does not terminate within a reasonable amount of time,
2271 // we may be stuck at a cross-replica barrier on-device. Terminate the
2272 // process since that's the only way we can escape this situation at the
2273 // moment (b/130629719).
2274 if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2275 absl::Seconds(10))) {
2276 LOG(FATAL)
2277 << "Replicated computation launch failed, but not all replicas "
2278 "terminated. Aborting process to work around deadlock. "
2279 "Failure message (there may have been multiple failures, see "
2280 "the error log for all failures): \n\n"
2281 << first_failure_status.error_message();
2282 }
2283 }
2284 }
2285 VLOG(1) << "Replicated execution complete.";
2286
2287 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2288 num_addressable_devices);
2289 if (returned_futures.has_value()) {
2290 returned_futures->reserve(num_addressable_devices);
2291 }
2292 for (int i = 0; i < num_addressable_devices; ++i) {
2293 const int replica = addressable_device_logical_ids_[i].replica;
2294 const int partition = addressable_device_logical_ids_[i].partition;
2295 auto& statusor = results[i];
2296 if (!statusor.ok()) {
2297 if (returned_futures.has_value()) {
2298 returned_futures->clear();
2299 }
2300 if (num_addressable_devices == 1) {
2301 return statusor.status();
2302 } else {
2303 return AppendStatus(
2304 statusor.status(),
2305 absl::StrFormat("while running replica %d and partition %d of a "
2306 "replicated computation (other "
2307 "replicas may have failed as well).",
2308 replica, partition));
2309 }
2310 }
2311 wrapped_results[i] = std::move(statusor->buffers);
2312 if (returned_futures.has_value()) {
2313 returned_futures->push_back(*std::move(statusor->future));
2314 }
2315 }
2316 return wrapped_results;
2317 }
2318
2319 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)2320 PjRtStreamExecutorExecutable::ExecuteSharded(
2321 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2322 const ExecuteOptions& options,
2323 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
2324 if (device_assignment_ == nullptr) {
2325 return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2326 }
2327 for (int i = 0; i < addressable_devices_.size(); ++i) {
2328 if (addressable_devices_[i] == device) {
2329 VLOG(1) << "ExecuteShard executes computation " << name()
2330 << " on assigned replica/partition on device "
2331 << device->DebugString();
2332 TF_ASSIGN_OR_RETURN(
2333 auto result,
2334 ExecuteHelper(argument_handles,
2335 addressable_device_logical_ids_[i].replica,
2336 addressable_device_logical_ids_[i].partition, RunId(),
2337 options, fill_future));
2338 returned_future = std::move(result.future);
2339 return std::move(result.buffers);
2340 }
2341 }
2342 return InvalidArgument(
2343 "ExecuteShard attempted to execute on device id %d which is not "
2344 "addressable by this client",
2345 device->id());
2346 }
2347
2348 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)2349 PjRtStreamExecutorExecutable::ExecutePortable(
2350 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2351 const ExecuteOptions& options,
2352 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
2353 if (device_assignment_ != nullptr) {
2354 return InvalidArgument("ExecutePortable gets a non-portable executable");
2355 }
2356 if (num_replicas() != 1 || num_partitions() != 1) {
2357 return InvalidArgument(
2358 "ExecutePortable expects a single-core executable but gets "
2359 "one with %d replica %d partition",
2360 num_replicas(), num_partitions());
2361 }
2362 if (device == nullptr) {
2363 return InvalidArgument("ExecutePortable expects a device to be specified");
2364 }
2365 VLOG(1) << "ExecutePortable executes single-core portable executable "
2366 << name();
2367 TF_ASSIGN_OR_RETURN(auto result, ExecuteHelper(argument_handles,
2368 /*replica=*/0,
2369 /*partition=*/0, RunId(),
2370 options, fill_future, device));
2371 returned_future = std::move(result.future);
2372 return std::move(result.buffers);
2373 }
2374
2375 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2376 PjRtStreamExecutorExecutable::GetHloModules() const {
2377 std::vector<std::shared_ptr<HloModule>> modules;
2378 modules.reserve(executables().size());
2379 for (const auto& local_exec : executables()) {
2380 if (!local_exec->executable()->has_module()) {
2381 return InvalidArgument("Executable does not have HLO modules.");
2382 }
2383 modules.push_back(local_exec->executable()->shared_module());
2384 }
2385 return std::move(modules);
2386 }
2387
2388 StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
GetExecutableExtras(CompileOptions * options)2389 PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
2390 ExecutableExtras extras;
2391 std::shared_ptr<DeviceAssignment>& device_assignment =
2392 extras.device_assignment;
2393 std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2394 addressable_device_logical_ids = extras.addressable_device_logical_ids;
2395 std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2396
2397 ExecutableBuildOptions& build_options = options->executable_build_options;
2398 if (!build_options.compile_thread_pool()) {
2399 build_options.set_compile_thread_pool(thread_pool());
2400 }
2401 if (!build_options.device_allocator()) {
2402 build_options.set_device_allocator(allocator());
2403 }
2404
2405 int num_replicas;
2406 int num_partitions;
2407 TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2408 options->compile_portable_executable, &options->executable_build_options,
2409 [this](int num_replicas, int num_partitions) {
2410 return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2411 },
2412 &num_replicas, &num_partitions, &device_assignment));
2413
2414 // Find devices that are addressable by this client/task.
2415 if (device_assignment != nullptr) {
2416 addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2417 addressable_devices.reserve(num_replicas * num_partitions);
2418 for (int replica = 0; replica < num_replicas; ++replica) {
2419 for (int partition = 0; partition < num_partitions; ++partition) {
2420 int device_id = (*device_assignment)(replica, partition);
2421 TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2422 if (device->process_index() != process_index()) {
2423 VLOG(3) << "Non-local device: " << device_id;
2424 continue;
2425 }
2426 PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
2427 logica_device_ids.replica = replica;
2428 logica_device_ids.partition = partition;
2429 addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2430 addressable_devices.push_back(device);
2431 }
2432 }
2433 if (addressable_devices.empty()) {
2434 return InvalidArgument(
2435 "Device assignment (%s) does not have any local devices.",
2436 device_assignment->ToString());
2437 }
2438
2439 if (build_options.device_ordinal() < 0) {
2440 build_options.set_device_ordinal(
2441 addressable_devices.front()->local_hardware_id());
2442 }
2443 }
2444 return extras;
2445 }
2446
2447 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
Compile(const XlaComputation & computation,CompileOptions options)2448 PjRtStreamExecutorClient::Compile(const XlaComputation& computation,
2449 CompileOptions options) {
2450 tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2451 VLOG(1) << "PjRtStreamExecutorClient::Compile";
2452
2453 TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
2454 std::shared_ptr<DeviceAssignment>& device_assignment =
2455 extras.device_assignment;
2456 std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2457 addressable_device_logical_ids = extras.addressable_device_logical_ids;
2458 std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2459
2460 std::vector<const Shape*> argument_layout_pointers;
2461 TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2462 computation,
2463 [local_client = client()](Shape shape) {
2464 return local_client->backend()
2465 .transfer_manager()
2466 ->ChooseCompactLayoutForShape(shape);
2467 },
2468 options.argument_layouts, &options.executable_build_options,
2469 &argument_layout_pointers));
2470
2471 TF_ASSIGN_OR_RETURN(
2472 std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2473 client()->Compile(computation, argument_layout_pointers,
2474 options.executable_build_options));
2475
2476 auto executable = std::make_unique<PjRtStreamExecutorExecutable>(
2477 std::move(local_executables), options.parameter_is_tupled_arguments,
2478 std::move(device_assignment), std::move(addressable_device_logical_ids),
2479 std::move(addressable_devices), this);
2480 TF_RETURN_IF_ERROR(
2481 executable->SetUpDonation(options.parameter_is_tupled_arguments));
2482 return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
2483 }
2484
2485 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
Compile(mlir::ModuleOp module,CompileOptions options)2486 PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
2487 CompileOptions options) {
2488 XlaComputation xla_computation;
2489 TF_RETURN_IF_ERROR(MlirToXlaComputation(
2490 module, xla_computation,
2491 /*use_tuple_args=*/options.parameter_is_tupled_arguments,
2492 /*return_tuple=*/false));
2493 return Compile(xla_computation, options);
2494 }
2495
2496 } // namespace xla
2497