xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/synchronization/mutex.h"
31 #include "absl/synchronization/notification.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/client/executable_build_options.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/layout.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
39 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
40 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
41 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
42 #include "tensorflow/compiler/xla/pjrt/transpose.h"
43 #include "tensorflow/compiler/xla/service/computation_layout.h"
44 #include "tensorflow/compiler/xla/service/computation_placer.h"
45 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
46 #include "tensorflow/compiler/xla/service/hlo_module.h"
47 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
48 #include "tensorflow/compiler/xla/shape.h"
49 #include "tensorflow/compiler/xla/status.h"
50 #include "tensorflow/compiler/xla/statusor.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/xla_data.pb.h"
53 #include "tensorflow/core/framework/allocator.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/platform/casts.h"
56 #include "tensorflow/stream_executor/stream.h"
57 
58 namespace xla {
59 
60 class PjRtStreamExecutorDevice : public PjRtDevice {
61  public:
62   explicit PjRtStreamExecutorDevice(
63       int id, std::unique_ptr<LocalDeviceState> local_device_state,
64       std::string device_kind, int process_index = 0)
id_(id)65       : id_(id),
66         device_ordinal_(
67             local_device_state ? local_device_state->device_ordinal() : -1),
68         local_device_state_(std::move(local_device_state)),
69         process_index_(process_index),
70         device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice()71   ~PjRtStreamExecutorDevice() override {}
72 
73   // Must set client exactly once.
SetClient(PjRtClient * client)74   void SetClient(PjRtClient* client) {
75     CHECK(client_ == nullptr);
76     client_ = client;
77     // We have to define debug_string_ and to_string_ here, because
78     // platform_name() requires client_ to be set.
79     debug_string_ = absl::StrCat(platform_name(), ":", id());
80     to_string_ = absl::StrCat(platform_name(), "(id=", id(), ")");
81   }
82 
process_index()83   int process_index() const override { return process_index_; }
84 
85   // Return `platform_id` from client.
86   PjRtPlatformId platform_id() const;
87 
88   // Return `platform_name` from client.
89   absl::string_view platform_name() const;
90 
client()91   PjRtClient* client() const override { return client_; }
92 
id()93   int id() const override { return id_; }
94 
IsAddressable()95   bool IsAddressable() const override { return device_ordinal_ != -1; }
96 
local_hardware_id()97   int local_hardware_id() const override { return device_ordinal_; }
98 
99   // If this is a device local to this host, returns a LocalDeviceState object
100   // that can be used to manipulate the device. Returns nullptr if the device is
101   // not local to this host.
local_device_state()102   LocalDeviceState* local_device_state() const {
103     return local_device_state_.get();
104   }
105 
106   // If this is a device local to this host, returns a LocalDeviceState object
107   // that can be used to manipulate the device. Returns an error if the device
108   // is not local to this host.
109   StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
110 
device_kind()111   absl::string_view device_kind() const override { return device_kind_; }
112 
113   absl::string_view ToString() const override;
114 
115   absl::string_view DebugString() const override;
116 
117   Status TransferToInfeed(const LiteralSlice& literal) override;
118 
119   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
120 
CreateAsyncTrackingEvent(absl::string_view description)121   std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
122       absl::string_view description) const override {
123     return nullptr;
124   }
125 
Attributes()126   const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
127       const override {
128     return attributes_;
129   }
130 
131  protected:
132   absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
133 
134  private:
135   const int id_;
136   const int device_ordinal_;  // -1 means not local.
137   const std::unique_ptr<LocalDeviceState> local_device_state_;
138   const int process_index_;
139   const std::string device_kind_;
140   std::string debug_string_;
141   std::string to_string_;
142   PjRtClient* client_ = nullptr;
143 };
144 
145 class PjRtStreamExecutorClient : public PjRtClient {
146  public:
147   // `allocator` may null, in which case the platform default allocator is used.
148   explicit PjRtStreamExecutorClient(
149       std::string platform_name, LocalClient* client,
150       std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
151       int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
152       std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
153       bool should_stage_host_to_device_transfers,
154       std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
155   ~PjRtStreamExecutorClient() override = default;
156 
process_index()157   int process_index() const override { return process_index_; }
158 
device_count()159   int device_count() const override { return devices_.size(); }
addressable_device_count()160   int addressable_device_count() const override {
161     return addressable_devices_.size();
162   }
devices()163   absl::Span<PjRtDevice* const> devices() const override { return devices_; }
addressable_devices()164   absl::Span<PjRtDevice* const> addressable_devices() const override {
165     return addressable_devices_;
166   }
167 
LookupDevice(int device_id)168   StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
169     auto it = id_to_device_.find(device_id);
170     if (it != id_to_device_.end()) {
171       return it->second;
172     }
173     return InvalidArgument("No matching device found for device_id %d",
174                            device_id);
175   }
176 
177   StatusOr<PjRtDevice*> LookupAddressableDevice(
178       int local_hardware_id) const override;
179 
platform_id()180   PjRtPlatformId platform_id() const override { return platform_id_; }
platform_name()181   absl::string_view platform_name() const override { return platform_name_; }
platform_version()182   absl::string_view platform_version() const override { return "<unknown>"; }
runtime_type()183   PjRtRuntimeType runtime_type() const override { return kStreamExecutor; }
184 
185   // Most platforms expect device-to-device transfers to be enqueued on the
186   // source d2d stream, but some platforms use the destination d2d stream. This
187   // function specifies which one the platform expects.
EnqueueD2DTransfersOnSrcStream()188   virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
189 
190   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
191       int num_replicas, int num_partitions) const override;
192 
193   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
194       const XlaComputation& computation, CompileOptions options) override;
195   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
196       mlir::ModuleOp mlir_module, CompileOptions options) override;
197 
ExecutableFingerprint(const PjRtLoadedExecutable & executable)198   StatusOr<std::optional<std::string>> ExecutableFingerprint(
199       const PjRtLoadedExecutable& executable) const override {
200     return std::optional<std::string>();
201   }
202 
SerializeExecutable(const PjRtLoadedExecutable & executable)203   StatusOr<std::string> SerializeExecutable(
204       const PjRtLoadedExecutable& executable) const override {
205     return Unimplemented("SerializeExecutable not implemented on %s",
206                          platform_name());
207   }
208 
DeserializeExecutable(absl::string_view serialized,CompileOptions options)209   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
210       absl::string_view serialized, CompileOptions options) override {
211     return Unimplemented("DeserializeExecutable not implemented on %s",
212                          platform_name());
213   }
214 
215   StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override;
216 
217   // Creates a buffer on the device without initializing or copying any data.
218   // An optional `definition_event` may be speficied that can be used to
219   // ensure the buffer isn't referenced until some external mechanism has
220   // initialized the data.
221   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
222       const Shape& shape, PjRtDevice* device) override;
223   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
224       const Shape& shape, PjRtDevice* device,
225       std::shared_ptr<BufferSequencingEvent> definition_event);
226 
227   StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>>
CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)228   CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
229                                 PjRtDevice* device) override {
230     return Unimplemented("Async transfer to buffers not implemented");
231   };
232 
233   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
234       const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
235       std::optional<absl::Span<int64_t const>> byte_strides,
236       HostBufferSemantics host_buffer_semantics,
237       std::function<void()> on_done_with_host_buffer,
238       PjRtDevice* device) override;
239 
240   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
241       const LiteralSlice& literal, PjRtDevice* device) override;
242 
243   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
244   MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
245                               PjRtDevice* device,
246                               PjRtCrossHostRecvNotifier notifier) override;
247 
248   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
249   MakeCrossHostReceiveBuffersForGather(
250       absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
251       PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override;
252 
253   StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
254       void* device_ptr, const Shape& shape, PjRtDevice* device,
255       std::function<void()> on_delete_callback) override;
256 
CreateChannelHandle()257   StatusOr<ChannelHandle> CreateChannelHandle() override {
258     return client()->CreateChannelHandle();
259   }
CreateDeviceToHostChannelHandle()260   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
261     return client()->CreateDeviceToHostChannelHandle();
262   }
CreateHostToDeviceChannelHandle()263   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
264     return client()->CreateHostToDeviceChannelHandle();
265   }
266 
267   // TODO(zhangqiaorjc): Experimental. Will be removed.
Defragment()268   Status Defragment() override {
269     return Unimplemented("Defragment not implemented");
270   }
271 
device_state(int device_ordinal)272   LocalDeviceState& device_state(int device_ordinal) const {
273     return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
274                 LookupAddressableDevice(device_ordinal).ValueOrDie())
275                 ->local_device_state();
276   }
client()277   LocalClient* client() const { return client_; }
allocator()278   se::DeviceMemoryAllocator* allocator() const { return allocator_; }
host_memory_allocator()279   tensorflow::Allocator* host_memory_allocator() const {
280     return host_memory_allocator_.get();
281   }
should_stage_host_to_device_transfers()282   bool should_stage_host_to_device_transfers() const {
283     return should_stage_host_to_device_transfers_;
284   }
285 
gpu_run_options()286   gpu::GpuExecutableRunOptions* gpu_run_options() const {
287     return gpu_run_options_.get();
288   }
289 
thread_pool()290   tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
291 
292  protected:
293   friend class PjRtStreamExecutorBuffer;
294 
EnqueueCrossHostReceive(absl::Span<const std::unique_ptr<PjRtBuffer>> buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier notifier,std::optional<std::vector<GatherDetails>> gather_details)295   virtual Status EnqueueCrossHostReceive(
296       absl::Span<const std::unique_ptr<PjRtBuffer>> buffers,
297       std::shared_ptr<BufferSequencingEvent> definition_event,
298       PjRtCrossHostRecvNotifier notifier,
299       std::optional<std::vector<GatherDetails>> gather_details) const {
300     return Unimplemented("Cross host receives not implemented.");
301   }
302 
CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor,PjRtBuffer::RemoteSendCallback on_done)303   virtual void CopyToRemoteDevice(
304       PjRtBuffer* buffer, absl::string_view serialized_descriptor,
305       PjRtBuffer::RemoteSendCallback on_done) const {
306     on_done(Unimplemented("Cross host sends not implemented."),
307             /*sends_were_enqueued=*/false);
308   }
309 
CopyToRemoteDeviceScattered(PjRtBuffer * buffer,absl::Span<const std::pair<std::string,PjRtBuffer::RemoteSendCallback>> serialized_descriptors_and_callbacks,const PjRtBuffer::ScatterDetails & scatter_details)310   virtual void CopyToRemoteDeviceScattered(
311       PjRtBuffer* buffer,
312       absl::Span<const std::pair<std::string, PjRtBuffer::RemoteSendCallback>>
313           serialized_descriptors_and_callbacks,
314       const PjRtBuffer::ScatterDetails& scatter_details) const {
315     for (const auto& d_and_cb : serialized_descriptors_and_callbacks) {
316       d_and_cb.second(
317           Unimplemented("Scattered cross host sends not implemented."),
318           /*sends_were_enqueued=*/false);
319     }
320   }
321 
CopyRawSubBufferToHost(PjRtBuffer * buffer,void * dst,int64_t offset,int64_t transfer_size)322   virtual PjRtFuture<Status> CopyRawSubBufferToHost(PjRtBuffer* buffer,
323                                                     void* dst, int64_t offset,
324                                                     int64_t transfer_size) {
325     return PjRtFuture<Status>(
326         Unimplemented("Raw copies to host not implemented."));
327   }
328 
329   // Helper function for creating PjRtStreamExecutorExecutables. Modifies
330   // `options` in-place.
331   struct ExecutableExtras {
332     std::shared_ptr<DeviceAssignment> device_assignment;
333     std::vector<PjRtLoadedExecutable::LogicalDeviceIds>
334         addressable_device_logical_ids;
335     std::vector<PjRtDevice*> addressable_devices;
336   };
337   StatusOr<ExecutableExtras> GetExecutableExtras(CompileOptions* options);
338 
339   const PjRtPlatformId platform_id_;
340   const std::string platform_name_;
341   LocalClient* client_;
342 
343   // Allocator to be used for staging memory transfers to devices.
344   std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
345 
346   // Device memory allocator. If owned, the allocator must outlive the devices,
347   // because it is the device destructor that waits for any outstanding work to
348   // complete.
349   se::DeviceMemoryAllocator* allocator_;
350   std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
351 
352   // Includes all devices, including non-local devices on multi-host platforms.
353   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
354   // Pointers to `owned_devices_`.
355   std::vector<PjRtDevice*> devices_;
356   // Maps Device::id() to the corresponding Device. Includes all devices.
357   std::map<int, PjRtDevice*> id_to_device_;
358   // Local devices indexed by local device ordinal.
359   std::vector<PjRtDevice*> addressable_devices_;
360   int process_index_;
361 
362   // Should we always prefer to stage host-to-device transfers via memory
363   // allocated on host_memory_allocator_? True only on GPU, where we prefer to
364   // transfer via pinned memory.
365   bool should_stage_host_to_device_transfers_;
366 
367   std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
368 
369   tensorflow::thread::ThreadPool thread_pool_;
370 
371   absl::Mutex transpose_mu_;
372   TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_);
373 };
374 
375 // Converts a 2D set of Device objects indexed by [replica][partition] into an
376 // xla::DeviceAssignment.
377 StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
378     absl::Span<const std::vector<PjRtDevice*>> devices);
379 
380 class PjRtStreamExecutorBuffer : public PjRtBuffer {
381  public:
382   // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
383   // may not outlive its parent PjRtStreamExecutorBuffer.
384   //
385   // There are three types of hold, as follows:
386   //
387   // 1) Usage hold: a transient hold while an operation using the buffer is
388   //    being enqueued onto a stream.
389   // A client acquires a usage hold by calling
390   // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
391   // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
392   // hold should be released using a call to ConvertUsageHold. If the ScopedHold
393   // is deleted without ConvertUsageHold being called, e.g., on error, the hold
394   // is dropped. It is legal to drop a usage hold instead of calling
395   // ConvertUsageHold, even if the buffer was successfully enqueued, as long as
396   // the client ensures that all necessary synchronization has been done.
397   //
398   // 2) External hold: a potentially long-lived hold while the buffer is being
399   //    shared by an external framework, e.g., NumPy.
400   // A client acquires an external hold by calling
401   // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
402   // wrapper GetBufferWithExternalReference and releases it by deleting the
403   // ScopedHold. The external framework should not modify the underlying buffer
404   // unless it is confident via its own synchronization that modifications do
405   // not race with reads from the PjRtStreamExecutorBuffer.
406   //
407   // 3) Donation hold: a transient hold while an execution that donates the
408   //    buffer is being enqueued onto the compute stream.
409   // A client acquires a donation hold by calling
410   // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
411   // completes successfully the hold should be released using a call to
412   // ConfirmDonation after which the buffer is invalid. If the ScopedHold is
413   // deleted without ConfirmDonation being called, e.g., on error, the hold is
414   // dropped and the buffer remains valid. If the buffer is successfully
415   // enqueued the client *must* call ConfirmDonation.
416   //
417   // Donation holds behave like exclusive write locks: when a donation hold
418   // has been acquired, any attempt to acquire another hold of any type will
419   // block until the donation hold is dropped or confirmed. Acquiring a donation
420   // hold will fail with an error if there is any outstanding external hold, and
421   // will block if there are any outstanding usage holds until those holds are
422   // dropped or converted.
423   //
424   // Calls to PjRtStreamExecutorBuffer::Release (and transitively to
425   // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
426   // block until all usage and donation holds are either deleted or
427   // converted/confirmed.
428   class ScopedHold {
429    public:
430     enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
431     // Use a State enum instead of encoding the state in an error Status to
432     // avoid creating Status values in non-error cases. Creating a Status
433     // entails several allocations and can add O(us) to every use of a hold.
434     enum State {
435       kUninitialized = 0,
436       kValid,
437       kMoved,
438       kConverted,
439       kReleased,
440       kDonated,
441       kError
442     };
443 
444     ~ScopedHold();
445     ScopedHold(ScopedHold&& other);
446     ScopedHold(const ScopedHold&) = delete;
447     ScopedHold& operator=(const ScopedHold&) = delete;
448 
type()449     Type type() const { return type_; }
450 
status()451     Status status() const {
452       // Lazily create Status values only when they are requested.
453       switch (state_) {
454         case kUninitialized:
455           return InvalidArgument("Buffer has not been initialized");
456         case kValid:
457           return OkStatus();
458         case kMoved:
459           return InvalidArgument("Buffer has been moved.");
460         case kConverted:
461           return InvalidArgument("Buffer has been converted");
462         case kReleased:
463           return InvalidArgument("Buffer has been released");
464         case kDonated:
465           return InvalidArgument("Buffer has been donated");
466         case kError:
467           return status_;
468         default:
469           CHECK(false) << "Unexpected state value " << state_;
470       }
471     }
ok()472     bool ok() const { return state_ == kValid; }
473 
474     // Access to the underlying device buffer storage. Requires this->ok().
buffer()475     const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
476       CHECK_EQ(state_, kValid);
477       CHECK_NE(buffer_, nullptr);
478       return buffer_;
479     }
480     TrackedDeviceBuffer* operator->() const { return buffer().get(); }
481     const TrackedDeviceBuffer& operator*() const { return *buffer(); }
482 
483     // Converts the hold into a usage event. Only valid for holds of type
484     // kUsage.
485     //
486     //   usage_stream:   the stream that the buffer was used on.
487     //   event:          an event that has been recorded on usage_stream after
488     //                   the buffer was used.
489     //   reference_held: true if and only if the caller has caused a
490     //                   reference to this->buffer() to stay live until after
491     //                   the host is sure that the usage (transfer or execution)
492     //                   has completed.
493     void ConvertUsageHold(se::Stream* usage_stream,
494                           std::shared_ptr<BufferSequencingEvent> event,
495                           bool reference_held);
496 
497     // Confirms that the buffer was successfully donated to an execution.
498     // Only valid for holds of type kDonation. Causes the buffer to become
499     // invalid.
500     void ConfirmDonation();
501 
502     // Adds the held device buffers in order to 'iterator'. Used to add the
503     // buffers to an ExecutionInput. We require but do not verify that
504     // 'iterator' when passed in is pointing to a sub-tuple of the
505     // ExecutionInput whose on_device_shape matches that of the
506     // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
507     // out of bounds. Donates the device buffers if the hold type is kDonation,
508     // otherwise retains ownership of the device buffers.
509     void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
510                     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
511                     ExecutionInput* execution_input,
512                     se::DeviceMemoryAllocator* allocator) const;
513 
514    private:
515     friend class PjRtStreamExecutorBuffer;
516     friend class PjRtStreamExecutorClient;
517 
518     // Helper struct that makes it possible to move a ScopedHold through a
519     // closure.
520     using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State,
521                                   Status, std::shared_ptr<TrackedDeviceBuffer>>;
522 
ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)523     ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
524         : parent_(parent), type_(type), state_(kUninitialized) {}
ScopedHold(const ForClosure & closure_helper)525     explicit ScopedHold(const ForClosure& closure_helper)
526         : parent_(std::get<0>(closure_helper)),
527           type_(std::get<1>(closure_helper)),
528           state_(std::get<2>(closure_helper)),
529           status_(std::get<3>(closure_helper)),
530           buffer_(std::get<4>(closure_helper)) {
531       // Check the buffer is not in an error state.
532       CHECK(status_.ok() && buffer_ != nullptr);
533     }
534 
535     // Sets buffer state.
SetState(State state)536     void SetState(State state) { state_ = state; }
537 
538     // Sets buffer_ and status_. Called by parent_ to initialize the hold.
539     void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
540     // Releases the contents of *this, so *this can subsequently be
541     // deleted without releasing the parent's hold. Should be passed to the
542     // appropriate constructor of another ScopedHold, e.g., when a hold must be
543     // passed through a closure that is incompatible with std::move.
544     ForClosure ToClosure();
545 
546     PjRtStreamExecutorBuffer* const parent_;
547     const Type type_;
548 
549     // There is an invariant that if ok() then
550     // buffer_.ValueOrDie() != nullptr.
551     State state_;
552     Status status_;
553     std::shared_ptr<TrackedDeviceBuffer> buffer_;
554   };
555 
556   PjRtStreamExecutorBuffer(Shape on_device_shape,
557                            std::shared_ptr<TrackedDeviceBuffer> device_buffer,
558                            PjRtClient* client, PjRtDevice* device);
559   ~PjRtStreamExecutorBuffer() override;
560 
561   PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
562   PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
563   PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
564   PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
565 
on_device_shape()566   const Shape& on_device_shape() const override { return on_device_shape_; }
567   StatusOr<Shape> logical_on_device_shape() override;
device()568   PjRtStreamExecutorDevice* device() const override { return device_; }
platform_id()569   PjRtPlatformId platform_id() const { return client_->platform_id(); }
platform_name()570   absl::string_view platform_name() const { return client_->platform_name(); }
client()571   PjRtStreamExecutorClient* client() const override { return client_; }
IsEmptyTuple()572   bool IsEmptyTuple() const {
573     return on_device_shape_.IsTuple() &&
574            on_device_shape_.tuple_shapes_size() == 0;
575   }
576 
577   StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
578       override;
579 
580   StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
581       bool wait_for_operations_to_complete) override;
582 
583   using PjRtBuffer::ToLiteralSync;
584   PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override;
585 
586   StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
587 
588   PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
589                                    int64_t transfer_size) override;
590 
591   // Drops the buffer's reference to its associated device memory, leaving the
592   // buffer in an invalid state. The memory will be freed lazily when all async
593   // operations using the buffer have completed, according to the allocation
594   // semantics of the underlying platform. Delete may briefly block if another
595   // thread is in the process of enqueuing an operation on this buffer, but it
596   // will never block for a stream operation to complete. If an external
597   // framework holds a reference to the TrackedDeviceBuffer via
598   // GetBufferWithExternalReference, the memory will not be freed until the
599   // external framework drops the reference.
600   void Delete() override;
601 
602   bool IsDeleted() override;
603 
604   // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
605   // PjRtBuffer retains ownership of the device buffers.
606   StatusOr<ShapedBuffer> AsShapedBuffer() const;
607 
608   // Returns a hold on the TrackedDeviceBuffer holding the device
609   // buffers. See comment on ScopedHold.
610   ScopedHold GetBufferWithHold(ScopedHold::Type type);
GetBufferWithUsageHold()611   ScopedHold GetBufferWithUsageHold() {
612     return GetBufferWithHold(ScopedHold::kUsage);
613   }
GetBufferWithExternalReference()614   ScopedHold GetBufferWithExternalReference() {
615     return GetBufferWithHold(ScopedHold::kExternalReference);
616   }
617 
618   StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
619       PjRtDevice* dst_device) override;
620 
621   void CopyToRemoteDevice(absl::string_view serialized_descriptor,
622                           RemoteSendCallback on_done) override;
623 
624   void CopyToRemoteDeviceScattered(
625       absl::Span<const std::pair<std::string, RemoteSendCallback>>
626           serialized_descriptors_and_callbacks,
627       const ScatterDetails& scatter_details) override;
628 
629   PjRtFuture<Status> GetReadyFuture() override;
630 
631   bool IsOnCpu() const override;
632 
633   // Similar to Delete, drops the buffer's reference to its associated device
634   // memory, leaving the buffer in an invalid state, but returns the
635   // TrackedDeviceBuffer rather than freeing the device memory, so that another
636   // framework can take ownership of it. The buffer returned from Release may
637   // be safely dropped at any time even if it still has pending async
638   // operations. The client should call GetReadyFuture()->Await() before calling
639   // Release with wait_for_operations_to_complete=false, to ensure that the host
640   // has synchronized past any outstanding write operations to the buffer. If
641   // wait_for_operations_to_complete=true the host will block until any
642   // potentially outstanding asynchronous operations have completed before
643   // returning, in which case it is safe to read or mutate the returned buffer.
644   // If the buffer was shared via an external reference it is the client's
645   // responsibility that accesses via that reference do not interfere with
646   // accesses via the buffer returned from Release.
647   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
648       bool wait_for_operations_to_complete);
649 
650  private:
651   friend class PjRtClient;
652 
653   // Blocks in mu_.Await until there are no more usage holds.
654   void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
655 
656   // Blocks in mu_.Await until there is no donation hold.
657   void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
658 
659   // Adds a hold of 'type' and returns device_buffer_. Returns an error if
660   // device_buffer_ is null, or if a donation hold was requested when there is
661   // an outstanding external hold.
662   // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds()
663   // must be called first.)
664   StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
665       ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
666 
667   // Adds a hold of hold->type() and initializes `hold` with device_buffer_.
668   // Initializes hold with an error if device_buffer_ is null, or if a donation
669   // hold was requested when there is an outstanding external hold.
670   // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds()
671   // must be called first.)
672   void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
673 
674   // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
675   // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
676   // device_buffer_ was successfully enqueued on a stream.
677   void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
678                         std::shared_ptr<BufferSequencingEvent> event,
679                         bool reference_held);
680 
681   // Drops a donation hold and makes *this invalid for further use. Does a
682   // sanity check that buffer==device_buffer_. Called after device_buffer_ was
683   // successfully donated to an execution.
684   void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
685 
686   // Drops a hold without taking any other action. Does a sanity check that
687   // buffer==device_buffer_ or device_buffer_==nullptr.
688   void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
689 
690   StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
691                      std::shared_ptr<BufferSequencingEvent>>>
692   CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
693                      LocalDeviceState* transfer_local_device,
694                      se::Stream* transfer_stream,
695                      std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
696 
697   PjRtStreamExecutorClient* const client_;
698   const Shape on_device_shape_;
699   PjRtStreamExecutorDevice* const device_;
700 
701   mutable absl::Mutex mu_;
702   std::shared_ptr<TrackedDeviceBuffer> device_buffer_ ABSL_GUARDED_BY(mu_);
703   // Count of holds on the buffer.
704   std::array<int, ScopedHold::Type::kMaxValue> holds_ ABSL_GUARDED_BY(mu_);
705   PjRtFuture<Status>::Promise definition_promise_ ABSL_GUARDED_BY(mu_);
706 };
707 
708 // Wraps one or more XLA LocalExecutables (one per partition, as specified by
709 // the build options).
710 class PjRtStreamExecutorExecutable : public PjRtLoadedExecutable {
711  public:
712   PjRtStreamExecutorExecutable(
713       std::vector<std::unique_ptr<LocalExecutable>> executables,
714       bool parameter_is_tupled_arguments,
715       std::shared_ptr<DeviceAssignment> device_assignment,
716       std::vector<LogicalDeviceIds> addressable_device_logical_ids,
717       std::vector<PjRtDevice*> addressable_devices,
718       PjRtStreamExecutorClient* client);
719 
720   ~PjRtStreamExecutorExecutable() override = default;
721 
client()722   PjRtStreamExecutorClient* client() const override { return client_; }
723 
724   absl::string_view name() const override;
725 
num_replicas()726   int num_replicas() const override {
727     return executables_[0]->build_options().num_replicas();
728   }
729 
num_partitions()730   int num_partitions() const override {
731     return executables_[0]->build_options().num_partitions();
732   }
733 
SizeOfGeneratedCodeInBytes()734   int64_t SizeOfGeneratedCodeInBytes() const override {
735     int64_t size = 0;
736     for (auto& executable : executables_) {
737       size += executable->executable()->SizeOfGeneratedCodeInBytes();
738     }
739     return size;
740   }
741 
device_assignment()742   const DeviceAssignment& device_assignment() const override {
743     return *device_assignment_;
744   }
745 
addressable_device_logical_ids()746   absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
747       const override {
748     return addressable_device_logical_ids_;
749   }
750 
addressable_devices()751   absl::Span<PjRtDevice* const> addressable_devices() const override {
752     return addressable_devices_;
753   }
754 
755   // Return an HloModule per partition.
756   StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
757       const override;
758 
759   using PjRtLoadedExecutable::Execute;
760   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
761       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
762       const ExecuteOptions& options,
763       std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
764       override;
765 
766   using PjRtLoadedExecutable::ExecuteSharded;
767   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
768       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
769       const ExecuteOptions& options,
770       std::optional<PjRtFuture<Status>>& returned_future,
771       bool fill_future) override;
772 
773   using PjRtLoadedExecutable::ExecutePortable;
774   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
775       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
776       const ExecuteOptions& options,
777       std::optional<PjRtFuture<Status>>& returned_future,
778       bool fill_future) override;
779 
Delete()780   void Delete() override { executables_.clear(); }
781 
IsDeleted()782   bool IsDeleted() override { return executables_.empty(); }
783 
IsReturnedFutureSupported()784   bool IsReturnedFutureSupported() const override { return true; }
785 
executables()786   absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
787     return executables_;
788   }
789 
790  protected:
parameter_is_tupled_arguments()791   bool parameter_is_tupled_arguments() const {
792     return parameter_is_tupled_arguments_;
793   }
794 
795  private:
796   friend class PjRtStreamExecutorClient;
797   friend class PjRtTpuClient;
798   friend class InternalPjRtTpuClient;
799   // Initializes information about which arguments to which executables must be
800   // donated due to aliases that were specified by the computation.
801   Status SetUpDonation(bool tuple_inputs);
802 
803   // Returns a sorted list of the parameters that must be donated. Derived
804   // classes may use custom logic.
805   virtual absl::Span<int const> ParametersThatMustBeDonated(
806       int executable_idx) const;
807 
808   virtual StatusOr<std::vector<ExecutionInput>>
809   MakeExecutionInputsAndWaitForEvents(
810       int device_ordinal, const ExecuteOptions& options,
811       absl::Span<const Shape> executable_parameter_shapes,
812       absl::Span<PjRtBuffer* const> argument_handles,
813       absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
814       absl::flat_hash_set<BufferSequencingEvent*>& events) const;
815 
816   StatusOr<ScopedShapedBuffer> EnqueueExecution(
817       absl::Span<PjRtBuffer* const> argument_handles, int replica,
818       int partition, int executable_idx, const RunId& run_id,
819       const ExecuteOptions& options, PjRtDevice* device,
820       std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
821       std::shared_ptr<DeviceAssignment> device_assignment,
822       std::vector<std::function<void()>>& compute_callbacks) const;
823 
824   virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
825       int device_ordinal, const ExecuteOptions& options,
826       ScopedShapedBuffer result_buffer,
827       std::shared_ptr<BufferSequencingEvent> definition_event,
828       PjRtDevice* device, std::vector<std::function<void()>>& compute_callbacks,
829       std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
830       const;
831 
832   StatusOr<Result> ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
833                                  int replica, int partition,
834                                  const RunId& run_id,
835                                  const ExecuteOptions& options,
836                                  bool fill_future,
837                                  PjRtDevice* device = nullptr) const;
838 
839   // Create shared pointers so we can free them after the execution: with
840   // asynchronous execution, the process being executed can outlive the
841   // executable itself.
842   PjRtStreamExecutorClient* const client_;
843   // One executable per partition.
844   std::vector<std::shared_ptr<LocalExecutable>> executables_;
845   // On device shapes of the executable parameters.
846   std::vector<std::vector<Shape>> on_device_executable_parameter_shapes_;
847   // Per-executable sorted vector of parameters that have any aliased buffers
848   // and thus must be donated when executing the computation.
849   std::vector<std::vector<int>> parameters_that_must_be_donated_;
850   std::shared_ptr<DeviceAssignment> device_assignment_;
851 
852   // True if the executables were compiled expecting arguments in a single
853   // tuple.
854   const bool parameter_is_tupled_arguments_;
855 
856   // The replica and partition indices of device_assignment_ to be run by this
857   // client. On single-host platforms without partitioning, this is all replicas
858   // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
859   // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
860   // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
861   std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
862 
863   // addressable_devices_[i] is the Device to which
864   // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
865   // unique_ptrs to play well with the Python bindings (see xla.cc).
866   std::vector<PjRtDevice*> addressable_devices_;
867 };
868 
869 }  // namespace xla
870 
871 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
872