xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Asynchronous execution:
19 // -----------------------
20 //
21 // Computations and host-to-device transfers do not need to block the host
22 // waiting for the operation to complete but instead return control to the host
23 // immediately. This allows client logic to overlap with device-side
24 // computation.
25 //
26 // For a good user experience, we must be careful only to enqueue operations
27 // that are unlikely to fail; as a rule error checking must be done eagerly
28 // before returning control to the client.
29 //
30 // The degree to which the client can enqueue operations ahead of the client
31 // is limited by a semaphore. There are at two modes: asynchronous, where we
32 // allow the client to enqueue up to 32 executions ahead of the device, and
33 // synchronous, where we limit the client to having one enqueued operation at
34 // a time. The value of 32 is arbitrary.
35 //
36 // Even in asynchronous mode, it is important that we do not permit
37 // unbounded queue-ahead. Firstly it is problematic when the user does something
38 // like the following in Python:
39 // %timeit run_computation()
40 // To the timeit logic, op() appears to be extremely cheap since it is deferring
41 // all of its real work and not blocking, and so the %timeit will run op() many
42 // (e.g., 10000) times to get better timing resolution, even though in reality
43 // it may be expensive. Secondly, on CPU the allocator is synchronized with the
44 // head of the compute stream, and we allocate buffers for all of the enqueued
45 // programs without any reuse (unlike GPU). This means that the memory usage
46 // is proportional to the queue size.
47 //
48 // Multi-stream execution:
49 // -----------------------
50 //
51 // We use a multistream execution design, where different Streams are used for
52 // host-to-device transfers, device-to-host transfers, and compute. This allows
53 // us to overlap transfers on and off the device with computation.
54 //
55 // Synchronization between streams occurs via BufferSequencingEvents that
56 // describe when the contents of a logical buffer are known to be valid on
57 // a particular stream, and when a buffer's uses have all completed.
58 //
59 // Synchronous vs asynchronous deallocation:
60 // -----------------------------------------
61 //
62 // See the comment on LocalDeviceState::AllocationModel for a discussion of the
63 // different allocation semantics on CPU, GPU, and TPU.
64 
65 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
66 
67 #include <algorithm>
68 #include <cstddef>
69 #include <cstdlib>
70 #include <functional>
71 #include <memory>
72 #include <optional>
73 #include <string>
74 #include <utility>
75 #include <vector>
76 
77 #include "absl/algorithm/container.h"
78 #include "absl/base/casts.h"
79 #include "absl/container/flat_hash_set.h"
80 #include "absl/container/inlined_vector.h"
81 #include "absl/strings/str_format.h"
82 #include "absl/synchronization/mutex.h"
83 #include "absl/time/time.h"
84 #include "absl/types/span.h"
85 #include "tensorflow/compiler/xla/client/local_client.h"
86 #include "tensorflow/compiler/xla/client/xla_computation.h"
87 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
88 #include "tensorflow/compiler/xla/executable_run_options.h"
89 #include "tensorflow/compiler/xla/layout.h"
90 #include "tensorflow/compiler/xla/literal.h"
91 #include "tensorflow/compiler/xla/literal_util.h"
92 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
93 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
94 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
95 #include "tensorflow/compiler/xla/pjrt/metrics.h"
96 #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
97 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
98 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
99 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
100 #include "tensorflow/compiler/xla/pjrt/utils.h"
101 #include "tensorflow/compiler/xla/service/computation_layout.h"
102 #include "tensorflow/compiler/xla/service/executable.h"
103 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
104 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
105 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
106 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
107 #include "tensorflow/compiler/xla/service/transfer_manager.h"
108 #include "tensorflow/compiler/xla/shape.h"
109 #include "tensorflow/compiler/xla/shape_util.h"
110 #include "tensorflow/compiler/xla/util.h"
111 #include "tensorflow/compiler/xla/xla_data.pb.h"
112 #include "tensorflow/core/platform/cpu_info.h"
113 #include "tensorflow/core/platform/env.h"
114 #include "tensorflow/core/platform/errors.h"
115 #include "tensorflow/core/platform/fingerprint.h"
116 #include "tensorflow/core/platform/mem.h"
117 #include "tensorflow/core/platform/status.h"
118 #include "tensorflow/core/platform/statusor.h"
119 #include "tensorflow/core/profiler/lib/connected_traceme.h"
120 #include "tensorflow/core/profiler/lib/traceme.h"
121 #include "tensorflow/core/profiler/lib/traceme_encode.h"
122 #include "tensorflow/stream_executor/device_memory.h"
123 #include "tensorflow/stream_executor/device_memory_allocator.h"
124 #include "tensorflow/stream_executor/event.h"
125 #include "tensorflow/stream_executor/host/host_platform_id.h"
126 #include "tensorflow/stream_executor/lib/statusor.h"
127 #include "tensorflow/stream_executor/stream.h"
128 
129 namespace xla {
130 
platform_id() const131 PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
132   return client_->platform_id();
133 }
platform_name() const134 absl::string_view PjRtStreamExecutorDevice::platform_name() const {
135   return client_->platform_name();
136 }
137 
GetLocalDeviceState() const138 StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
139     const {
140   if (local_device_state_) {
141     return local_device_state_.get();
142   }
143   return InvalidArgument("Device %s is not a local device.", DebugString());
144 }
145 
DebugString() const146 absl::string_view PjRtStreamExecutorDevice::DebugString() const {
147   return debug_string_;
148 }
149 
ToString() const150 absl::string_view PjRtStreamExecutorDevice::ToString() const {
151   return to_string_;
152 }
153 
DevicesToDeviceAssignment(absl::Span<const std::vector<PjRtDevice * >> devices)154 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
155     absl::Span<const std::vector<PjRtDevice*>> devices) {
156   if (devices.empty()) {
157     return InvalidArgument(
158         "Device assignment passed to Compile() must be non-empty.");
159   }
160   if (devices[0].empty()) {
161     return InvalidArgument(
162         "Device assignment passed to Compile() must have a nonzero number of "
163         "partitions per replica; replica 0 had 0 partitions.");
164   }
165   DeviceAssignment xla_assignment(devices.size(), devices[0].size());
166   for (int replica = 0; replica < devices.size(); ++replica) {
167     if (devices[replica].size() != devices[0].size()) {
168       return InvalidArgument(
169           "Device assignment passed to Compile() has different numbers of "
170           "partitions between replicas; %d partitions for replica %d versus %d "
171           "partitions for replica 0.",
172           devices[replica].size(), replica, devices[0].size());
173     }
174     for (int partition = 0; partition < devices[replica].size(); ++partition) {
175       if (devices[0][0]->client()->platform_id() !=
176           devices[replica][partition]->client()->platform_id()) {
177         return InvalidArgument(
178             "Device assignment passed to Compile() must have devices of a "
179             "single kind, got %s for replica 0 partition 0 and %s for replica "
180             "%d partition %d.",
181             devices[0][0]->client()->platform_name(),
182             devices[replica][partition]->client()->platform_name(), replica,
183             partition);
184       }
185       xla_assignment(replica, partition) = devices[replica][partition]->id();
186     }
187   }
188   return xla_assignment;
189 }
190 
191 class CpuAllocator : public tensorflow::Allocator {
192  public:
193   CpuAllocator() = default;
194 
Name()195   std::string Name() override { return "cpu"; }
196 
AllocateRaw(size_t alignment,size_t num_bytes)197   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
198     return tensorflow::port::AlignedMalloc(num_bytes, alignment);
199   }
DeallocateRaw(void * ptr)200   void DeallocateRaw(void* ptr) override {
201     return tensorflow::port::AlignedFree(ptr);
202   }
203 };
204 
PjRtStreamExecutorClient(std::string platform_name,LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int process_index,std::unique_ptr<se::DeviceMemoryAllocator> allocator,std::unique_ptr<tensorflow::Allocator> host_memory_allocator,bool should_stage_host_to_device_transfers,std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)205 PjRtStreamExecutorClient::PjRtStreamExecutorClient(
206     std::string platform_name, LocalClient* client,
207     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
208     int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
209     std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
210     bool should_stage_host_to_device_transfers,
211     std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
212     : platform_id_(tensorflow::Fingerprint64(platform_name)),
213       platform_name_(std::move(platform_name)),
214       client_(client),
215       host_memory_allocator_(std::move(host_memory_allocator)),
216       owned_allocator_(std::move(allocator)),
217       owned_devices_(std::move(devices)),
218       process_index_(process_index),
219       should_stage_host_to_device_transfers_(
220           should_stage_host_to_device_transfers),
221       gpu_run_options_(std::move(gpu_run_options)),
222       thread_pool_(
223           tensorflow::Env::Default(), "pjrt_thread_pool",
224           std::max<int>(DefaultThreadPoolSize(), client->device_count())),
225       transpose_cache_(1024) {
226   if (owned_allocator_ != nullptr) {
227     allocator_ = owned_allocator_.get();
228   } else {
229     allocator_ = client_->backend().memory_allocator();
230   }
231 
232   if (!host_memory_allocator_) {
233     host_memory_allocator_ = std::make_unique<CpuAllocator>();
234   }
235 
236   for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
237        owned_devices_) {
238     devices_.push_back(device.get());
239     CHECK(id_to_device_.insert({device->id(), device.get()}).second)
240         << "Duplicate device id: " << device->id();
241 
242     if (device->IsAddressable()) {
243       addressable_devices_.push_back(device.get());
244     }
245     device->SetClient(this);
246   }
247   // TODO(phawkins): we don't really promise anything about the order of
248   // these devices, but users may be depending on the current order. Sort into
249   // device ordinal order, which is the historical order these values have
250   // appeared.
251   absl::c_sort(addressable_devices_,
252                [](const PjRtDevice* a, const PjRtDevice* b) {
253                  return a->local_hardware_id() < b->local_hardware_id();
254                });
255 }
256 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const257 StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
258     int num_replicas, int num_partitions) const {
259   return client_->backend().computation_placer()->AssignDevices(num_replicas,
260                                                                 num_partitions);
261 }
262 
263 StatusOr<std::unique_ptr<HloCostAnalysis>>
GetHloCostAnalysis()264 PjRtStreamExecutorClient::GetHloCostAnalysis() {
265   return std::make_unique<HloCostAnalysis>(
266       client_->backend().compiler()->ShapeSizeBytesFunction());
267 }
268 
269 namespace {
270 
271 // Ensures that it is safe to deallocate any buffers that have been enqueued in
272 // an operation on stream. Called only in rare error cases that are triggered
273 // during enqueue. These cases generally correspond to resource exhaustion.
StallStreamOnError(LocalDeviceState * local_device,se::Stream * stream)274 void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
275   switch (local_device->allocation_model()) {
276     case LocalDeviceState::kAsynchronous:
277       // We can safely deallocate any dangling buffers immediately. NOTE: this
278       // assumes that any buffers enqueued on stream are local to stream's
279       // executor, and manual action may be needed if that condition is not met.
280       break;
281 
282     case LocalDeviceState::kComputeSynchronized:
283       // This will stall computation but that's ok in this very rare error
284       // case.
285       if (stream != local_device->compute_stream()) {
286         local_device->compute_stream()->ThenWaitFor(stream);
287       }
288       break;
289 
290     case LocalDeviceState::kSynchronous:
291       // This will stall the calling thread but that's ok in this very rare
292       // error case. If the stall fails just crash, since we have no other
293       // way to synchronize.
294       TF_CHECK_OK(stream->BlockHostUntilDone());
295       break;
296   }
297 }
298 
299 // Does all necessary bookkeeping, after a buffer is successfully enqueued onto
300 // a stream, to ensure that the buffer will be kept alive until its use on that
301 // stream is complete.
302 //
303 //   device_buffer:              the buffer that was enqueued.
304 //   buffer_local_device:        the device the buffer was allocated on.
305 //   stream_local_device:        the device that manages usage_stream.
306 //   event:                      an event that was recorded on usage_stream
307 //                               after the usage of device_buffer was enqueued.
308 //   usage_stream:               the stream the operation using device_buffer
309 //                               was enqueued on.
310 //   prefer_to_retain_reference: relevant only for the compute synchronous
311 //                               allocation model. If true, retain a reference
312 //                               to device_buffer until after the operation
313 //                               completes. If false then the compute stream
314 //                               will have to be synchronized past event before
315 //                               device_buffer can be freed.
316 //
317 // prefer_to_retain_reference encodes a heuristic set by the caller for the
318 // compute synchronous model:
319 //
320 // Generally when a buffer is the destination of a copy to a device, it will
321 // subsequently be used on the device's compute stream before being freed. In
322 // that case, there is no need to retain a reference to the buffer. If the
323 // buffer is freed before being used on the compute stream, the free will be
324 // delayed until the host knows that event has completed, but this is expected
325 // to be uncommon.
326 //
327 // When a buffer is the source of a copy from a device, we need to either retain
328 // a reference to the buffer until the copy completes or serialize the compute
329 // stream behind the copy. It is often better to retain a reference since while
330 // that keeps memory alive longer, it avoids stalling the compute stream.
RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,LocalDeviceState * buffer_local_device,LocalDeviceState * stream_local_device,std::shared_ptr<BufferSequencingEvent> event,se::Stream * usage_stream,bool prefer_to_retain_reference,std::vector<std::shared_ptr<TrackedDeviceBuffer>> * buffers_to_release=nullptr)331 void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
332                  LocalDeviceState* buffer_local_device,
333                  LocalDeviceState* stream_local_device,
334                  std::shared_ptr<BufferSequencingEvent> event,
335                  se::Stream* usage_stream, bool prefer_to_retain_reference,
336                  std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
337                      buffers_to_release = nullptr) {
338   tensorflow::profiler::TraceMe traceme("RecordUsage");
339   bool retain_buffer_until_completion =
340       // If the buffer wasn't allocated on the same device as the stream, always
341       // retain a reference.
342       (stream_local_device != buffer_local_device) ||
343       // In the synchronous allocation model, always retain a reference.
344       (stream_local_device->allocation_model() ==
345        LocalDeviceState::kSynchronous) ||
346       // In the compute synchronous model, use the caller's heuristic.
347       (stream_local_device->allocation_model() ==
348            LocalDeviceState::kComputeSynchronized &&
349        prefer_to_retain_reference);
350   if (retain_buffer_until_completion) {
351     if (buffers_to_release) {
352       buffers_to_release->push_back(device_buffer.buffer());
353     } else {
354       buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
355     }
356   }
357   device_buffer.ConvertUsageHold(usage_stream, event,
358                                  retain_buffer_until_completion);
359 }
360 
361 // Allocates the device buffers for a buffer that will be used as the
362 // destination of a copy, either from the host or another device. copy_stream
363 // may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
364 // buffer is a tuple then the tuple tables are allocated, and all necessary
365 // synchronization for them is dealt with, before the buffer is returned.
366 //
367 // It is safe to delete the returned PjRtBuffer without further
368 // synchronization if an error occurs before the buffer is used.
369 //
370 // The caller may optionally provide a definition event to be recorded in
371 // the buffer.
372 // TODO(phawkins): replace on_host_shape here with on_device_shape.
AllocateDestinationBuffer(const Shape & on_host_shape,PjRtDevice * device,LocalDeviceState * local_device,se::Stream * copy_stream,bool is_uninitialized_create,PjRtClient * client,std::shared_ptr<BufferSequencingEvent> definition_event=nullptr)373 StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
374     const Shape& on_host_shape, PjRtDevice* device,
375     LocalDeviceState* local_device, se::Stream* copy_stream,
376     bool is_uninitialized_create, PjRtClient* client,
377     std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
378   if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
379     return InvalidArgument("Can't make a buffer from an empty tuple");
380   }
381 
382   auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
383   TransferManager* transfer_manager =
384       se_client->client()->backend().transfer_manager();
385   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
386                       transfer_manager->AllocateScopedShapedBuffer(
387                           on_host_shape, se_client->allocator(),
388                           local_device->device_ordinal()));
389   if (local_device->allocation_model() ==
390       LocalDeviceState::kComputeSynchronized) {
391     if (copy_stream == nullptr) {
392       CHECK(is_uninitialized_create);
393     } else {
394       copy_stream->ThenWaitFor(local_device->compute_stream());
395     }
396   } else {
397     DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
398         local_device->compute_stream()->parent(), dst_buffer));
399   }
400   Shape on_device_shape = dst_buffer.on_device_shape();
401 
402   absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
403       definition_events;
404   if (is_uninitialized_create) {
405     // There is not going to be any copy into the buffer so in general we don't
406     // need a definition event.
407     if (local_device->allocation_model() ==
408         LocalDeviceState::kComputeSynchronized) {
409       // The allocation is not valid until the compute stream passes this point,
410       // so add a definition event in the compute stream.
411       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
412       TF_ASSIGN_OR_RETURN(EventPool::Handle event,
413                           local_device->event_pool().ThenAllocateAndRecordEvent(
414                               local_device->compute_stream()));
415       definition_events.back()->SetSequencingEvent(
416           std::move(event), local_device->compute_stream());
417     }
418     // if the caller provided a definition event then we record that.
419     if (definition_event) {
420       definition_events.emplace_back(definition_event);
421     }
422   } else {
423     // We have at least one definition event, for the copy completing to
424     // the device buffers.
425     if (definition_event) {
426       definition_events.emplace_back(definition_event);
427     } else {
428       definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
429     }
430   }
431   se::Stream* tuple_table_stream = local_device->host_to_device_stream();
432   if (on_device_shape.IsTuple()) {
433     // We also need to copy the tuple tables, so we'll have an additional
434     // definition event for that copy to complete.
435     if (tuple_table_stream != copy_stream) {
436       if (local_device->allocation_model() ==
437           LocalDeviceState::kComputeSynchronized) {
438         tuple_table_stream->ThenWaitFor(local_device->compute_stream());
439       } else {
440         DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
441             local_device->compute_stream()->parent(), dst_buffer));
442       }
443     }
444 
445     TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
446         tuple_table_stream, dst_buffer));
447     // CAUTION: From this point onwards we need to be careful about returning
448     // from error cases because we have started a transfer and must not allow
449     // dst_buffer to be freed too soon in the non-async allocation models.
450 
451     definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
452     StatusOr<EventPool::Handle> event_or =
453         local_device->event_pool().ThenAllocateAndRecordEvent(
454             tuple_table_stream);
455     if (!event_or.ok()) {
456       StallStreamOnError(local_device, tuple_table_stream);
457       return event_or.status();
458     }
459     definition_events.back()->SetSequencingEvent(std::move(event_or).value(),
460                                                  tuple_table_stream);
461   }
462   std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
463       TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
464                                                   definition_events);
465 
466   auto py_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
467       on_device_shape, std::move(dst_device_buffer), client, device);
468 
469   if (on_device_shape.IsTuple()) {
470     // Add a usage hold for the tuple table write and immediately convert it to
471     // the appropriate form of synchronization. prefer_to_retain_reference=false
472     // means don't retain a memory reference until the transfer is complete when
473     // using the ComputeSynchronized allocation model. This is a heuristic
474     // because in the common case destination buffers will be used on the
475     // compute stream and therefore don't require any synchronization before
476     // being freed. If the buffer is allocated and never used, the free will
477     // take longer and this is assumed to be ok.
478     RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
479                 definition_events.back(), tuple_table_stream,
480                 /*prefer_to_retain_reference=*/false);
481   }
482 
483   return py_buffer;
484 }
485 
486 // Adds necessary synchronization after a copy has been enqueued to a buffer.
487 // definition_event was added when the buffer was allocated, but has not yet
488 // had an event recorded.
AddDestinationBufferSynchronization(LocalDeviceState * local_device,PjRtStreamExecutorBuffer::ScopedHold device_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,se::Stream * copy_stream)489 Status AddDestinationBufferSynchronization(
490     LocalDeviceState* local_device,
491     PjRtStreamExecutorBuffer::ScopedHold device_buffer,
492     std::shared_ptr<BufferSequencingEvent> definition_event,
493     se::Stream* copy_stream) {
494   StatusOr<EventPool::Handle> event_or =
495       local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
496   if (!event_or.ok()) {
497     StallStreamOnError(local_device, copy_stream);
498     return event_or.status();
499   }
500   definition_event->SetSequencingEvent(std::move(event_or).value(),
501                                        copy_stream);
502   // prefer_to_retain_reference=false means don't retain a memory reference
503   // until the transfer is complete when using the ComputeSynchronized
504   // allocation model. This is a heuristic because in the common case
505   // destination buffers will be used on the compute stream and therefore don't
506   // require any synchronization before being freed. If the buffer is allocated
507   // and never used, the free will take longer and this is assumed to be ok.
508   RecordUsage(std::move(device_buffer), local_device, local_device,
509               definition_event, copy_stream,
510               /*prefer_to_retain_reference=*/false);
511   return OkStatus();
512 }
513 
514 }  // namespace
515 
~ScopedHold()516 PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
517   if (ok()) {
518     parent_->DropHold(type_, buffer().get());
519   }
520 }
521 
ScopedHold(ScopedHold && other)522 PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
523     : parent_(other.parent_),
524       type_(other.type_),
525       state_(other.state_),
526       status_(std::move(other.status_)),
527       buffer_(std::move(other.buffer_)) {
528   // Preserve the invariant that status is invalid if buffer == nullptr.
529   other.SetState(kMoved);
530 }
531 
Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>> && buffer_or)532 void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
533     StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
534   CHECK(!ok());
535   if (buffer_or.ok()) {
536     buffer_ = buffer_or.ValueOrDie();
537     SetState(kValid);
538   } else {
539     status_ = buffer_or.status();
540     buffer_ = nullptr;
541     SetState(kError);
542   }
543   // Check the invariant holds.
544   CHECK(!ok() || buffer_ != nullptr);
545 }
546 
547 PjRtStreamExecutorBuffer::ScopedHold::ForClosure
ToClosure()548 PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
549   CHECK(ok());
550   ForClosure for_closure(parent_, type_, state_, std::move(status_),
551                          std::move(buffer_));
552   SetState(kReleased);
553   return for_closure;
554 }
555 
ConvertUsageHold(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)556 void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
557     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
558     bool reference_held) {
559   CHECK(ok());
560   CHECK_EQ(type_, kUsage);
561   parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
562                             reference_held);
563   SetState(kConverted);
564 }
565 
ConfirmDonation()566 void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
567   CHECK(ok());
568   CHECK_EQ(type_, kDonation);
569   parent_->ConfirmDonation(buffer().get());
570   SetState(kDonated);
571 }
572 
AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const573 void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
574     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
575     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
576     ExecutionInput* execution_input,
577     se::DeviceMemoryAllocator* allocator) const {
578   CHECK(ok());
579   if (type_ == kDonation) {
580     buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
581   } else {
582     CHECK_EQ(type_, kUsage);
583     buffer()->AddToInputAsImmutable(iterator, end);
584   }
585 }
586 
IsOnCpu() const587 bool PjRtStreamExecutorBuffer::IsOnCpu() const {
588   return client()->platform_id() == CpuId();
589 }
590 
logical_on_device_shape()591 StatusOr<Shape> PjRtStreamExecutorBuffer::logical_on_device_shape() {
592   if (on_device_shape_.is_static()) {
593     return on_device_shape_;
594   }
595   auto* local_device = device_->local_device_state();
596   auto* stream = local_device->GetDeviceToHostStream();
597   ScopedHold device_buffer(this, ScopedHold::kUsage);
598   {
599     absl::MutexLock lock(&mu_);
600     // We can't perform any other action while a donation hold is in progress.
601     WaitForOutstandingDonationHold();
602     if (device_buffer_ == nullptr) {
603       return InvalidArgument(
604           "logical_on_device_shape() called on deleted or donated buffer");
605     }
606     AcquireHoldLocked(&device_buffer);
607   }
608 
609   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
610   ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
611   StatusOr<EventPool::Handle> event_or =
612       local_device->event_pool().AllocateEvent(stream->parent());
613   if (!event_or.ok()) {
614     return event_or.status();
615   }
616   Shape ret_shape = on_device_shape_;
617   TransferManager* transfer_manager =
618       client_->client()->backend().transfer_manager();
619   TF_RETURN_IF_ERROR(
620       transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
621   return ret_shape;
622 }
623 
624 namespace {
625 
626 // Implements PjRtBuffer::ExternalReference as a wrapped
627 // ScopedHold::kExternalReference.
628 class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
629  public:
ScopedHoldAsExternalReference(PjRtStreamExecutorBuffer::ScopedHold hold)630   explicit ScopedHoldAsExternalReference(
631       PjRtStreamExecutorBuffer::ScopedHold hold)
632       : external_reference_(std::move(hold)) {
633     CHECK(external_reference_.type() ==
634           PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
635     data_ptr_ = external_reference_->device_memory().front().opaque();
636   }
637 
638   ~ScopedHoldAsExternalReference() override = default;
639 
640  private:
641   PjRtStreamExecutorBuffer::ScopedHold external_reference_;
642 };
643 
644 }  // namespace
645 
646 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
AcquireExternalReference()647 PjRtStreamExecutorBuffer::AcquireExternalReference() {
648   ScopedHold hold = GetBufferWithExternalReference();
649   Status hold_status = hold.status();
650   if (!hold_status.ok()) return hold_status;
651   return std::unique_ptr<ExternalReference>(
652       std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
653 }
654 
655 class TrackedDeviceBufferExternalReference
656     : public PjRtBuffer::ExternalReference {
657  public:
TrackedDeviceBufferExternalReference(std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)658   explicit TrackedDeviceBufferExternalReference(
659       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
660       : tracked_device_buffer_(std::move(tracked_device_buffer)) {
661     data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
662   }
663 
664   ~TrackedDeviceBufferExternalReference() override = default;
665 
666  private:
667   std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
668 };
669 
670 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete)671 PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
672     bool wait_for_operations_to_complete) {
673   if (on_device_shape_.IsTuple()) {
674     return InvalidArgument(
675         "ReleaseDeviceMemoryOwnership allowed only for non-tuple");
676   }
677   TF_ASSIGN_OR_RETURN(
678       std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
679       Release(wait_for_operations_to_complete));
680 
681   std::unique_ptr<PjRtBuffer::ExternalReference> ref;
682   if (tracked_device_buffer) {
683     ref = std::make_unique<TrackedDeviceBufferExternalReference>(
684         std::move(tracked_device_buffer));
685   }
686   return ref;
687 }
688 
689 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostBuffer(const void * data,PrimitiveType type,absl::Span<int64_t const> dims,std::optional<absl::Span<int64_t const>> byte_strides,HostBufferSemantics host_buffer_semantics,std::function<void ()> on_done_with_host_buffer,PjRtDevice * device)690 PjRtStreamExecutorClient::BufferFromHostBuffer(
691     const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
692     std::optional<absl::Span<int64_t const>> byte_strides,
693     HostBufferSemantics host_buffer_semantics,
694     std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
695   tensorflow::profiler::TraceMe traceme(
696       "PjRtStreamExecutorClient::BufferFromHostBuffer");
697   Shape shape = ShapeUtil::MakeShape(type, dims);
698   VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
699           << shape.ToString() << " device: " << device->DebugString();
700   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
701                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
702                           ->GetLocalDeviceState());
703 
704   absl::InlinedVector<int64_t, 4> tmp_strides;
705   if (!byte_strides) {
706     tmp_strides.resize(dims.size());
707     TF_RETURN_IF_ERROR(
708         ShapeUtil::ByteStrides(shape, absl::MakeSpan(tmp_strides)));
709     byte_strides = tmp_strides;
710   }
711   int64_t size = ShapeUtil::ByteSizeOf(shape);
712 
713   TransferManager* transfer_manager = client()->backend().transfer_manager();
714   TF_ASSIGN_OR_RETURN(Shape compact_shape,
715                       transfer_manager->ChooseCompactLayoutForShape(shape));
716   absl::InlinedVector<int64_t, 4> compact_shape_strides(
717       compact_shape.dimensions_size());
718   TF_RETURN_IF_ERROR(ShapeUtil::ByteStrides(
719       compact_shape, absl::MakeSpan(compact_shape_strides)));
720   bool host_and_device_strides_equal =
721       (size == 0 || *byte_strides == compact_shape_strides);
722   // The CPU platform is special because the "host" and the "device" are in the
723   // same memory space. If the input shape is in the correct layout and we don't
724   // want to defer the copy onto a thread, we can use the following fast
725   // path.
726   bool is_cpu_platform =
727       local_device->executor()->platform()->id() == se::host::kHostPlatformId;
728   if (is_cpu_platform) {
729     // If we are on the host platform and the input buffer is sufficiently
730     // aligned, we can simply point to the input array's data without any
731     // further copies. At the time of writing we require a 16-byte alignment
732     // because XLA may generate code which requires it.
733     bool can_use_zero_copy =
734         host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
735         ((absl::bit_cast<std::uintptr_t>(data) &
736           (cpu_function_runtime::MinAlign() - 1)) == 0);
737     if (host_and_device_strides_equal &&
738         (host_buffer_semantics ==
739              HostBufferSemantics::kImmutableOnlyDuringCall ||
740          can_use_zero_copy)) {
741       std::function<void()> on_delete_callback;
742       se::DeviceMemoryBase buffer;
743       // If we are on the host platform and the input buffer is sufficiently
744       // aligned, we can simply point to the input array's data without any
745       // further copies. At the time of writing we require a 16-byte alignment
746       // because XLA may generate code which requires it.
747       if (can_use_zero_copy) {
748         on_delete_callback = std::move(on_done_with_host_buffer);
749         buffer = se::DeviceMemoryBase(
750             const_cast<void*>(static_cast<const void*>(data)), size);
751       } else {
752         void* staging_buffer = host_memory_allocator()->AllocateRaw(
753             cpu_function_runtime::MinAlign(), size);
754         buffer = se::DeviceMemoryBase(staging_buffer, size);
755         std::memcpy(staging_buffer, data, size);
756         if (on_done_with_host_buffer) {
757           on_done_with_host_buffer();
758         }
759         on_delete_callback = [staging_buffer, host_memory_allocator =
760                                                   host_memory_allocator()]() {
761           host_memory_allocator->DeallocateRaw(staging_buffer);
762         };
763       }
764       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
765           definition_events;
766       auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
767           /*allocator=*/nullptr, local_device->device_ordinal(),
768           std::initializer_list<se::DeviceMemoryBase>{buffer},
769           definition_events, std::move(on_delete_callback));
770       return std::unique_ptr<PjRtBuffer>(
771           std::make_unique<PjRtStreamExecutorBuffer>(
772               shape, std::move(device_buffer), this, device));
773     }
774   }
775 
776   TF_ASSIGN_OR_RETURN(
777       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
778       AllocateDestinationBuffer(compact_shape, device, local_device,
779                                 local_device->host_to_device_stream(),
780                                 /*is_uninitialized_create=*/false, this));
781 
782   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
783       py_buffer->GetBufferWithUsageHold());
784   CHECK(device_buffer.ok());
785 
786   // If necessary, allocate a host-side buffer for staging host-to-device
787   // transfers. On GPU this is a buffer in pinned memory.
788   std::shared_ptr<void> staging_buffer;
789   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
790       should_stage_host_to_device_transfers() ||
791       !host_and_device_strides_equal) {
792     void* ptr = host_memory_allocator()->AllocateRaw(
793         tensorflow::Allocator::kAllocatorAlignment, size);
794     staging_buffer = std::shared_ptr<void>(
795         ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
796           host_memory_allocator->DeallocateRaw(ptr);
797         });
798   }
799 
800   std::shared_ptr<TransposePlan> transpose;
801   if (!host_and_device_strides_equal) {
802     absl::InlinedVector<int64_t, 4> permutation(dims.size());
803     absl::c_reverse_copy(compact_shape.layout().minor_to_major(),
804                          permutation.begin());
805     absl::MutexLock lock(&transpose_mu_);
806     TF_ASSIGN_OR_RETURN(transpose,
807                         transpose_cache_.GetOrCreate(
808                             primitive_util::ByteWidth(type), dims, permutation,
809                             TransposePlan::Striding{*byte_strides}));
810   }
811 
812   // Copy the buffer into a staging buffer before returning control to the
813   // caller if the caller only guaranteed that the buffer is valid for the
814   // duration of the call. Otherwise, we stage (if necessary) on a separate
815   // thread.
816   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
817     if (transpose) {
818       transpose->Execute(data, staging_buffer.get());
819     } else {
820       std::memcpy(staging_buffer.get(), data, size);
821     }
822     if (on_done_with_host_buffer) {
823       on_done_with_host_buffer();
824       on_done_with_host_buffer = nullptr;
825     }
826   }
827 
828   // The host to device transfer is performed on a thread pool, mostly because
829   // it includes linearization that may be slow. It is OK to capture the
830   // py_buffer pointer because the py_buffer can't be deleted until all the
831   // usage holds have gone away.
832   // TODO(misard) assess if it would be preferable to introduce a heuristic to
833   // put the transfer into the calling thread for small literals.
834   auto transfer_h2d =
835       [local_client = client(), transfer_manager, local_device, data, size,
836        movable_device_buffer{device_buffer.ToClosure()}, shape,
837        py_buffer{py_buffer.get()},
838        on_device_shape{py_buffer->on_device_shape()},
839        staging_buffer{std::move(staging_buffer)},
840        on_done_with_host_buffer{std::move(on_done_with_host_buffer)},
841        host_buffer_semantics, transpose{std::move(transpose)}]() {
842         PjRtStreamExecutorBuffer::ScopedHold device_buffer(
843             movable_device_buffer);
844         // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
845         // to report failures from a callback. However, the operations here are
846         // unlikely to fail and not recoverable even if we were to fail: DMAs to
847         // memory that has already been allocated, and a possible Event
848         // allocation.
849 
850         ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
851         // If applicable on the backend, stage the transfer via host memory
852         // allocated via the host_memory_allocator. On GPU, this is pinned
853         // memory.
854         if (staging_buffer) {
855           // If we didn't already copy the input buffer into the staging buffer,
856           // do so now.
857           if (host_buffer_semantics !=
858               HostBufferSemantics::kImmutableOnlyDuringCall) {
859             if (transpose) {
860               transpose->Execute(data, staging_buffer.get());
861             } else {
862               std::memcpy(staging_buffer.get(), data, size);
863             }
864           }
865           // The buffer has the same dimension order as the on-device shape, but
866           // is not tiled, etc.
867           BorrowingLiteral literal(
868               static_cast<const char*>(staging_buffer.get()),
869               ShapeUtil::DeviceShapeToHostShape(on_device_shape));
870           TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
871               local_device->host_to_device_stream(), literal, buffer));
872         } else {
873           BorrowingLiteral literal(
874               reinterpret_cast<const char*>(data),
875               ShapeUtil::DeviceShapeToHostShape(on_device_shape));
876           // Otherwise, just transfer the literal.
877           TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
878               local_device->host_to_device_stream(), literal, buffer));
879         }
880 
881         std::shared_ptr<BufferSequencingEvent> event =
882             device_buffer->definition_events()[0];
883         TF_CHECK_OK(AddDestinationBufferSynchronization(
884             local_device, std::move(device_buffer), event,
885             local_device->host_to_device_stream()));
886 
887         local_device->ThenExecuteCallback(
888             local_device->host_to_device_stream(),
889             [staging_buffer{std::move(staging_buffer)},
890              on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
891               if (on_done_with_host_buffer) {
892                 on_done_with_host_buffer();
893               }
894             });
895       };
896   if (is_cpu_platform) {
897     // Using the thread_pool would be a double thread hop; the code
898     // already defers its work onto a stream (= thread on CPU).
899     transfer_h2d();
900   } else {
901     thread_pool()->Schedule(transfer_h2d);
902   }
903   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
904 }
905 
906 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device)907 PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
908                                                     PjRtDevice* device) {
909   return CreateUninitializedBuffer(shape, device, nullptr);
910 }
911 
912 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateUninitializedBuffer(const Shape & shape,PjRtDevice * device,std::shared_ptr<BufferSequencingEvent> definition_event)913 PjRtStreamExecutorClient::CreateUninitializedBuffer(
914     const Shape& shape, PjRtDevice* device,
915     std::shared_ptr<BufferSequencingEvent> definition_event) {
916   tensorflow::profiler::TraceMe traceme(
917       "PjRtStreamExecutorClient::CreateUninitializedBuffer");
918   VLOG(1) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
919           << shape.ToString() << " device: " << device->DebugString();
920   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
921                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
922                           ->GetLocalDeviceState());
923 
924   TransferManager* transfer_manager = client()->backend().transfer_manager();
925   TF_ASSIGN_OR_RETURN(Shape compact_shape,
926                       transfer_manager->ChooseCompactLayoutForShape(shape));
927 
928   TF_ASSIGN_OR_RETURN(
929       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
930       AllocateDestinationBuffer(compact_shape, device, local_device,
931                                 /*copy_stream=*/nullptr,
932                                 /*is_uninitialized_create=*/true, this,
933                                 definition_event));
934   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
935 }
936 
937 StatusOr<std::unique_ptr<PjRtBuffer>>
BufferFromHostLiteral(const LiteralSlice & literal,PjRtDevice * device)938 PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
939                                                 PjRtDevice* device) {
940   tensorflow::profiler::TraceMe traceme(
941       "PjRtStreamExecutorClient::BufferFromHostLiteral");
942   VLOG(1) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
943           << literal.shape().ToString() << " device: " << device->DebugString();
944   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
945                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
946                           ->GetLocalDeviceState());
947 
948   TransferManager* transfer_manager = client()->backend().transfer_manager();
949   TF_ASSIGN_OR_RETURN(
950       Shape compact_shape,
951       transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
952   TF_ASSIGN_OR_RETURN(
953       std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
954       AllocateDestinationBuffer(compact_shape, device, local_device,
955                                 local_device->host_to_device_stream(),
956                                 /*is_uninitialized_create=*/false, this));
957 
958   PjRtStreamExecutorBuffer::ScopedHold device_buffer(
959       py_buffer->GetBufferWithUsageHold());
960   CHECK(device_buffer.ok());
961 
962   // The host to device transfer is performed on a thread pool, mostly because
963   // it includes linearization that may be slow. It is OK to capture the
964   // py_buffer pointer because the py_buffer can't be deleted until all the
965   // usage holds have gone away.
966   // TODO(misard) assess if it would be preferable to introduce a heuristic to
967   // put the transfer into the calling thread for small literals.
968   auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
969                        movable_device_buffer{device_buffer.ToClosure()},
970                        literal, py_buffer{py_buffer.get()},
971                        on_device_shape{py_buffer->on_device_shape()}]() {
972     PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
973     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
974     // to report failures from a callback. However, the operations here are
975     // unlikely to fail and not recoverable even if we were to fail: DMAs to
976     // memory that has already been allocated, and a possible Event
977     // allocation.
978 
979     se::Stream* h2d_stream = local_device->host_to_device_stream();
980     ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
981     TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
982         h2d_stream, literal, buffer));
983 
984     std::shared_ptr<BufferSequencingEvent> event =
985         device_buffer->definition_events()[0];
986     TF_CHECK_OK(AddDestinationBufferSynchronization(
987         local_device, std::move(device_buffer), event, h2d_stream));
988 
989     // This can sometimes catch the case where the literal memory has been
990     // freed before the H2D transfer was issued.
991     h2d_stream->RefreshStatus()
992         .IgnoreError();  // Can return error::Unimplemented
993     QCHECK(h2d_stream->ok());
994   };
995   thread_pool()->Schedule(transfer_h2d);
996   return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
997 }
998 
999 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)1000 PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
1001     absl::Span<const Shape> shapes, PjRtDevice* device,
1002     PjRtCrossHostRecvNotifier notifier) {
1003   if (shapes.empty()) {
1004     return InvalidArgument(
1005         "shapes parameter empty in MakeCrossHostReceiveBuffers");
1006   }
1007 
1008   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
1009                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1010                           ->GetLocalDeviceState());
1011   std::shared_ptr<BufferSequencingEvent> definition_event =
1012       std::make_shared<BufferSequencingEvent>();
1013   std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1014   buffers.reserve(shapes.size());
1015   for (const auto& shape : shapes) {
1016     TF_ASSIGN_OR_RETURN(
1017         std::unique_ptr<PjRtBuffer> buffer,
1018         AllocateDestinationBuffer(shape, device, local_device,
1019                                   /*copy_stream=*/nullptr,
1020                                   /*is_uninitialized_create=*/false, this,
1021                                   definition_event));
1022     buffers.push_back(std::move(buffer));
1023   }
1024 
1025   TF_RETURN_IF_ERROR(EnqueueCrossHostReceive(
1026       buffers, std::move(definition_event), std::move(notifier), std::nullopt));
1027   return buffers;
1028 }
1029 
1030 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)1031 PjRtStreamExecutorClient::MakeCrossHostReceiveBuffersForGather(
1032     absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
1033     PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) {
1034   VLOG(2) << "Making " << gather_details.size()
1035           << " cross host receive buffers for gather";
1036   if (gather_details.empty()) {
1037     return InvalidArgument(
1038         "gather_details parameter empty in "
1039         "MakeCrossHostReceiveBuffersForGather");
1040   }
1041 
1042   if (shapes.size() != gather_details.size()) {
1043     return InvalidArgument(
1044         "gather_details parameter has length %lld but shapes "
1045         "parameter has length %lld in "
1046         "MakeCrossHostReceiveBuffersForGather",
1047         gather_details.size(), shapes.size());
1048   }
1049 
1050   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
1051                       tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1052                           ->GetLocalDeviceState());
1053   std::shared_ptr<BufferSequencingEvent> definition_event =
1054       std::make_shared<BufferSequencingEvent>();
1055   std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1056   buffers.reserve(shapes.size());
1057   for (int i = 0; i < shapes.size(); ++i) {
1058     TF_ASSIGN_OR_RETURN(
1059         std::unique_ptr<PjRtBuffer> buffer,
1060         AllocateDestinationBuffer(shapes[i], device, local_device,
1061                                   /*copy_stream=*/nullptr,
1062                                   /*is_uninitialized_create=*/false, this,
1063                                   definition_event));
1064     buffers.push_back(std::move(buffer));
1065   }
1066 
1067   TF_RETURN_IF_ERROR(
1068       EnqueueCrossHostReceive(buffers, std::move(definition_event),
1069                               std::move(notifier), gather_details));
1070   return buffers;
1071 }
1072 
1073 StatusOr<std::unique_ptr<PjRtBuffer>>
CreateViewOfDeviceBuffer(void * device_ptr,const Shape & shape,PjRtDevice * device,std::function<void ()> on_delete_callback)1074 PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
1075     void* device_ptr, const Shape& shape, PjRtDevice* device,
1076     std::function<void()> on_delete_callback) {
1077   se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
1078   absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
1079   auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
1080       /*allocator=*/nullptr, device->local_hardware_id(),
1081       std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
1082       std::move(on_delete_callback));
1083   return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
1084       shape, std::move(device_buffer), this, device));
1085 }
1086 
1087 // Transfer the given literal to the infeed queue of the given local device.
TransferToInfeed(const LiteralSlice & literal)1088 Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
1089   // Only support infeed to local device.
1090   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1091   return local_device->client()->TransferToInfeedLocal(
1092       literal, local_device->device_ordinal());
1093 }
1094 
TransferFromOutfeed(MutableBorrowingLiteral literal)1095 Status PjRtStreamExecutorDevice::TransferFromOutfeed(
1096     MutableBorrowingLiteral literal) {
1097   VLOG(1) << "PjRtStreamExecutorDevice::TransferFromOutfeed";
1098   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
1099   return local_device->client()->TransferFromOutfeedLocal(
1100       local_device->device_ordinal(), literal);
1101 }
1102 
LookupAddressableDevice(int local_hardware_id) const1103 StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
1104     int local_hardware_id) const {
1105   for (auto* device : addressable_devices_) {
1106     if (local_hardware_id == device->local_hardware_id()) {
1107       return device;
1108     }
1109   }
1110   return InvalidArgument("No matching device found for local_hardware_id %d",
1111                          local_hardware_id);
1112 }
1113 
PjRtStreamExecutorBuffer(Shape on_device_shape,std::shared_ptr<TrackedDeviceBuffer> device_buffer,PjRtClient * client,PjRtDevice * device)1114 PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
1115     Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
1116     PjRtClient* client, PjRtDevice* device)
1117     : client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
1118       on_device_shape_(std::move(on_device_shape)),
1119       device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
1120       device_buffer_(std::move(device_buffer)) {
1121   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1122     holds_[i] = 0;
1123   }
1124 }
1125 
~PjRtStreamExecutorBuffer()1126 PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
1127   Delete();
1128   for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
1129     CHECK_EQ(holds_[i], 0);
1130   }
1131 }
1132 
WaitForOutstandingUsageHolds()1133 void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
1134   auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1135     return holds_[ScopedHold::kUsage] == 0;
1136   };
1137   mu_.Await(absl::Condition(&not_in_usage_hold));
1138 }
1139 
WaitForOutstandingDonationHold()1140 void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
1141   auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1142     return holds_[ScopedHold::kDonation] == 0;
1143   };
1144   mu_.Await(absl::Condition(&not_in_donation_hold));
1145 }
1146 
1147 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
Release(bool wait_for_operations_to_complete)1148 PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
1149   tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
1150   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1151   TrackedDeviceBuffer::StreamAndEventContainer events;
1152   {
1153     absl::MutexLock lock(&mu_);
1154     // We first wait for a donation hold to complete if there is one in
1155     // progress. If the donation succeeds via ConfirmDonation() then it will
1156     // set device_buffer_ to nullptr before returning to this thread.
1157     WaitForOutstandingDonationHold();
1158     if (device_buffer_ == nullptr) {
1159       return std::shared_ptr<TrackedDeviceBuffer>();
1160     }
1161     // Set device_buffer_ to null now so that no other
1162     // thread can add a hold while we are in WaitForOutstandingUsageHolds()
1163     // below.
1164     std::swap(device_buffer_, device_buffer);
1165     WaitForOutstandingUsageHolds();
1166     // Now that all holds have completed and no more can be added, we can get
1167     // the final set of usage events.
1168     events = device_buffer->LockUseAndTransferUsageEvents();
1169   }
1170   LocalDeviceState* local_device_state = device_->local_device_state();
1171   if (wait_for_operations_to_complete) {
1172     // Block the host until all usage events have completed. Usage events
1173     // dominate definition events, so this also waits for the buffer to be
1174     // defined.
1175     std::unique_ptr<se::Stream> stream;
1176     for (const auto& stream_and_event : events) {
1177       if (!stream_and_event.event->IsComplete()) {
1178         if (stream == nullptr) {
1179           stream = local_device_state->BorrowStreamFromPool();
1180         }
1181         stream_and_event.event->WaitForEventOnStream(stream.get());
1182       }
1183     }
1184     if (stream != nullptr) {
1185       TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
1186       local_device_state->ReturnStreamToPool(std::move(stream));
1187     }
1188   } else {
1189     if (local_device_state->allocation_model() ==
1190         LocalDeviceState::kComputeSynchronized) {
1191       std::unique_ptr<se::Stream> block_stream;
1192       for (const auto& stream_and_event : events) {
1193         // We only need to do something for events that didn't already acquire a
1194         // reference to the buffer, and also which the compute stream didn't
1195         // already wait for. Based on our heuristics this rare case should only
1196         // occur when a buffer was copied to a device and then never used there.
1197         // In that case we get a new stream and use it to hold onto a reference
1198         // to the buffer until the events are complete.
1199         if (!stream_and_event.reference_held &&
1200             !stream_and_event.event->DefinedOn(
1201                 local_device_state->compute_stream()) &&
1202             !stream_and_event.event->IsComplete()) {
1203           if (block_stream == nullptr) {
1204             block_stream = local_device_state->BorrowStreamFromPool();
1205           }
1206           stream_and_event.event->WaitForEventOnStream(block_stream.get());
1207         }
1208       }
1209       if (block_stream != nullptr) {
1210         se::Stream* block_stream_ptr = block_stream.release();
1211         local_device_state->ThenExecuteCallback(
1212             block_stream_ptr,
1213             [device_buffer, block_stream_ptr, local_device_state]() {
1214               local_device_state->ReturnStreamToPool(
1215                   std::unique_ptr<se::Stream>(block_stream_ptr));
1216             });
1217       }
1218     }
1219   }
1220   return device_buffer;
1221 }
1222 
Delete()1223 void PjRtStreamExecutorBuffer::Delete() {
1224   VLOG(1) << "PjRtStreamExecutorBuffer::Delete";
1225   // When wait_for_reads_to_complete is false, Release should never fail.
1226   TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
1227 }
1228 
IsDeleted()1229 bool PjRtStreamExecutorBuffer::IsDeleted() {
1230   absl::MutexLock lock(&mu_);
1231   return device_buffer_ == nullptr;
1232 }
1233 
1234 StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
GetBufferForHoldLocked(ScopedHold::Type type)1235 PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
1236   // All callers should have called WaitForOutstandingDonationHold().
1237   CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1238   if (type == ScopedHold::kDonation) {
1239     if (device_buffer_ == nullptr) {
1240       return InvalidArgument("Donation requested for invalid buffer");
1241     }
1242     if (holds_[ScopedHold::kExternalReference] > 0) {
1243       return InvalidArgument(
1244           "Donation requested for buffer with external reference");
1245     }
1246     // First add the donation hold.
1247     ++holds_[type];
1248     // Then wait for any usage holds to be dropped or converted. No new usage
1249     // holds can be added until we drop the donation hold so this wait will
1250     // complete eventually.
1251     WaitForOutstandingUsageHolds();
1252     // Because we added a donation hold, nobody could release the buffer while
1253     // we were waiting.
1254     CHECK(device_buffer_ != nullptr);
1255   } else {
1256     if (device_buffer_ == nullptr) {
1257       return InvalidArgument("Buffer has been deleted or donated.");
1258     } else {
1259       ++holds_[type];
1260     }
1261   }
1262   return device_buffer_;
1263 }
1264 
AcquireHoldLocked(ScopedHold * hold)1265 void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
1266   hold->Acquire(GetBufferForHoldLocked(hold->type()));
1267 }
1268 
ConvertUsageHold(TrackedDeviceBuffer * buffer,se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)1269 void PjRtStreamExecutorBuffer::ConvertUsageHold(
1270     TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
1271     std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
1272   absl::MutexLock lock(&mu_);
1273   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1274   buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1275   CHECK_GT(holds_[ScopedHold::kUsage], 0);
1276   --holds_[ScopedHold::kUsage];
1277 }
1278 
ConfirmDonation(TrackedDeviceBuffer * device_buffer)1279 void PjRtStreamExecutorBuffer::ConfirmDonation(
1280     TrackedDeviceBuffer* device_buffer) {
1281   {
1282     absl::MutexLock lock(&mu_);
1283     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1284     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1285     CHECK_EQ(holds_[ScopedHold::kDonation], 1);
1286     holds_[ScopedHold::kDonation] = 0;
1287     CHECK(device_buffer_.get() == device_buffer);
1288     // As a sanity check ensure no more usage events can be added to the buffer.
1289     device_buffer->LockUseAndTransferUsageEvents();
1290     // Give up ownership of the device memory so we don't free it when the last
1291     // reference to device_buffer_ goes away.
1292     device_buffer->ReleaseDeviceMemory();
1293     // Make *this invalid so it can't be used again. Any threads blocking in
1294     // Release or GetBufferWithHold will see an invalid buffer and return.
1295     device_buffer_.reset();
1296   }
1297 }
1298 
DropHold(ScopedHold::Type type,TrackedDeviceBuffer * buffer)1299 void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
1300                                         TrackedDeviceBuffer* buffer) {
1301   absl::MutexLock lock(&mu_);
1302   CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
1303   CHECK_GT(holds_[type], 0);
1304   --holds_[type];
1305   if (type == ScopedHold::kDonation) {
1306     CHECK_EQ(holds_[ScopedHold::kDonation], 0);
1307     CHECK_EQ(holds_[ScopedHold::kUsage], 0);
1308     CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
1309   }
1310 }
1311 
ToLiteral(MutableLiteralBase * literal)1312 PjRtFuture<Status> PjRtStreamExecutorBuffer::ToLiteral(
1313     MutableLiteralBase* literal) {
1314   VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1315   if (IsEmptyTuple()) {
1316     return PjRtFuture<Status>(
1317         InvalidArgument("ToLiteral called on empty tuple"));
1318   }
1319   LocalDeviceState* local_device = device_->local_device_state();
1320   se::Stream* stream = local_device->GetDeviceToHostStream();
1321   ScopedHold device_buffer(this, ScopedHold::kUsage);
1322   {
1323     absl::MutexLock lock(&mu_);
1324     // We can't perform any other action while a donation hold is in progress.
1325     WaitForOutstandingDonationHold();
1326     if (device_buffer_ == nullptr) {
1327       return PjRtFuture<Status>(InvalidArgument(
1328           "CopyToHostAsync() called on deleted or donated buffer"));
1329     }
1330     AcquireHoldLocked(&device_buffer);
1331   }
1332 
1333   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
1334   ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
1335   StatusOr<EventPool::Handle> event_or =
1336       local_device->event_pool().AllocateEvent(stream->parent());
1337   if (!event_or.ok()) {
1338     return PjRtFuture<Status>(event_or.status());
1339   }
1340   auto promise = PjRtFuture<Status>::CreatePromise();
1341   client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
1342       stream, shaped_buffer, literal,
1343       [promise](Status status) mutable { promise.Set(status); });
1344 
1345   auto usage_event = std::make_shared<BufferSequencingEvent>();
1346   local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
1347   usage_event->SetSequencingEvent(std::move(event_or).value(), stream);
1348   // When using the ComputeSynchronized allocation model, retain a reference to
1349   // the device_buffer until the copy completes, to ensure that the buffer isn't
1350   // deleted or donated while it is still in use. The choice of retaining a
1351   // reference at the host is a heuristic; the alternative is to ensure, before
1352   // freeing the buffer, that the compute stream is synchronized past the
1353   // transfer, but it seems better to hold onto the buffer too long than to
1354   // stall the compute stream, particularly since the overwhelmingly common
1355   // use case of CopyToHostAsync will hold onto the reference long enough to
1356   // read the buffer in a subsequent call to ToLiteral.
1357   RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
1358               stream,
1359               /*prefer_to_retain_reference=*/true);
1360 
1361   return PjRtFuture<Status>(
1362       std::move(promise),
1363       /*on_block_start=*/
1364       []() {
1365         tensorflow::profiler::TraceMeProducer traceme(
1366             "PjRtStreamExecutorBuffer::ToLiteral");
1367         VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral";
1368         return PjRtFutureHelpers::ProfilingKeys(
1369             {/*traceme_context_id =*/traceme.GetContextId()});
1370       },
1371       /*on_block_end=*/
1372       [](PjRtFutureHelpers::ProfilingKeys keys) {
1373         tensorflow::profiler::TraceMeConsumer traceme(
1374             "PjRtStreamExecutorBuffer::ToLiteral", keys.traceme_context_id);
1375       });
1376 }
1377 
GetOnDeviceSizeInBytes() const1378 StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() const {
1379   absl::MutexLock lock(&mu_);
1380   if (device_buffer_ == nullptr) {
1381     return InvalidArgument(
1382         "GetOnDeviceSizeInBytes called on deleted or donated buffer");
1383   }
1384   if (device_buffer_->device_memory().size() != 1) {
1385     return InvalidArgument(
1386         "GetOnDeviceSizeInBytes called on tuple-shaped buffer");
1387   }
1388   return device_buffer_->device_memory()[0].size();
1389 }
1390 
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)1391 PjRtFuture<Status> PjRtStreamExecutorBuffer::CopyRawToHost(
1392     void* dst, int64_t offset, int64_t transfer_size) {
1393   return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size);
1394 }
1395 
AsShapedBuffer() const1396 StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
1397   absl::MutexLock lock(&mu_);
1398   if (device_buffer_ == nullptr) {
1399     return InvalidArgument(
1400         "Attempted to fetch value of invalid/deleted buffer.");
1401   }
1402   return device_buffer_->AsShapedBuffer(on_device_shape_);
1403 }
1404 
1405 PjRtStreamExecutorBuffer::ScopedHold
GetBufferWithHold(ScopedHold::Type type)1406 PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
1407   absl::MutexLock lock(&mu_);
1408   // Ensure that at most one donation hold can be in progress at a time.
1409   WaitForOutstandingDonationHold();
1410   ScopedHold hold(this, type);
1411   AcquireHoldLocked(&hold);
1412   return hold;
1413 }
1414 
1415 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1416                    std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice * dst_device,LocalDeviceState * dst_local_device,LocalDeviceState * transfer_local_device,se::Stream * transfer_stream,std::shared_ptr<TrackedDeviceBuffer> src_device_buffer)1417 PjRtStreamExecutorBuffer::CopyToDeviceHelper(
1418     PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
1419     LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
1420     std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
1421   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
1422                       AllocateDestinationBuffer(
1423                           ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
1424                           dst_device, dst_local_device, transfer_stream,
1425                           /*is_uninitialized_create=*/false, client_));
1426 
1427   TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
1428 
1429   WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
1430 
1431   ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
1432   CHECK(dst_device_buffer.ok());
1433   ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
1434 
1435   // Copy the leaf buffers.
1436   StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
1437       [&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
1438     for (const auto& leaf : src_buffer.buffers().leaves()) {
1439       const ShapeIndex& index = leaf.first;
1440       const se::DeviceMemoryBase& input_buffer = leaf.second;
1441       const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
1442       TF_RET_CHECK(input_buffer.size() == output_buffer.size())
1443           << "input: " << input_buffer.size()
1444           << " output: " << output_buffer.size();
1445       if (input_buffer.size() != 0) {
1446         TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
1447             transfer_stream, dst_local_device->compute_stream(), input_buffer,
1448             output_buffer));
1449       }
1450     }
1451     std::shared_ptr<BufferSequencingEvent> event =
1452         dst_device_buffer->definition_events()[0];
1453     TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
1454         transfer_local_device, std::move(dst_device_buffer), event,
1455         transfer_stream));
1456     return event;
1457   }();
1458   if (!copy_event_or.ok()) {
1459     StallStreamOnError(transfer_local_device, transfer_stream);
1460     if (transfer_local_device == dst_local_device) {
1461       // Some copies may have been enqueued before the error was returned, and
1462       // StallStreamOnError only makes sure the destination device is ok, so
1463       // make sure that the src buffer remains valid until after any transfers
1464       // have completed.
1465       device_->local_device_state()->ThenRelease(transfer_stream,
1466                                                  std::move(src_device_buffer));
1467     }
1468     return copy_event_or.status();
1469   }
1470 
1471   return std::pair<std::unique_ptr<PjRtBuffer>,
1472                    std::shared_ptr<BufferSequencingEvent>>(
1473       std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
1474       std::move(copy_event_or).value());
1475 }
1476 
CopyToDevice(PjRtDevice * dst_device)1477 StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
1478     PjRtDevice* dst_device) {
1479   tensorflow::profiler::TraceMe traceme(
1480       "PjRtStreamExecutorBuffer::CopyToDevice");
1481   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToDevice";
1482   if (dst_device == device_) {
1483     return InvalidArgument(
1484         "CopyToDevice cannot accept the same source and destination devices");
1485   }
1486 
1487   // Copying across PjRtClients involves a copy through the host.
1488   if (dst_device->client() != client_) {
1489     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteralSync());
1490     // Avoid use-after-free on `literal` due to unsequenced move and use.
1491     Literal* literal_pointer = literal.get();
1492     absl::InlinedVector<int64_t, 4> byte_strides(
1493         literal->shape().dimensions_size());
1494     TF_RETURN_IF_ERROR(
1495         ShapeUtil::ByteStrides(literal->shape(), absl::MakeSpan(byte_strides)));
1496     return dst_device->client()->BufferFromHostBuffer(
1497         literal_pointer->untyped_data(),
1498         literal_pointer->shape().element_type(),
1499         literal_pointer->shape().dimensions(), byte_strides,
1500         PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
1501         [literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
1502   }
1503 
1504   TF_ASSIGN_OR_RETURN(
1505       LocalDeviceState * dst_local_device,
1506       tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
1507           ->GetLocalDeviceState());
1508   LocalDeviceState* transfer_local_device =
1509       client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
1510                                                 : dst_local_device;
1511   CHECK_EQ(dst_local_device->allocation_model(),
1512            transfer_local_device->allocation_model());
1513 
1514   se::Stream* transfer_stream =
1515       transfer_local_device->GetDeviceToDeviceStream();
1516 
1517   ScopedHold src_device_buffer(this, ScopedHold::kUsage);
1518   {
1519     absl::MutexLock lock(&mu_);
1520     // We can't perform any other action while a donation hold is in progress.
1521     WaitForOutstandingDonationHold();
1522     if (device_buffer_ == nullptr) {
1523       return InvalidArgument(
1524           "CopyToDevice called on deleted or donated buffer");
1525     }
1526     AcquireHoldLocked(&src_device_buffer);
1527   }
1528 
1529   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
1530                      std::shared_ptr<BufferSequencingEvent>>>
1531       buffer_and_event_or = CopyToDeviceHelper(
1532           dst_device, dst_local_device, transfer_local_device, transfer_stream,
1533           src_device_buffer.buffer());
1534   if (!buffer_and_event_or.ok()) {
1535     return buffer_and_event_or.status();
1536   }
1537 
1538   auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
1539   std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
1540   std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
1541 
1542   // prefer_to_retain_reference=*/true means that, when using the
1543   // ComputeSynchronized allocation model, retain a reference to the
1544   // src_device_buffer until the copy completes. This is a heuristic; the
1545   // alternative is to ensure, before freeing the buffer, that the compute
1546   // stream is synchronized past the transfer, but it seems better to hold onto
1547   // the buffer too long than to stall the compute stream.
1548   RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
1549               transfer_local_device, event, transfer_stream,
1550               /*prefer_to_retain_reference=*/true);
1551 
1552   return std::move(buffer);
1553 }
1554 
CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)1555 void PjRtStreamExecutorBuffer::CopyToRemoteDevice(
1556     absl::string_view serialized_descriptor, RemoteSendCallback on_done) {
1557   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDevice";
1558   client_->CopyToRemoteDevice(this, serialized_descriptor, std::move(on_done));
1559 }
1560 
CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)1561 void PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered(
1562     absl::Span<const std::pair<std::string, RemoteSendCallback>>
1563         serialized_descriptors_and_callbacks,
1564     const ScatterDetails& scatter_details) {
1565   VLOG(1) << "PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered";
1566   client_->CopyToRemoteDeviceScattered(
1567       this, serialized_descriptors_and_callbacks, scatter_details);
1568 }
1569 
GetReadyFuture()1570 PjRtFuture<Status> PjRtStreamExecutorBuffer::GetReadyFuture() {
1571   std::shared_ptr<TrackedDeviceBuffer> device_buffer;
1572   PjRtFuture<Status>::Promise definition_promise;
1573   {
1574     absl::MutexLock lock(&mu_);
1575     if (device_buffer_ == nullptr) {
1576       return PjRtFuture<Status>(InvalidArgument(
1577           "GetReadyFuture() called on deleted or donated buffer"));
1578     }
1579     if (!definition_promise_) {
1580       device_buffer = device_buffer_;
1581       definition_promise_ = PjRtFuture<Status>::CreatePromise();
1582     }
1583     definition_promise = definition_promise_;
1584   }
1585 
1586   if (device_buffer) {
1587     LocalDeviceState* local_device_state = device_->local_device_state();
1588     std::unique_ptr<se::Stream> stream;
1589     for (auto& event : device_buffer->definition_events()) {
1590       if (!event->IsComplete()) {
1591         if (stream == nullptr) {
1592           stream = local_device_state->BorrowStreamFromPool();
1593         }
1594         event->WaitForEventOnStream(stream.get());
1595       }
1596     }
1597     if (stream != nullptr) {
1598       auto* stream_ptr = stream.release();
1599       // We already borrowed a stream from the pool so we can safely do the
1600       // callback directly on that stream instead of bouncing through
1601       // local_device_state->ThenExecuteCallback. The direct callback saves
1602       // significant time.
1603       stream_ptr->ThenDoHostCallback(
1604           [definition_promise, stream_ptr, local_device_state]() mutable {
1605             local_device_state->ReturnStreamToPool(
1606                 std::unique_ptr<se::Stream>(stream_ptr));
1607             definition_promise.Set(OkStatus());
1608           });
1609     } else {
1610       // All events are already complete.
1611       definition_promise.Set(OkStatus());
1612     }
1613   }
1614 
1615   return PjRtFuture<Status>(
1616       std::move(definition_promise),
1617       /*on_block_start=*/
1618       []() {
1619         tensorflow::profiler::TraceMeProducer traceme(
1620             "PjRtStreamExecutorBuffer::Await");
1621         VLOG(1) << "PjRtStreamExecutorBuffer::Await";
1622         return PjRtFutureHelpers::ProfilingKeys(
1623             {/*traceme_context_id=*/traceme.GetContextId()});
1624       },
1625       /*on_block_end=*/
1626       [](PjRtFutureHelpers::ProfilingKeys keys) {
1627         tensorflow::profiler::TraceMeConsumer traceme(
1628             "PjRtStreamExecutorBuffer::Await", keys.traceme_context_id);
1629       });
1630 }
1631 
1632 namespace {
1633 
1634 // Helper struct for the tuple that is transiently constructed to hold the
1635 // arguments of an execution.
1636 struct TupleHandle {
1637   // The ExecutionInput describing the tuple.
1638   ExecutionInput execution_input;
1639   // A definition event that has been recorded on the host_to_device stream
1640   // after the tuple table transfer.
1641   std::shared_ptr<BufferSequencingEvent> event;
1642 };
1643 
CheckCompatibleShapes(bool strict_shape_checking,const Shape & buffer_shape,const Shape & execution_shape,const TransferManager & transfer_manager,int parameter_index)1644 Status CheckCompatibleShapes(bool strict_shape_checking,
1645                              const Shape& buffer_shape,
1646                              const Shape& execution_shape,
1647                              const TransferManager& transfer_manager,
1648                              int parameter_index) {
1649   // TODO(misard) Support casting of tuple parameters.
1650   if (strict_shape_checking || buffer_shape.IsTuple()) {
1651     if (!ShapeUtil::Equal(buffer_shape, execution_shape)) {
1652       return InvalidArgument(
1653           "Executable expected shape %s for argument %d but got "
1654           "incompatible "
1655           "shape %s",
1656           ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1657           ShapeUtil::HumanStringWithLayout(buffer_shape));
1658     }
1659   } else {
1660     if (transfer_manager.GetByteSizeRequirement(buffer_shape) !=
1661         transfer_manager.GetByteSizeRequirement(execution_shape)) {
1662       return InvalidArgument(
1663           "Executable expected shape %s for argument %d but got "
1664           "incompatible "
1665           "shape %s",
1666           ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
1667           ShapeUtil::HumanStringWithLayout(buffer_shape));
1668     }
1669   }
1670   return OkStatus();
1671 }
1672 
1673 // Makes a tuple from the arguments to an execution.
MakeTupleHelper(PjRtStreamExecutorClient * client,LocalDeviceState * local_device,bool strict_shape_checking,const Shape & tupled_parameter_shape,absl::Span<PjRtBuffer * const> py_buffers,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,int device_ordinal)1674 StatusOr<TupleHandle> MakeTupleHelper(
1675     PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
1676     bool strict_shape_checking, const Shape& tupled_parameter_shape,
1677     absl::Span<PjRtBuffer* const> py_buffers,
1678     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1679     int device_ordinal) {
1680   se::DeviceMemoryAllocator* allocator = client->allocator();
1681   TransferManager* transfer_manager =
1682       client->client()->backend().transfer_manager();
1683 
1684   if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
1685     return InvalidArgument("Executable expected %lld parameters but got %lld",
1686                            tupled_parameter_shape.tuple_shapes_size(),
1687                            py_buffers.size());
1688   }
1689   for (int i = 0; i < py_buffers.size(); ++i) {
1690     TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1691         strict_shape_checking, py_buffers[i]->on_device_shape(),
1692         tupled_parameter_shape.tuple_shapes(i), *transfer_manager, i));
1693   }
1694 
1695   se::Stream* stream = local_device->host_to_device_stream();
1696   TF_ASSIGN_OR_RETURN(
1697       se::OwningDeviceMemory root_table_memory,
1698       allocator->Allocate(
1699           device_ordinal,
1700           transfer_manager->GetByteSizeRequirement(tupled_parameter_shape)));
1701 
1702   if (local_device->allocation_model() ==
1703       LocalDeviceState::kComputeSynchronized) {
1704     stream->ThenWaitFor(local_device->compute_stream());
1705   } else {
1706     DCHECK(transfer_manager->CanBufferBeAccessedNow(
1707         local_device->compute_stream()->parent(), root_table_memory.cref()));
1708   }
1709 
1710   ExecutionInput execution_input(tupled_parameter_shape);
1711   ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1712       execution_input.MutableBuffers()->begin();
1713   ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1714       execution_input.MutableBuffers()->end();
1715   // First set the root tuple table which is the first buffer in the ShapeTree.
1716   execution_input.SetBuffer(
1717       input_iterator->first,
1718       MaybeOwningDeviceMemory(std::move(root_table_memory)));
1719   ++input_iterator;
1720   // Then set each sub-tuple in turn from the parameters.
1721   for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
1722        device_buffers) {
1723     device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
1724                              allocator);
1725   }
1726   CHECK(input_iterator == iterator_end);
1727 
1728   TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
1729       stream, execution_input.Buffers()));
1730   StatusOr<EventPool::Handle> event_or =
1731       local_device->event_pool().ThenAllocateAndRecordEvent(stream);
1732   if (!event_or.ok()) {
1733     StallStreamOnError(local_device, stream);
1734     return event_or.status();
1735   }
1736 
1737   auto transfer_event = std::make_shared<BufferSequencingEvent>();
1738   transfer_event->SetSequencingEvent(std::move(event_or).value(), stream);
1739   return TupleHandle({std::move(execution_input), std::move(transfer_event)});
1740 }
1741 
1742 // Converts a ScopedShapedBuffer returned from an execution into a
1743 // PjRtBuffer.
OutputBufferHelper(ScopedShapedBuffer * result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtClient * client,PjRtDevice * device,LocalDeviceState * local_device,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release)1744 std::unique_ptr<PjRtBuffer> OutputBufferHelper(
1745     ScopedShapedBuffer* result_buffer,
1746     std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
1747     PjRtDevice* device, LocalDeviceState* local_device,
1748     std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) {
1749   std::shared_ptr<TrackedDeviceBuffer> out_buffer =
1750       TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
1751                                                   {definition_event});
1752   auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
1753       result_buffer->on_device_shape(), std::move(out_buffer), client, device);
1754   RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
1755               definition_event, local_device->compute_stream(),
1756               /*prefer_to_retain_reference=*/false, &buffers_to_release);
1757   return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
1758 }
1759 }  // namespace
1760 
PjRtStreamExecutorExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,bool parameter_is_tupled_arguments,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<LogicalDeviceIds> addressable_device_logical_ids,std::vector<PjRtDevice * > addressable_devices,PjRtStreamExecutorClient * client)1761 PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
1762     std::vector<std::unique_ptr<LocalExecutable>> executables,
1763     bool parameter_is_tupled_arguments,
1764     std::shared_ptr<DeviceAssignment> device_assignment,
1765     std::vector<LogicalDeviceIds> addressable_device_logical_ids,
1766     std::vector<PjRtDevice*> addressable_devices,
1767     PjRtStreamExecutorClient* client)
1768     : client_(client),
1769       device_assignment_(std::move(device_assignment)),
1770       parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
1771       addressable_device_logical_ids_(
1772           std::move(addressable_device_logical_ids)),
1773       addressable_devices_(std::move(addressable_devices)) {
1774   TransferManager* transfer_manager =
1775       client_->client()->backend().transfer_manager();
1776   executables_.reserve(executables.size());
1777   for (auto& executable : executables) {
1778     const auto& computation_layout =
1779         executable->executable()->module().entry_computation_layout();
1780     std::vector<Shape> parameter_shapes;
1781     parameter_shapes.reserve(computation_layout.parameter_count());
1782     for (int i = 0; i < computation_layout.parameter_count(); ++i) {
1783       parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
1784           computation_layout.parameter_shape(i)));
1785     }
1786     executables_.emplace_back(std::move(executable));
1787     on_device_executable_parameter_shapes_.push_back(
1788         std::move(parameter_shapes));
1789   }
1790 
1791   int num_partitions;
1792   if (device_assignment_ == nullptr) {
1793     // This must go after `executables_` is initialized.
1794     VLOG(3) << "PjRtStreamExecutorExecutable portable single-core";
1795     num_partitions = 1;
1796     CHECK(addressable_devices_.empty());
1797   } else {
1798     // This must go after `executables_` is initialized.
1799     VLOG(3) << "PjRtStreamExecutorExecutable device_assignment:\n"
1800             << device_assignment_->ToString();
1801     CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
1802     CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
1803         << "Inconsistent local device count.";
1804     num_partitions = device_assignment_->computation_count();
1805   }
1806 
1807   // SPMD sharding produces a single executable for multiple partitions.
1808   if (executables_.size() > 1) {
1809     CHECK_EQ(num_partitions, executables_.size())
1810         << "Number of executables " << executables_.size()
1811         << " did not match number of partitions " << num_partitions;
1812   }
1813 }
1814 
SetUpDonation(bool tuple_inputs)1815 Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
1816   parameters_that_must_be_donated_.reserve(executables_.size());
1817   for (auto& executable : executables_) {
1818     TF_ASSIGN_OR_RETURN(std::vector<int> parameters_to_donate,
1819                         ComputeParametersThatMustBeDonated(
1820                             executable->executable()->module(), tuple_inputs));
1821     parameters_that_must_be_donated_.emplace_back(
1822         std::move(parameters_to_donate));
1823   }
1824   return OkStatus();
1825 }
1826 
name() const1827 absl::string_view PjRtStreamExecutorExecutable::name() const {
1828   Executable* executable = executables_[0]->executable();
1829   if (executable->has_module()) {
1830     return executable->module().name();
1831   } else {
1832     return "<unknown executable>";
1833   }
1834 }
1835 
ParametersThatMustBeDonated(int executable_idx) const1836 absl::Span<int const> PjRtStreamExecutorExecutable::ParametersThatMustBeDonated(
1837     int executable_idx) const {
1838   return parameters_that_must_be_donated_[executable_idx];
1839 }
1840 
1841 StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(int device_ordinal,const ExecuteOptions & options,absl::Span<const Shape> executable_parameter_shapes,absl::Span<PjRtBuffer * const> argument_handles,absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,absl::flat_hash_set<BufferSequencingEvent * > & events) const1842 PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
1843     int device_ordinal, const ExecuteOptions& options,
1844     absl::Span<const Shape> executable_parameter_shapes,
1845     absl::Span<PjRtBuffer* const> argument_handles,
1846     absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
1847     absl::flat_hash_set<BufferSequencingEvent*>& events) const {
1848   std::vector<ExecutionInput> execution_inputs;
1849   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1850   TransferManager* transfer_manager =
1851       client_->client()->backend().transfer_manager();
1852   // Lift tuple_handle outside the conditional so that the event it returns is
1853   // not destroyed until after the loop below that waits on events.
1854   std::optional<TupleHandle> tuple_handle;
1855   if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
1856     TF_ASSIGN_OR_RETURN(
1857         tuple_handle,
1858         MakeTupleHelper(client_, device_state, options.strict_shape_checking,
1859                         executable_parameter_shapes[0], argument_handles,
1860                         device_buffers, device_ordinal));
1861     events.insert(tuple_handle->event.get());
1862     execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
1863   } else {
1864     if (argument_handles.size() != executable_parameter_shapes.size()) {
1865       return InvalidArgument("Executable expected %lld arguments but got %lld",
1866                              executable_parameter_shapes.size(),
1867                              argument_handles.size());
1868     }
1869     execution_inputs.reserve(argument_handles.size());
1870     for (int i = 0; i < argument_handles.size(); ++i) {
1871       PjRtBuffer* handle = argument_handles[i];
1872 
1873       // Make an ExecutionInput from the device buffer.
1874       TF_RETURN_IF_ERROR(CheckCompatibleShapes(
1875           options.strict_shape_checking, handle->on_device_shape(),
1876           executable_parameter_shapes[i], *transfer_manager, i));
1877       execution_inputs.emplace_back(executable_parameter_shapes[i]);
1878       ExecutionInput& execution_input = execution_inputs.back();
1879       ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
1880           execution_input.MutableBuffers()->begin();
1881       ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
1882           execution_input.MutableBuffers()->end();
1883       device_buffers[i].AddToInput(&input_iterator, iterator_end,
1884                                    &execution_input, client_->allocator());
1885       CHECK(input_iterator == iterator_end);
1886     }
1887   }
1888 
1889   for (BufferSequencingEvent* event : events) {
1890     event->WaitForEventOnStream(device_state->compute_stream());
1891   }
1892 
1893   return execution_inputs;
1894 }
1895 
1896 // Enqueues a computation onto the compute stream. Each buffer returned in
1897 // device_buffers has a usage hold added that must be dropped on error or
1898 // converted on success.
EnqueueExecution(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,int executable_idx,const RunId & run_id,const ExecuteOptions & options,PjRtDevice * device,std::vector<PjRtStreamExecutorBuffer::ScopedHold> * device_buffers,std::shared_ptr<DeviceAssignment> device_assignment,std::vector<std::function<void ()>> & compute_callbacks) const1899 StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
1900     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
1901     int executable_idx, const RunId& run_id, const ExecuteOptions& options,
1902     PjRtDevice* device,
1903     std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
1904     std::shared_ptr<DeviceAssignment> device_assignment,
1905     std::vector<std::function<void()>>& compute_callbacks) const {
1906   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
1907                            ->local_device_state()
1908                            ->device_ordinal();
1909   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1910   tensorflow::profiler::TraceMeConsumer activity(
1911       "PjRtStreamExecutorExecutable::EnqueueExecution",
1912       tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
1913   VLOG(3) << "Replica " << replica << ", partition " << partition
1914           << " mapped to device ordinal for execution: " << device_ordinal;
1915 
1916   absl::flat_hash_set<BufferSequencingEvent*> events;
1917   device_buffers->reserve(argument_handles.size());
1918   absl::Span<int const> donated_params =
1919       ParametersThatMustBeDonated(executable_idx);
1920   auto donate_it = donated_params.begin();
1921   for (int i = 0; i < argument_handles.size(); ++i) {
1922     auto* handle =
1923         tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
1924     if (handle->device() != device) {
1925       return InvalidArgument(
1926           "Buffer passed to Execute() as argument %d to replica %d is on "
1927           "device %s, but replica is assigned to device %s.",
1928           i, replica, handle->device()->DebugString(), device->DebugString());
1929     }
1930     bool must_donate = donate_it != donated_params.end() && *donate_it == i;
1931     if (must_donate) {
1932       ++donate_it;
1933     }
1934     device_buffers->emplace_back(handle->GetBufferWithHold(
1935         must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
1936                     : PjRtStreamExecutorBuffer::ScopedHold::kUsage));
1937     PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
1938         device_buffers->back();
1939     if (!device_buffer.ok()) {
1940       return InvalidArgument(
1941           "Invalid buffer passed to Execute() as argument %d to replica %d: "
1942           "%s",
1943           i, replica, device_buffer.status().ToString());
1944     }
1945     // If we are trying to donate the buffer wait on the usage events as well
1946     // as the definition events to ensure that all reads have been completed
1947     // before the buffer is mutated. Usage holds are excluded during a donation
1948     // hold so we know that the set of usage events won't be modified while we
1949     // are enqueueing.
1950     GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
1951                           &events);
1952   }
1953 
1954   if (options.arguments_are_tupled) {
1955     if (!parameter_is_tupled_arguments_) {
1956       return InvalidArgument(
1957           "Arguments may only be supplied as a tuple when the executable was "
1958           "compiled with a single tupled parameter");
1959     }
1960     if (argument_handles.size() != 1) {
1961       return InvalidArgument(
1962           "Option arguments_are_tupled was true but %d buffers were passed to "
1963           "execution",
1964           argument_handles.size());
1965     }
1966   }
1967 
1968   TF_ASSIGN_OR_RETURN(
1969       std::vector<ExecutionInput> execution_inputs,
1970       MakeExecutionInputsAndWaitForEvents(
1971           device_ordinal, options,
1972           on_device_executable_parameter_shapes_[executable_idx],
1973           argument_handles, *device_buffers, events));
1974 
1975   ExecutableRunOptions run_options;
1976   run_options.set_stream(device_state->compute_stream());
1977   run_options.set_host_to_device_stream(device_state->host_to_device_stream());
1978   run_options.set_allocator(client_->allocator());
1979   run_options.set_intra_op_thread_pool(
1980       client_->client()->backend().eigen_intra_op_thread_pool_device());
1981   run_options.set_device_assignment(device_assignment.get());
1982   run_options.set_run_id(run_id);
1983   run_options.set_rng_seed(device_state->GetNewPrngSeed());
1984   run_options.set_gpu_executable_run_options(client_->gpu_run_options());
1985   run_options.set_launch_id(options.launch_id);
1986   if (run_options.launch_id() != 0) {
1987     VLOG(3) << "launch id for " << name() << ": " << run_options.launch_id();
1988   }
1989 
1990   // The choice of where we wait is arbitrary; the reason for the wait is
1991   // pacing to avoid problems such as memory fragmentation and running ahead
1992   // too far, not for correctness. Placing it before the executable launch
1993   // allows the inputs for the next executable to be fetched even if the
1994   // launch is delayed.
1995   std::shared_ptr<Semaphore::ScopedReservation> compute_reservation;
1996   {
1997     tensorflow::profiler::TraceMe traceme("ComputeSemaphoreAcquire");
1998     compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1999         device_state->compute_semaphore().ScopedAcquire(1));
2000   }
2001 
2002   StatusOr<ExecutionOutput> result_buffer_or_status =
2003       executables_[executable_idx]->RunAsync(std::move(execution_inputs),
2004                                              run_options);
2005 
2006   VLOG(1) << "Replica " << replica << " partition " << partition
2007           << " completed; ok=" << result_buffer_or_status.ok();
2008 
2009   if (!result_buffer_or_status.ok()) {
2010     return result_buffer_or_status.status();
2011   }
2012 
2013   if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2014     ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
2015     // If we used a transient tuple for the arguments we donated its root table
2016     // buffer. In that case, and/or if we donated any input buffers that were
2017     // not aliased, the donated buffers are going to be passed back to us via
2018     // the execution output. We need to ensure they aren't freed until after
2019     // execution completes. (Currently XLA does not support aliasing tuple
2020     // tables, so if any donated parameter is a tuple there will be donated but
2021     // unaliased buffers.)
2022     std::vector<se::OwningDeviceMemory> donated_memory =
2023         execution_output.ConsumeToBeReleased();
2024     absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
2025     donated_ptrs.reserve(donated_memory.size());
2026     for (se::OwningDeviceMemory& owning : donated_memory) {
2027       // Release the owning memory so we can pass it to the closure.
2028       donated_ptrs.push_back(owning.Release());
2029     }
2030     compute_callbacks.push_back(
2031         [references{std::make_tuple(executables_[executable_idx],
2032                                     compute_reservation, device_assignment)},
2033          donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
2034          device_ordinal]() {
2035           for (const auto& ptr : donated_ptrs) {
2036             TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
2037           }
2038         });
2039   } else {
2040     // Any donated memory returned by the ExecutionOutput can be immediately
2041     // freed.
2042     compute_callbacks.push_back(
2043         [to_release{std::make_tuple(executables_[executable_idx],
2044                                     compute_reservation,
2045                                     device_assignment)}]() {});
2046   }
2047 
2048   return std::move(result_buffer_or_status).value().ConsumeResult();
2049 }
2050 
2051 std::vector<std::unique_ptr<PjRtBuffer>>
MakeOutputBuffers(int device_ordinal,const ExecuteOptions & options,ScopedShapedBuffer result_buffer,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtDevice * device,std::vector<std::function<void ()>> & compute_callbacks,std::vector<std::shared_ptr<TrackedDeviceBuffer>> & buffers_to_release) const2052 PjRtStreamExecutorExecutable::MakeOutputBuffers(
2053     int device_ordinal, const ExecuteOptions& options,
2054     ScopedShapedBuffer result_buffer,
2055     std::shared_ptr<BufferSequencingEvent> definition_event, PjRtDevice* device,
2056     std::vector<std::function<void()>>& compute_callbacks,
2057     std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
2058     const {
2059   tensorflow::profiler::TraceMe traceme("MakeOutputBuffers");
2060   std::vector<std::unique_ptr<PjRtBuffer>> outputs;
2061   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2062   if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
2063     int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
2064     outputs.reserve(tuple_count);
2065     // Take ownership of each of the output values, leaving only the root table
2066     // in result_buffer.
2067     for (int i = 0; i < tuple_count; ++i) {
2068       ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
2069       outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
2070                                            client_, device, device_state,
2071                                            buffers_to_release));
2072     }
2073     if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
2074       // Don't release the root buffer until after execution completes.
2075       ShapedBuffer root_buffer_holder = result_buffer.release();
2076       se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
2077       compute_callbacks.push_back(
2078           [root_buffer, allocator{client_->allocator()}, device_ordinal]() {
2079             TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
2080           });
2081     }
2082   } else {
2083     outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
2084                                          client_, device, device_state,
2085                                          buffers_to_release));
2086   }
2087   return outputs;
2088 }
2089 
2090 StatusOr<PjRtLoadedExecutable::Result>
ExecuteHelper(absl::Span<PjRtBuffer * const> argument_handles,int replica,int partition,const RunId & run_id,const ExecuteOptions & options,bool fill_future,PjRtDevice * device) const2091 PjRtStreamExecutorExecutable::ExecuteHelper(
2092     absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
2093     const RunId& run_id, const ExecuteOptions& options, bool fill_future,
2094     PjRtDevice* device) const {
2095   const uint64_t start_time_usecs = tensorflow::Env::Default()->NowMicros();
2096   std::shared_ptr<DeviceAssignment> device_assignment;
2097   if (device == nullptr) {
2098     CHECK(device_assignment_ != nullptr);
2099     const int device_id = (*device_assignment_)(replica, partition);
2100     TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
2101     device_assignment = device_assignment_;
2102   } else {
2103     CHECK(device_assignment_ == nullptr);
2104     CHECK_EQ(replica, 0);
2105     CHECK_EQ(partition, 0);
2106     CHECK(addressable_devices_.empty());
2107     device_assignment = std::make_shared<DeviceAssignment>(1, 1);
2108     (*device_assignment)(0, 0) = device->id();
2109   }
2110 
2111   CHECK_EQ(device->process_index(), client_->process_index());
2112   int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2113                            ->local_device_state()
2114                            ->device_ordinal();
2115   tensorflow::profiler::TraceMe traceme(
2116       "PjRtStreamExecutorExecutable::ExecuteHelper");
2117   VLOG(1) << "Replica " << replica << ", partition " << partition
2118           << " mapped to device ordinal for execution: " << device_ordinal;
2119 
2120   // SPMD sharding produces a single executable for multiple partitions.
2121   int executable_idx = executables_.size() > 1 ? partition : 0;
2122 
2123   std::vector<std::function<void()>> compute_callbacks;
2124   std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
2125   device_buffers.reserve(argument_handles.size());
2126   StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
2127       argument_handles, replica, partition, executable_idx, run_id, options,
2128       device, &device_buffers, std::move(device_assignment), compute_callbacks);
2129 
2130   if (!result_buffer_or_status.ok()) {
2131     LOG(ERROR) << "Execution of replica " << replica
2132                << " failed: " << result_buffer_or_status.status();
2133     return result_buffer_or_status.status();
2134   }
2135   ScopedShapedBuffer result_buffer = std::move(result_buffer_or_status).value();
2136 
2137   LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
2138   se::Stream* stream = device_state->compute_stream();
2139   StatusOr<EventPool::Handle> event_or =
2140       device_state->event_pool().ThenAllocateAndRecordEvent(stream);
2141   if (!event_or.ok()) {
2142     StallStreamOnError(device_state, stream);
2143     for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2144       if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
2145         // Even though there was an error we need to call ConfirmDonation, which
2146         // renders b invalid, since the computation has been enqueued and b has
2147         // been donated.
2148         b.ConfirmDonation();
2149       }
2150     }
2151     return event_or.status();
2152   }
2153   auto definition_event = std::make_shared<BufferSequencingEvent>();
2154   definition_event->SetSequencingEvent(std::move(event_or).value(), stream);
2155   std::vector<std::shared_ptr<TrackedDeviceBuffer>> buffers_to_release;
2156   std::vector<std::unique_ptr<PjRtBuffer>> outputs = MakeOutputBuffers(
2157       device_ordinal, options, std::move(result_buffer), definition_event,
2158       device, compute_callbacks, buffers_to_release);
2159 
2160   for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
2161     // prefer_to_retain_reference=false because when using the
2162     // ComputeSynchronized allocation model we don't need to retain a reference
2163     // to the device_buffer during execution because by definition the compute
2164     // stream is synchronized past the execution.
2165     if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
2166       RecordUsage(std::move(b), device_state, device_state, definition_event,
2167                   stream,
2168                   /*prefer_to_retain_reference=*/false, &buffers_to_release);
2169     } else {
2170       CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
2171       b.ConfirmDonation();
2172     }
2173   }
2174 
2175   std::optional<PjRtFuture<Status>> future;
2176   if (fill_future) {
2177     auto promise = PjRtFuture<Status>::CreatePromise();
2178     future = PjRtFuture<Status>(promise);
2179     compute_callbacks.push_back(
2180         [promise = std::move(promise)]() mutable { promise.Set(OkStatus()); });
2181   }
2182   device_state->ThenExecuteCallback(
2183       stream, [callbacks{std::move(compute_callbacks)},
2184                buffers_to_release{std::move(buffers_to_release)}]() {
2185         for (auto& fn : callbacks) {
2186           fn();
2187         }
2188       });
2189   ReportExecutableEnqueueTime(tensorflow::Env::Default()->NowMicros() -
2190                               start_time_usecs);
2191   return Result({/*future=*/std::move(future), /*buffers=*/std::move(outputs)});
2192 }
2193 
2194 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)2195 PjRtStreamExecutorExecutable::Execute(
2196     absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
2197     const ExecuteOptions& options,
2198     std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
2199   if (device_assignment_ == nullptr) {
2200     return InvalidArgument("Execute expects a non-null device_assignment");
2201   }
2202 
2203   RunId run_id;
2204   tensorflow::profiler::TraceMeProducer activity(
2205       "PjRtStreamExecutorExecutable::Execute",
2206       tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
2207 
2208   const int num_addressable_devices = addressable_devices_.size();
2209 
2210   if (argument_handles.size() != num_addressable_devices) {
2211     return InvalidArgument(
2212         "Attempted to execute with %d argument lists when local device "
2213         "count is %d (total replica count: %d, partition count: %d)",
2214         argument_handles.size(), num_addressable_devices, num_replicas(),
2215         num_partitions());
2216   }
2217 
2218   VLOG(1) << "Executing computation " << name()
2219           << "; num_replicas=" << num_replicas()
2220           << " num_partitions=" << num_partitions()
2221           << " num_addressable_devices=" << num_addressable_devices;
2222   std::vector<StatusOr<Result>> results(num_addressable_devices);
2223   if (num_addressable_devices == 1) {
2224     // Fast-path if there is only one device — run the computation on the
2225     // current thread.
2226     const int replica = addressable_device_logical_ids_[0].replica;
2227     const int partition = addressable_device_logical_ids_[0].partition;
2228     results[0] = ExecuteHelper(argument_handles[0], replica, partition, run_id,
2229                                options, returned_futures.has_value());
2230   } else {
2231     absl::Mutex mu;
2232     int running = num_addressable_devices;
2233     int failed = 0;
2234     Status first_failure_status;
2235 
2236     for (int i = 0; i < num_addressable_devices; ++i) {
2237       const int replica = addressable_device_logical_ids_[i].replica;
2238       const int partition = addressable_device_logical_ids_[i].partition;
2239       PjRtDevice* device = addressable_devices_[i];
2240       const LocalDeviceState& device_state =
2241           *tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
2242                ->local_device_state();
2243       device_state.execute_thread()->Schedule([&, replica, partition, i] {
2244         results[i] =
2245             ExecuteHelper(argument_handles[i], replica, partition, run_id,
2246                           options, returned_futures.has_value());
2247 
2248         absl::MutexLock lock(&mu);
2249         --running;
2250         if (!results[i].ok()) {
2251           if (failed == 0) {
2252             first_failure_status = results[i].status();
2253           }
2254           ++failed;
2255         }
2256       });
2257     }
2258 
2259     auto done_running_or_failed = [&]() {
2260       mu.AssertHeld();
2261       return running == 0 || failed > 0;
2262     };
2263     absl::MutexLock lock(&mu);
2264     mu.Await(absl::Condition(&done_running_or_failed));
2265     if (failed > 0) {
2266       auto done_running = [&]() {
2267         mu.AssertHeld();
2268         return running == 0;
2269       };
2270       // If execution does not terminate within a reasonable amount of time,
2271       // we may be stuck at a cross-replica barrier on-device. Terminate the
2272       // process since that's the only way we can escape this situation at the
2273       // moment (b/130629719).
2274       if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
2275                                absl::Seconds(10))) {
2276         LOG(FATAL)
2277             << "Replicated computation launch failed, but not all replicas "
2278                "terminated. Aborting process to work around deadlock. "
2279                "Failure message (there may have been multiple failures, see "
2280                "the error log for all failures): \n\n"
2281             << first_failure_status.error_message();
2282       }
2283     }
2284   }
2285   VLOG(1) << "Replicated execution complete.";
2286 
2287   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
2288       num_addressable_devices);
2289   if (returned_futures.has_value()) {
2290     returned_futures->reserve(num_addressable_devices);
2291   }
2292   for (int i = 0; i < num_addressable_devices; ++i) {
2293     const int replica = addressable_device_logical_ids_[i].replica;
2294     const int partition = addressable_device_logical_ids_[i].partition;
2295     auto& statusor = results[i];
2296     if (!statusor.ok()) {
2297       if (returned_futures.has_value()) {
2298         returned_futures->clear();
2299       }
2300       if (num_addressable_devices == 1) {
2301         return statusor.status();
2302       } else {
2303         return AppendStatus(
2304             statusor.status(),
2305             absl::StrFormat("while running replica %d and partition %d of a "
2306                             "replicated computation (other "
2307                             "replicas may have failed as well).",
2308                             replica, partition));
2309       }
2310     }
2311     wrapped_results[i] = std::move(statusor->buffers);
2312     if (returned_futures.has_value()) {
2313       returned_futures->push_back(*std::move(statusor->future));
2314     }
2315   }
2316   return wrapped_results;
2317 }
2318 
2319 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)2320 PjRtStreamExecutorExecutable::ExecuteSharded(
2321     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2322     const ExecuteOptions& options,
2323     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
2324   if (device_assignment_ == nullptr) {
2325     return InvalidArgument("ExecuteShard expects a non-null device_assignment");
2326   }
2327   for (int i = 0; i < addressable_devices_.size(); ++i) {
2328     if (addressable_devices_[i] == device) {
2329       VLOG(1) << "ExecuteShard executes computation " << name()
2330               << " on assigned replica/partition on device "
2331               << device->DebugString();
2332       TF_ASSIGN_OR_RETURN(
2333           auto result,
2334           ExecuteHelper(argument_handles,
2335                         addressable_device_logical_ids_[i].replica,
2336                         addressable_device_logical_ids_[i].partition, RunId(),
2337                         options, fill_future));
2338       returned_future = std::move(result.future);
2339       return std::move(result.buffers);
2340     }
2341   }
2342   return InvalidArgument(
2343       "ExecuteShard attempted to execute on device id %d which is not "
2344       "addressable by this client",
2345       device->id());
2346 }
2347 
2348 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future,bool fill_future)2349 PjRtStreamExecutorExecutable::ExecutePortable(
2350     absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
2351     const ExecuteOptions& options,
2352     std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
2353   if (device_assignment_ != nullptr) {
2354     return InvalidArgument("ExecutePortable gets a non-portable executable");
2355   }
2356   if (num_replicas() != 1 || num_partitions() != 1) {
2357     return InvalidArgument(
2358         "ExecutePortable expects a single-core executable but gets "
2359         "one with %d replica %d partition",
2360         num_replicas(), num_partitions());
2361   }
2362   if (device == nullptr) {
2363     return InvalidArgument("ExecutePortable expects a device to be specified");
2364   }
2365   VLOG(1) << "ExecutePortable executes single-core portable executable "
2366           << name();
2367   TF_ASSIGN_OR_RETURN(auto result, ExecuteHelper(argument_handles,
2368                                                  /*replica=*/0,
2369                                                  /*partition=*/0, RunId(),
2370                                                  options, fill_future, device));
2371   returned_future = std::move(result.future);
2372   return std::move(result.buffers);
2373 }
2374 
2375 StatusOr<std::vector<std::shared_ptr<HloModule>>>
GetHloModules() const2376 PjRtStreamExecutorExecutable::GetHloModules() const {
2377   std::vector<std::shared_ptr<HloModule>> modules;
2378   modules.reserve(executables().size());
2379   for (const auto& local_exec : executables()) {
2380     if (!local_exec->executable()->has_module()) {
2381       return InvalidArgument("Executable does not have HLO modules.");
2382     }
2383     modules.push_back(local_exec->executable()->shared_module());
2384   }
2385   return std::move(modules);
2386 }
2387 
2388 StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
GetExecutableExtras(CompileOptions * options)2389 PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
2390   ExecutableExtras extras;
2391   std::shared_ptr<DeviceAssignment>& device_assignment =
2392       extras.device_assignment;
2393   std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2394       addressable_device_logical_ids = extras.addressable_device_logical_ids;
2395   std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2396 
2397   ExecutableBuildOptions& build_options = options->executable_build_options;
2398   if (!build_options.compile_thread_pool()) {
2399     build_options.set_compile_thread_pool(thread_pool());
2400   }
2401   if (!build_options.device_allocator()) {
2402     build_options.set_device_allocator(allocator());
2403   }
2404 
2405   int num_replicas;
2406   int num_partitions;
2407   TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
2408       options->compile_portable_executable, &options->executable_build_options,
2409       [this](int num_replicas, int num_partitions) {
2410         return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
2411       },
2412       &num_replicas, &num_partitions, &device_assignment));
2413 
2414   // Find devices that are addressable by this client/task.
2415   if (device_assignment != nullptr) {
2416     addressable_device_logical_ids.reserve(num_replicas * num_partitions);
2417     addressable_devices.reserve(num_replicas * num_partitions);
2418     for (int replica = 0; replica < num_replicas; ++replica) {
2419       for (int partition = 0; partition < num_partitions; ++partition) {
2420         int device_id = (*device_assignment)(replica, partition);
2421         TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
2422         if (device->process_index() != process_index()) {
2423           VLOG(3) << "Non-local device: " << device_id;
2424           continue;
2425         }
2426         PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
2427         logica_device_ids.replica = replica;
2428         logica_device_ids.partition = partition;
2429         addressable_device_logical_ids.push_back(std::move(logica_device_ids));
2430         addressable_devices.push_back(device);
2431       }
2432     }
2433     if (addressable_devices.empty()) {
2434       return InvalidArgument(
2435           "Device assignment (%s) does not have any local devices.",
2436           device_assignment->ToString());
2437     }
2438 
2439     if (build_options.device_ordinal() < 0) {
2440       build_options.set_device_ordinal(
2441           addressable_devices.front()->local_hardware_id());
2442     }
2443   }
2444   return extras;
2445 }
2446 
2447 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
Compile(const XlaComputation & computation,CompileOptions options)2448 PjRtStreamExecutorClient::Compile(const XlaComputation& computation,
2449                                   CompileOptions options) {
2450   tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
2451   VLOG(1) << "PjRtStreamExecutorClient::Compile";
2452 
2453   TF_ASSIGN_OR_RETURN(ExecutableExtras extras, GetExecutableExtras(&options));
2454   std::shared_ptr<DeviceAssignment>& device_assignment =
2455       extras.device_assignment;
2456   std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
2457       addressable_device_logical_ids = extras.addressable_device_logical_ids;
2458   std::vector<PjRtDevice*>& addressable_devices = extras.addressable_devices;
2459 
2460   std::vector<const Shape*> argument_layout_pointers;
2461   TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
2462       computation,
2463       [local_client = client()](Shape shape) {
2464         return local_client->backend()
2465             .transfer_manager()
2466             ->ChooseCompactLayoutForShape(shape);
2467       },
2468       options.argument_layouts, &options.executable_build_options,
2469       &argument_layout_pointers));
2470 
2471   TF_ASSIGN_OR_RETURN(
2472       std::vector<std::unique_ptr<LocalExecutable>> local_executables,
2473       client()->Compile(computation, argument_layout_pointers,
2474                         options.executable_build_options));
2475 
2476   auto executable = std::make_unique<PjRtStreamExecutorExecutable>(
2477       std::move(local_executables), options.parameter_is_tupled_arguments,
2478       std::move(device_assignment), std::move(addressable_device_logical_ids),
2479       std::move(addressable_devices), this);
2480   TF_RETURN_IF_ERROR(
2481       executable->SetUpDonation(options.parameter_is_tupled_arguments));
2482   return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
2483 }
2484 
2485 StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
Compile(mlir::ModuleOp module,CompileOptions options)2486 PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
2487                                   CompileOptions options) {
2488   XlaComputation xla_computation;
2489   TF_RETURN_IF_ERROR(MlirToXlaComputation(
2490       module, xla_computation,
2491       /*use_tuple_args=*/options.parameter_is_tupled_arguments,
2492       /*return_tuple=*/false));
2493   return Compile(xla_computation, options);
2494 }
2495 
2496 }  // namespace xla
2497