xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/synchronization/notification.h"
32 #include "absl/types/span.h"
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "tensorflow/compiler/xla/client/executable_build_options.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h"
38 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
39 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
40 #include "tensorflow/compiler/xla/shape.h"
41 #include "tensorflow/compiler/xla/status.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/fingerprint.h"
47 
48 // API notes:
49 // PjRt stands for "Pretty much Just another RunTime".
50 
51 namespace xla {
52 
53 using PjRtPlatformId = uint64_t;
54 
CpuName()55 inline const char* CpuName() {
56   static constexpr char kCpuName[] = "cpu";
57   return kCpuName;
58 }
GpuName()59 inline const char* GpuName() {
60   static constexpr char kGpuName[] = "gpu";
61   return kGpuName;
62 }
TpuName()63 inline const char* TpuName() {
64   static constexpr char kTpuName[] = "tpu";
65   return kTpuName;
66 }
CpuId()67 inline PjRtPlatformId CpuId() {
68   static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(CpuName());
69   return kCpuId;
70 }
GpuId()71 inline PjRtPlatformId GpuId() {
72   static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(GpuName());
73   return kGpuId;
74 }
TpuId()75 inline PjRtPlatformId TpuId() {
76   static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(TpuName());
77   return kTpuId;
78 }
79 
80 enum PjRtRuntimeType { kStreamExecutor, kTfrt };
PjRtRuntimeTypeString(PjRtRuntimeType type)81 inline constexpr absl::string_view PjRtRuntimeTypeString(PjRtRuntimeType type) {
82   switch (type) {
83     case kStreamExecutor:
84       return "stream_executor";
85     case kTfrt:
86       return "tfrt";
87   }
88 }
89 
90 class PjRtClient;
91 
92 using PjRtDeviceAttribute =
93     std::variant<std::string, int64_t, std::vector<int64_t>>;
94 
95 class PjRtDevice {
96  public:
~PjRtDevice()97   virtual ~PjRtDevice() {}
98 
99   // Return the client that owns this device.
100   virtual PjRtClient* client() const = 0;
101 
102   // Whether client can issue command to this device.
103   virtual bool IsAddressable() const = 0;
104 
105   // The ID of this device. IDs are unique among devices of this type
106   // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
107   // hosts' devices.  This is the ID that should be used in a DeviceAssignment.
108   virtual int id() const = 0;
109 
110   // The index of the process that this device belongs to, i.e. is addressable
111   // from. This is not always identical to PjRtClient::process_index() in a
112   // multi-process setting, where each client can see devices from all
113   // processes, but only a subset of them are addressable and have the same
114   // process_index as the client.
115   virtual int process_index() const = 0;
116 
117   // Opaque hardware ID, e.g., the CUDA device number, useful for identifying
118   // which GPU when interacting with non-JAX code. In general, not guaranteed to
119   // be dense, and -1 if undefined.
120   virtual int local_hardware_id() const = 0;
121 
122   // A vendor-dependent string that uniquely identifies the kind of device,
123   // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
124   // compatible compilation.
125   virtual absl::string_view device_kind() const = 0;
126 
127   // Debug string suitable for logging when errors occur. Should be verbose
128   // enough to describe the current device unambiguously.
129   virtual absl::string_view DebugString() const = 0;
130 
131   // Debug string suitable for reading by end users, should be reasonably terse,
132   // for example: "CpuDevice(id=0)".
133   virtual absl::string_view ToString() const = 0;
134 
135   // Returns a scoped event that the caller uses to tell the PjRtClient that
136   // there is asynchronous work happening that depends on activity on the
137   // PjRtDevice. See comment on class definition in pjrt_future.h.
138   //
139   // Only some PjRtDevice implementations support ScopedAsyncTrackingEvent, and
140   // those that do not will return nullptr.
141   virtual std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
142       absl::string_view description) const = 0;
143 
144   // Transfer the given literal to the infeed queue.
145   virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
146 
147   // Transfer and return a value of the given shape from the outfeed queue.
148   virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
149 
150   // Returns vendor specific attributes about the device. For example the model
151   // number of a GPU, or the mesh coordinates of a TPU device. The returned
152   // reference will remain valid for the lifetime of the PjRtDevice.
153   virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
154   Attributes() const = 0;
155 };
156 
157 // Forward declaration.
158 class PjRtBuffer;
159 
160 // Helper struct for cross host transfers, returned by the callback from a call
161 // to PjRtBuffer::MakeCrossHostReceiveBuffers or
162 // PjRtBuffer::MakeCrossHostReceiveBuffersForGather.
163 struct PjRtCrossHostRecvDescriptors {
164   // There is one serialized_descriptor per sub-buffer being gathered (i.e. a
165   // single descriptor if the buffer is returned from a call to
166   // MakeCrossHostReceiveBuffers). The descriptor should be transmitted to the
167   // sender(s) and passed to a call to src_buffer->CopyToRemoteDevice.
168   absl::InlinedVector<std::string, 1> serialized_descriptors;
169 };
170 // Function that the client should call at the receiver if it needs to cancel a
171 // cross-host send, for example because the buffer that the remote host wanted
172 // to send is not available. The serialized descriptor should match one of the
173 // descriptors returned in a PjRtCrossHostRecvDescriptors. on_canceled will be
174 // called once cancellation is complete and indicates whether cancellation was
175 // successful or not.
176 //
177 // For each serialized_descriptor provided in a PjRtCrossHostRecvDescriptors,
178 // *either* the sending host must successfully complete a CopyToRemoteDevice
179 // for that descriptor, *or* the receiving host must cancel. If there is a
180 // duplicate (e.g., both send and cancel) then the system will be left in an
181 // undefined state. If there is no send or cancellation then the system will
182 // hang indefinitely.
183 using PjRtCrossHostSendCancelNotifier =
184     std::function<void(absl::string_view serialized_descriptor, Status reason,
185                        std::function<void(Status)> on_canceled)>;
186 // State asynchronously returned by MakeCrossHostReceiveBuffers. "descriptors"
187 // will match the returned PjRtBuffer objects 1:1. Specifically, each PjRtBuffer
188 // returned by MakeCrossHostReceiveBuffers will have one
189 // PjRtCrossHostRecvDescriptors object containing it descriptor(s).
190 struct PjRtCrossHostRecvState {
191   std::vector<PjRtCrossHostRecvDescriptors> descriptors;
192   PjRtCrossHostSendCancelNotifier cancel_notifier;
193 };
194 using PjRtCrossHostRecvNotifier =
195     std::function<void(StatusOr<PjRtCrossHostRecvState>)>;
196 
197 // Provides configuration for implementations that support compile and execute
198 // spanning multiple slices. A slice is a set of devices connected by dedicated
199 // high speed interconnect. Connectivity between slices is typically over data
200 // center networks. Concrete implementations of MultiSliceConfig contain
201 // environment specific information to enable communication between devices on
202 // different slices. Passed as options during compile and execute.
203 // Implementations that do not support this are allowed to pass nullptr.
204 class MultiSliceConfig {
205  public:
206   virtual ~MultiSliceConfig();
207 
208   // Returns the total number of slices.
209   virtual int32_t NumSlices() const = 0;
210 
211   // Returns the SliceID at this host - an integer in [0, NumSlices)
212   virtual int32_t SliceId() const = 0;
213 
214   // Returns the number of devices on each slice indexed by SliceId.
215   virtual absl::flat_hash_map<int32_t, int32_t> NumDevicesPerSlice() const = 0;
216 };
217 
218 struct CompileOptions {
219   // The layouts of the arguments that the computation should expect.
220   std::optional<std::vector<Shape>> argument_layouts;
221 
222   // If true, the supplied computation expects its arguments to be wrapped in a
223   // tuple and passed as a single parameter.
224   bool parameter_is_tupled_arguments = false;
225 
226   // XLA's compilation time options.
227   ExecutableBuildOptions executable_build_options;
228 
229   // If true, the executable can be run on any device. May only be true if
230   // !executable_build_options.has_device_assignment(), so only applies to
231   // single-device executables. Beware: on GPUs, sometimes an executable
232   // compiled for one device doesn't run on another.
233   bool compile_portable_executable = false;
234 
235   // XLA compilation profile version.
236   int64_t profile_version = 0;
237 
238   // Set multi_slice_config to trigger compilation for DCN connected multi
239   // slice operation.
240   const MultiSliceConfig* multi_slice_config = nullptr;
241 
242   // Serialize the CompileOptions into a CompileOptionsProto.
243   StatusOr<CompileOptionsProto> ToProto() const;
244 };
245 
246 StatusOr<CompileOptions> CompileOptionsFromProto(
247     const CompileOptionsProto& input);
248 
249 // A sized chunk of host data. The host data can be either in host layout or in
250 // device layout, and it can be one part of the entire buffer. The PjRt
251 // implementations can customize how the memory is allocated and deallocated.
252 class PjRtChunk {
253  public:
254   // Allocate a PjRtChunk using malloc.
AllocateDefault(size_t size)255   static PjRtChunk AllocateDefault(size_t size) {
256     return PjRtChunk(malloc(size), size, [](void* ptr) { free(ptr); });
257   }
258 
259   PjRtChunk() = default;
PjRtChunk(void * data,size_t size,std::function<void (void *)> deleter)260   PjRtChunk(void* data, size_t size, std::function<void(void*)> deleter)
261       : data_(static_cast<uint8_t*>(data)),
262         size_(size),
263         deleter_(std::move(deleter)) {}
264 
~PjRtChunk()265   ~PjRtChunk() {
266     if (data_) {
267       deleter_(data_);
268     }
269   }
270 
PjRtChunk(PjRtChunk && other)271   PjRtChunk(PjRtChunk&& other)
272       : data_(other.data_),
273         size_(other.size_),
274         deleter_(std::move(other.deleter_)) {
275     other.data_ = nullptr;
276   }
277   PjRtChunk& operator=(PjRtChunk&& other) {
278     if (data_) {
279       deleter_(data_);
280     }
281     data_ = other.data_;
282     size_ = other.size_;
283     deleter_ = std::move(other.deleter_);
284     other.data_ = nullptr;
285     return *this;
286   }
287 
288   PjRtChunk(const PjRtChunk&) = delete;
289   PjRtChunk& operator=(const PjRtChunk&) = delete;
290 
data()291   uint8_t* data() { return data_; }
data()292   const uint8_t* data() const { return data_; }
size()293   int64_t size() const { return size_; }
294 
295  private:
296   // The ownership of the bytes pointed to by `data_` is controlled by the
297   // `deleter_`.
298   uint8_t* data_ = nullptr;
299   size_t size_ = 0;
300   std::function<void(void*)> deleter_;
301 };
302 
303 // A stream of Chunks from the host to the device. Once the stream enters
304 // Complete state it never changes state again.
305 //
306 // This class is thread-safe.
307 class CopyToDeviceStream {
308  public:
CopyToDeviceStream(int64_t total_bytes,int64_t granule_bytes)309   explicit CopyToDeviceStream(int64_t total_bytes, int64_t granule_bytes)
310       : total_bytes_(total_bytes), granule_bytes_(granule_bytes) {}
311 
312   // Emplaces a new Chunk of data to copy to the device. Returns a non-OK status
313   // if the Chunk's size causes the amount of transferred data to exceed
314   // total_bytes() or if the stream is already complete.
315   //
316   // The size of the chunk must be a multiple of granule_bytes().
317   // TODO(jmolloy): Enforce the granule size.
318   Status AddChunk(PjRtChunk chunk);
319 
320   // Returns the total amount of data the stream expects to be transferred.
total_bytes()321   int64_t total_bytes() const { return total_bytes_; }
322 
323   // Returns the granule size in bytes. The size of the chunk added to this
324   // stream must be a multiple of this number.
granule_size_in_bytes()325   int64_t granule_size_in_bytes() const { return granule_bytes_; }
326 
327   // Returns the amount of data the stream currently has either transferred or
328   // has buffered to transfer.
current_bytes()329   int64_t current_bytes() const {
330     absl::MutexLock lock(&mu_);
331     return current_bytes_;
332   }
333 
334   // Returns true if the stream is complete; all expected bytes have been
335   // transferred or are buffered to transfer.
IsComplete()336   bool IsComplete() const {
337     absl::MutexLock lock(&mu_);
338     return current_bytes_ == total_bytes_;
339   }
340 
341   // Returns true if the stream is empty; no data has been queued.
empty()342   bool empty() const { return current_bytes() == 0; }
343 
344   // Consumes the next chunk. If no chunks remain, returns nullopt. Blocks
345   // until a chunk is available.
346   std::optional<PjRtChunk> ConsumeNextChunk();
347 
348   // Members are protected to allow subclassing for mocking in tests.
349  protected:
350   int64_t total_bytes_;
351   int64_t granule_bytes_;
352   int64_t current_bytes_ ABSL_GUARDED_BY(mu_) = 0;
353   std::deque<PjRtChunk> buffered_chunks_ ABSL_GUARDED_BY(mu_);
354   mutable absl::Mutex mu_;
355 };
356 
357 class PjRtHostMemoryForDeviceManager {
358  public:
359   virtual ~PjRtHostMemoryForDeviceManager();
360 
361   // Transforms the host memory representations of a shape with the host layout
362   // to the host memory representation of the same shape with the device layout.
363   // `src_shape` and `dst_shape` may only differ in their layouts.
364   virtual StatusOr<PjRtChunk> ToDeviceLayout(const void* src_data,
365                                              size_t src_size,
366                                              const Shape& host_shape,
367                                              const Shape& device_shape) = 0;
368 
369   // Transforms the host memory representations of a shape with the device
370   // layout to the host memory representation of the same shape with the host
371   // layout. `src_shape` and `dst_shape` may only differ in their layouts.
372   virtual Status ToHostLayout(const void* src_data, size_t src_size,
373                               const Shape& src_shape, void* dst_data,
374                               size_t dst_size, const Shape& dst_shape) = 0;
375 };
376 
377 class PjRtLoadedExecutable;
378 
379 // Encapsulates the state of Python session with XLA.
380 //
381 // It is the responsibility of the client of this API to keep the PjRtClient
382 // alive as long as any of the other runtime objects are alive.
383 //
384 // A note on the semantics of cross-device copies.
385 //
386 // There are two mechanisms to transfer a buffer from one device to another.
387 // When both devices are on the same host (more specifically, the user program
388 // ends up with pointers to both the source and destination buffers in the same
389 // address space), the caller can use:
390 //   dst_buffer = src_buffer->CopyToDevice(dst_device)
391 //
392 // When the source and destination are on different hosts, but the transfer is
393 // made via native device networking (as opposed to the user program fetching
394 // the buffer and sending it using its own networking code), the caller can
395 // use:
396 //   DstHost: dst_client->MakeCrossHostReceiveBuffers(...)
397 //   DstHost: [...]
398 //   DstHost: gets callback containing PjRtCrossHostRecvDescriptors
399 //   DstHost: sends cross-host recv serialized descriptors to SrcHost
400 //   SrcHost: src_buffer->CopyToRemoteDevice(serialized_descriptors)
401 //
402 // Note that in the cross-host case, the dst_client may call
403 // MakeCrossHostReceiveBuffers before the action that produces src_buffer has
404 // been enqueued at SrcHost.
405 //
406 // On some platforms, device-to-device transfers consume scarce hardware
407 // resources. If dst_client->MakeCrossHostReceiveBuffers immediately claimed
408 // those resources, then there would be a risk of system-wide deadlock, if the
409 // resources claimed by the recv prevented other transfers that are necessary
410 // to generate src_buffer from acquiring enough resources to proceed.
411 //
412 // In order to allow clients to avoid deadlocks such as those in the preceding
413 // paragraph, PjRtClient guarantees progress but not fairness with respect to
414 // the order that cross-device transfers are enqueued on a given host, as
415 // follows:
416 //
417 // The progress guarantee is that a cross-device transfer T on host A will not
418 // claim scarce hardware resources until it is guaranteed that all transfers
419 // enqueued on A before T have already either completed, or been assigned enough
420 // resources to ensure that they can eventually complete.
421 //
422 // The lack of a fairness guarantee means that, if cross-device transfer T1 is
423 // enqueued before transfer T2 at A, then T2 may complete before T1. T1 may be
424 // delayed for an unbounded time waiting for T2 if T2 is large, even though T1
425 // will eventually be able to make progress.
426 class PjRtClient {
427  public:
428   PjRtClient() = default;
PjRtClient(std::unique_ptr<PjRtHostMemoryForDeviceManager> host_memory_for_device_manager)429   explicit PjRtClient(std::unique_ptr<PjRtHostMemoryForDeviceManager>
430                           host_memory_for_device_manager)
431       : host_memory_for_device_manager_(
432             std::move(host_memory_for_device_manager)) {}
433 
434   virtual ~PjRtClient() = default;
435 
436   // Return the process index of this client. Always 0 in single-process
437   // settings.
438   virtual int process_index() const = 0;
439 
440   // Return the number of devices in the entire computation. In multi-headed
441   // client setting, some are addressable by this client, some are not. In a
442   // single-client setting, this is equal to the number of addressable devices.
443   virtual int device_count() const = 0;
444 
445   // Return number of addressable devices. Addressable devices are those that
446   // the client can issue commands to.
447   virtual int addressable_device_count() const = 0;
448 
449   // Return all devices known to the client, including addressable and
450   // non-addressable devices.
451   virtual absl::Span<PjRtDevice* const> devices() const = 0;
452 
453   // Return only addressable devices. The devices are in no particular order.
454   virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
455 
456   // Lookup any PjRtDevice for a given PjRtDevice::id().
457   virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
458 
459   // Return an addressable PjRtDevice for a given
460   // PjRtDevice::local_hardware_id().
461   virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
462       int local_hardware_id) const = 0;
463 
464   // Return an ID that identifies the platform (CPU/GPU/TPU).
465   virtual PjRtPlatformId platform_id() const = 0;
466 
467   // Returns a string that identifies the platform (CPU/GPU/TPU).
468   virtual absl::string_view platform_name() const = 0;
469 
470   // Returns a string containing human-readable, platform-specific version info
471   // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
472   virtual absl::string_view platform_version() const = 0;
473 
474   // Returns an enum that identifies the type of runtime being used under this
475   // client.
476   virtual PjRtRuntimeType runtime_type() const = 0;
477 
478   // Return a device-specific default device assignment, e.g., GPU and TPU may
479   // be different.
480   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
481       int num_replicas, int num_partitions) const = 0;
482 
483   // Returns a device-specific default device assignment for multi-slice system.
484   // If num_replicas_per_slice is not defined (nullopt) then we assume that
485   // all the partitions live entirely on a single slice and that all cross slice
486   // communication happens across replicas assuming then that
487   // num_replicas_per_slice is going to be "num_replicas / num_slices".
488   // TODO(zhangqiaorjc): Convert this to pure virtual and push down.
GetDefaultDeviceAssignment(int num_replicas,std::optional<int> num_replicas_per_slice,int num_partitions,const MultiSliceConfig * multi_slice_config)489   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
490       int num_replicas, std::optional<int> num_replicas_per_slice,
491       int num_partitions, const MultiSliceConfig* multi_slice_config) const {
492     return Unimplemented("Multi slice device assignment is not supported.");
493   }
494 
495   // Returns a backend-specific HLO cost analysis visitor.
496   virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0;
497 
498   // Compile `computation` with given `options`.
499   virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
500       const XlaComputation& computation, CompileOptions options) = 0;
501 
502   // Variant of `Compile` that accepts an MLIR module.
503   virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
504       mlir::ModuleOp module, CompileOptions options) = 0;
505 
506   // Generates a unique fingerprint for `executable`, may be std::nullopt.
507   virtual StatusOr<std::optional<std::string>> ExecutableFingerprint(
508       const PjRtLoadedExecutable& executable) const = 0;
509 
510   // Returns a platform-specific serialization of `executable`. The
511   // serialization is not guaranteed to be stable over time. `executable` must
512   // have been produced by this client.
513   virtual StatusOr<std::string> SerializeExecutable(
514       const PjRtLoadedExecutable& executable) const = 0;
515 
516   // Deserializes a serialized executable as produced by
517   // SerializeExecutable(). `serialized` must have been produced by a client of
518   // the same platform and version as this one.
519   virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
520       absl::string_view serialized, CompileOptions options) = 0;
521 
522   // LoadSerializedExecutable takes the serialized output of PjRtExecutable. The
523   // returned executable is loaded by this client. The same checks are made as
524   // in Load that the serialized executable is compatible with the client.
525   // LoadSerializedExecutable will materialize CompileOptions from within the
526   // serialized executable unlike 'DeserializeExecutable' above that accepts
527   // CompileOptions.
528   virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
LoadSerializedExecutable(absl::string_view serialized)529   LoadSerializedExecutable(absl::string_view serialized) const {
530     return Unimplemented("Loading serialized executable not supported.");
531   }
532 
533   // Loads the executable returns aa PjRtLoadedExecutable runnable by this
534   // client. Returns an error if the PjRtExecutable was created with an
535   // incompatible topology or client.
536   // PjRtExecutable contains a copy of the CompileOptions that was used to
537   // generate the executable. Load will use the CompileOptions from within the
538   // executable.
Load(std::unique_ptr<PjRtExecutable> executable)539   virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
540       std::unique_ptr<PjRtExecutable> executable) {
541     return Unimplemented("Loading executable not supported.");
542   }
543 
544   // Creates a buffer on the device without initializing or copying any data.
545   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
546       const Shape& shape, PjRtDevice* device) = 0;
547 
548   // A client may want to create a buffer, and hand the buffer to other PjRt
549   // methods, before the data to store in the buffer is available to the client.
550   // This is supported using CreateBuffersForAsyncTransfer, which returns an
551   // AsyncBufferTransferManager helper object.
552   //
553   // The PjRtBuffers can be retrieved from the AsyncBufferTransferManager and
554   // safely passed immediately to downstream PjRt method calls. Subsequently the
555   // client can call methods on the AsyncBufferTransferManager object to copy
556   // data into the buffers, and once the data copies are complete, the buffers'
557   // definition events will automatically become ready, unblocking downstream
558   // consumers of the buffers.
559   //
560   // A single call to CreateBuffersForAsyncTransfer creates a "batch" of buffers
561   // that share a single definition event, which may amortize some performance
562   // overheads, but means that none of the buffers are available to downstream
563   // consumers until all the transfers have completed. Multiple calls to
564   // CreateBuffersForAsyncTransfer should be made if it is desirable for buffers
565   // to become available as soon as transfers into them complete.
566 
567   // Helper class to all clients to asynchronously transfer data into buffers
568   // that are created uninitialized, see comments immediately above.
569   class AsyncBufferTransferManager {
570    public:
571     virtual ~AsyncBufferTransferManager() = default;
572 
573     // Returns the number of buffers managed by this object.
574     virtual size_t buffer_count() const = 0;
575 
576     // Returns the destination device of the transfers.
577     virtual PjRtDevice* device() const = 0;
578 
579     // Returns buffer_index, which can be passed to downstream consumers
580     // immediately and will become available once transfers complete. May not
581     // be called more than once for a given buffer_index.
582     //
583     // RetrieveBuffer can be called at any convenient time; transfer methods
584     // can safely be called for a buffer index after RetrieveBuffer has been
585     // called.
586     virtual std::unique_ptr<PjRtBuffer> RetrieveBuffer(int buffer_index) = 0;
587 
588     // Transfers 'literal' into buffer_index. No transfer calls into
589     // buffer_index can be made after this call. on_done is called when the
590     // transfer is complete but before the buffers are made available to
591     // their consumers. 'literal' must remain in scope until on_done is
592     // called.
593     virtual Status TransferLiteralToBuffer(int buffer_index,
594                                            const LiteralSlice& literal,
595                                            std::function<void()> on_done) = 0;
596 
597     // Returns the on-device size in bytes of buffer buffer_index.
598     virtual size_t buffer_size(int buffer_index) const = 0;
599 
600     // Transfers 'data' into buffer_index. 'data' must be already laid out in
601     // the correct on-device format, for example returned by a call to
602     // buffer->CopyRawToHost. No transfer calls into buffer_index can be made
603     // after this call. on_done is called when the transfer is complete but
604     // before the buffers are made available to their consumers. 'data' must
605     // remain in scope until on_done is called.
606     virtual Status TransferRawDataToBuffer(int buffer_index,
607                                            absl::string_view data,
608                                            std::function<void()> on_done) = 0;
609 
610     // Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
611     // length transfer_size. 'data' must be already laid out in the correct
612     // on-device format, for example returned by a call to
613     // buffer->CopyRawToHost. If is_last_transfer is false then the buffer
614     // remains unavailable to consumers after the transfer completes. If
615     // is_last_transfer is true then the buffer becomes available to consumers
616     // after the transfer completes, and no transfer calls into buffer_index can
617     // be made after this call. on_done is called when the transfer is complete
618     // but before the buffers are made available to their consumers. 'data' must
619     // remain in scope until on_done is called.
620     virtual Status TransferRawDataToSubBuffer(
621         int buffer_index, const void* data, int64_t offset,
622         int64_t transfer_size, bool is_last_transfer,
623         std::function<void()> on_done) = 0;
624 
625     // Indicates that a client error occurred and the transfers will never
626     // complete. Puts all buffers in an error state. For the stream executor
627     // client, since error states are not well supported, this triggers a fatal
628     // error.
629     //
630     // SetTransferError may be called at most once, and may not be called unless
631     // at least one buffer has not yet had its final transfer initiated.
632     virtual void SetTransferError(Status error) = 0;
633 
634     // Adds the specified key/value metadata for the transfer operation.
635     // This is typically used for debugging purposes, such as adding a handle
636     // that can be used to identify transfer operations.
637     using TransferMetadata = absl::flat_hash_map<std::string, std::string>;
638     virtual void AddTransferMetadata(const TransferMetadata& metadata) = 0;
639   };
640 
641   // Returns a manager for async transfers into a set of buffers with on-host
642   // shapes 'shapes'.
643   virtual StatusOr<std::unique_ptr<AsyncBufferTransferManager>>
644   CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
645                                 PjRtDevice* device) = 0;
646 
647   // Describes the semantics the caller to BufferFromHostBuffer expects from the
648   // runtime, in a total order from most restrictive to least restrictive.
649   enum class HostBufferSemantics {
650     // The runtime may not hold references to `data` after the call to
651     // `BufferFromHostBuffer` completes. The caller promises that `data` is
652     // immutable and will not be freed only for the duration of the
653     // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
654     // before `BufferFromHostBuffer` returns.
655     kImmutableOnlyDuringCall,
656 
657     // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
658     // returns while the runtime completes a transfer to the device. The caller
659     // promises not to mutate or free `data` until the transfer completes, at
660     // which point the runtime will call `on_done_with_host_buffer`. It is also
661     // correct to wait on the host (directly or indirectly) for the buffer's
662     // definition event to complete.
663     kImmutableUntilTransferCompletes,
664 
665     // The PjRtBuffer may alias `data` internally and the runtime may use the
666     // `data` contents as long as the buffer is alive. The caller promises to
667     // keep `data` alive and not to mutate its contents as long as the buffer is
668     // alive; to notify the caller that the buffer may be freed, the runtime
669     // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
670     // non-CPU platforms this acts identically to
671     // kImmutableUntilTransferCompletes.
672     kZeroCopy,
673   };
674 
675   // on_done_with_host_buffer is optional and may be null.
676   // on_done_with_host_buffer will be called iff an OK status is returned.
677   //
678   // `data` points to the backing array of the host buffer. Caution:
679   // `byte_strides` are allowed to be negative, in which case `data` may need
680   // to point to the interior of the buffer, not necessarily its start.
681   //
682   // If byte_strides is omitted, the array is assumed to have a dense layout
683   // with dimensions in major-to-minor order.
684   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
685       const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
686       std::optional<absl::Span<int64_t const>> byte_strides,
687       HostBufferSemantics host_buffer_semantics,
688       std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
689 
690   // Note that literal must remain in scope until the transfer has completed, so
691   // the caller should, for example, wait for GetReadyFuture().Await()
692   // completes on the return value before letting literal go out of scope.
693   virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
694       const LiteralSlice& literal, PjRtDevice* device) = 0;
695 
696   // Creates a PjRtBuffer that is a non-owned view of an on-device
697   // buffer (typically allocated by another library).
698   // on_delete_callback is called when the PjRtBuffer is done with the on-device
699   // buffer. The buffer may be mutated, for example, if the buffer is donated
700   // to an Execute operation.
701   // TODO(phawkins): Currently this API assumes the buffer is ready to use
702   // immediately on the device. Extend it to support, for example, waiting for a
703   // CUDA stream/event.
704   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
705       void* device_ptr, const Shape& shape, PjRtDevice* device,
706       std::function<void()> on_delete_callback) = 0;
707 
708   // Returns platform-dependent address for the given buffer that is often but
709   // not guaranteed to be the physical/device address.
710   virtual StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer);
711 
712   // Returns a vector of PjRtBuffers that can be used to receive
713   // cross host transfers using `client` on `device'. Asynchronously calls
714   // `notifier` once receive descriptors are ready to be communicated to the
715   // sender. `shapes` must be the exact shapes, with identical layouts,
716   // corresponding to the buffers that will be sent. When resources for the
717   // transfer are available, notifier will be called with a vector of
718   // PjRtCrossHostRecvDescriptors structs, one for each shape in `shapes`. Each
719   // struct contains an opaque string that should be transmitted to the sending
720   // host and used in a call to CopyToRemoteDevice. None of the recv buffers
721   // will become ready until *all* of the sends have completed.
722   //
723   // If MakeCrossHostReceiveBuffers returns an error, then `notifier` will not
724   // be called. Otherwise `notifier` will be called exactly once. In the case
725   // where `notifier` is called with an error status, then the PjRtBuffers
726   // returned by MakeCrossHostReceiveBuffers will never yield data.
727   //
728   // See note on semantics of cross-device copies in the class definition
729   // comment for PjRtClient.
730   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
731   MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
732                               PjRtDevice* device,
733                               PjRtCrossHostRecvNotifier notifier) = 0;
734 
735   // Asynchronously makes a vector of PjRtBuffers that can be used to receive
736   // cross host transfers, as in MakeCrossHostReceiveBuffers above, however
737   // each buffer expects to be "gathered" using multiple sends, one for each of
738   // a set of sub-slices of the destination buffer.
739   //
740   // For each value in shapes there is a corresponding FullGatherDetails struct
741   // that describes the sub-slices.
742   struct GatherDetails {
743     // The dimensions of the corresponding buffer that the gather slices
744     // into. These dimensions must be the major dimensions in the on-device
745     // layout of the buffer, and must all be untiled. The scatter acts as if
746     // the buffer were transposed/reshaped so that all of these dimensions were
747     // combined into a single dimension whose size is the product of the
748     // dimensions, and the slice indices correspond to indices in that single
749     // combined dimension.
750     //
751     // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
752     // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
753     // treated as if it were shape [12, 128, 128] and the indices in
754     // slice_boundaries range in [0, 12].
755     absl::InlinedVector<int, 3> dimensions;
756     // The cumulative indices in dimension of the slices. For example, if
757     // shape.dimensions(dimension)==10, setting slice_boundaries to {2, 5, 10}
758     // would correspond to 3 slices of sizes {2, 3, 5} respectively. If the last
759     // entry in slice_boundaries is less than the size of the combined gather
760     // dimension, the trailing data in the buffer is undefined after the receive
761     // completes.
762     std::vector<int64_t> slice_boundaries;
763   };
764   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
765   MakeCrossHostReceiveBuffersForGather(
766       absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
767       PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) = 0;
768 
769   // Create ChannelHandles for XLA send/recv.
770   virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
771   virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
772   virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
773 
774   // TODO(zhangqiaorjc): Experimental API to be removed.
775   // Defragment device memory.
776   virtual Status Defragment() = 0;
777 
778   // Return the PjRtHostMemoryForDeviceManager for this client. It can be
779   // nullptr if the implementation does not provide one.
GetPjRtHostMemoryForDeviceManager()780   PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() const {
781     return host_memory_for_device_manager_.get();
782   }
783 
784  private:
785   std::unique_ptr<PjRtHostMemoryForDeviceManager>
786       host_memory_for_device_manager_;
787 };
788 
789 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
790 // can be either valid or invalid. An invalid buffer is one that has never been
791 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or
792 // by donating it to a computation that aliases an input parameter to an
793 // output). We allow PjRtBuffer objects to outlive the underlying device
794 // buffers so we can decouple buffer lifetimes from the corresponding Python
795 // references if needed. Thread-safe.
796 class PjRtBuffer {
797  public:
798   virtual ~PjRtBuffer() = default;
799 
800   virtual const Shape& on_device_shape() const = 0;
801 
802   // Same as on_device_shape when the shape is static. When the shape is
803   // dynamic, it gathers the metadata from the device and returns a static shape
804   // representing the logical shape of the data. This approach is identical to
805   // how tensorflow and xrt setup the output buffer in the graph.
806   //
807   // Since this method actually acquires locks and communicate with the device,
808   // it does not have the const qualifier, similar to what ToLiteral does.
809   virtual StatusOr<Shape> logical_on_device_shape() = 0;
810   virtual PjRtDevice* device() const = 0;
811   virtual PjRtClient* client() const = 0;
812 
813   // ExternalReference is a potentially long-lived reference held while a buffer
814   // is being shared by an external framework, e.g., NumPy. A client acquires an
815   // external reference by calling PjRtBuffer::AcquireExternalReference() and
816   // releases it by deleting the ExternalReference. The external framework
817   // should not modify the underlying buffer unless it is confident via its own
818   // synchronization that modifications do not race with reads from the
819   // PjRtBuffer.
820   class ExternalReference {
821    public:
822     virtual ~ExternalReference() = 0;
823     // Return opaque device memory pointer to root buffer.
OpaqueDeviceMemoryDataPointer()824     void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; }
825 
826    protected:
827     void* data_ptr_;
828   };
829   virtual StatusOr<std::unique_ptr<ExternalReference>>
830   AcquireExternalReference() = 0;
831 
832   // Asynchronously copies the buffer's value into `literal`.
833   //
834   // Return value is a future the caller can use to discover when the copy has
835   // completed. The transfer respects the layout of `literal`; to specify a
836   // particular layout, set the layout before calling `ToLiteral`.
837   virtual PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) = 0;
838 
839   // Copies the buffer's value into `literal`. Calls `on_ready` when the value
840   // (or an error) is ready. The transfer respects the layout of `literal`; to
841   // specify a particular layout, set the layout before calling `ToLiteral`.
842   ABSL_DEPRECATED("Use ToLiteral(...).OnReady() instead")
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)843   void ToLiteral(MutableLiteralBase* literal,
844                  std::function<void(Status)> on_ready) {
845     ToLiteral(literal).OnReady(std::move(on_ready));
846   }
847 
848   // Synchronous overload of ToLiteral, as a convenience.
ToLiteralSync(MutableLiteralBase * literal)849   Status ToLiteralSync(MutableLiteralBase* literal) {
850     absl::Notification done;
851     Status status;
852     ToLiteral(literal, [&](Status s) {
853       status = std::move(s);
854       done.Notify();
855     });
856     done.WaitForNotification();
857     return status;
858   }
859 
860   // Convenience synchronous overload that allocates a literal with a default
861   // layout.
ToLiteralSync()862   StatusOr<std::shared_ptr<Literal>> ToLiteralSync() {
863     auto literal = std::make_shared<Literal>(
864         ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
865     TF_RETURN_IF_ERROR(ToLiteralSync(literal.get()));
866     return literal;
867   }
868 
869   // Returns the number of bytes of the buffer storage on the device.
870   virtual StatusOr<size_t> GetOnDeviceSizeInBytes() const = 0;
871 
872   // Transfers a sub-range of the on-device representation of the buffer.
873   // offset+transfer_size must be less than GetOnDeviceSizeInBytes. The
874   // returned future transitions to ready on error, or after the transfer has
875   // completed.
876   virtual PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
877                                            int64_t transfer_size) = 0;
878 
879   // Drops the buffer's reference to its associated device memory, leaving the
880   // buffer in an invalid state. The memory will be freed lazily when all async
881   // operations using the buffer have completed, according to the allocation
882   // semantics of the underlying platform. Delete may briefly block if another
883   // thread is in the process of enqueuing an operation on this buffer, but it
884   // will never block for a stream operation to complete. If an external
885   // framework holds a reference to the TrackedDeviceBuffer via
886   // GetBufferWithExternalReference, the memory will not be freed until the
887   // external framework drops the reference.
888   virtual void Delete() = 0;
889 
890   // Similar to Delete, drops the buffer's reference to its associated device
891   // memory, leaving the buffer in an invalid state, but transfers the device
892   // memory ownership out via an ExternalReference rather than
893   // freeing the device memory, so that another framework can take ownership of
894   // it. A return value of nullptr indicates that PjRtBuffer has been
895   // deleted. The buffer returned from Release may be safely dropped at any time
896   // even if it still has pending async operations. The client should call
897   // GetReadyFuture().Await before calling ReleaseDeviceMemoryOwnership with
898   // wait_for_operations_to_complete=false, to ensure that the host has
899   // synchronized past any outstanding write operations to the buffer. If
900   // wait_for_operations_to_complete=true the host will block until any
901   // potentially outstanding asynchronous operations have completed before
902   // returning, in which case it is safe to read or mutate the returned buffer.
903   // If the buffer was shared via an external reference it is the client's
904   // responsibility that accesses via that reference do not interfere with
905   // accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
906   virtual StatusOr<std::unique_ptr<ExternalReference>>
907   ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
908 
909   // True if and only if Delete or Release has previously been called.
910   virtual bool IsDeleted() = 0;
911 
912   // Copies the buffer to device `dst_device`, performing a d2d transfer when
913   // `dst_device` is sharing the same Client, and performing a d2h and h2d copy
914   // if `dst_device` lives on a different Client.
915   // Returns an error if the buffer is already on dst_device.
916   //
917   // See note on semantics of cross-device copies in the class definition
918   // comment for PjRtClient.
919   virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
920       PjRtDevice* dst_device) = 0;
921 
922   // Copies the buffer to the remote device encoded in serialized_descriptor.
923   // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
924   // remote host's destination device. MakeCrossHostReceiveBuffers takes an
925   // array of shapes to construct the destination buffers, and a callback
926   // supplies an array containing both the destination buffers, and a serialized
927   // descriptor for each buffer. For each destination buffer there should be a
928   // matching call to src->CopyToRemoteDevice on a remote host for a src buffer
929   // of the corresponding shape. serialized_descriptor is the string returned by
930   // the callback along with the corresponding destination buffer.
931   //
932   // When the send either completes or fails, on_done will be called. If
933   // status is Ok then it is guaranteed that sends_were_enqueued==true.
934   // Otherwise, if sends_were_enqueued==false then the sender should contact
935   // the receiver out of band to request cancellation of the transfer. If
936   // !status.ok() and sends_were_enqueued==true then it is not possible to
937   // determine whether the transfer succeeded and the system is in an
938   // undefined state. This undefined state almost certainly indicates an
939   // unrecoverable hardware error.
940   //
941   // See note on semantics of cross-device copies in the class definition
942   // comment for PjRtClient.
943   using RemoteSendCallback =
944       std::function<void(Status status, bool sends_were_enqueued)>;
945   virtual void CopyToRemoteDevice(absl::string_view serialized_descriptor,
946                                   RemoteSendCallback on_done) = 0;
947   struct ScatterDetails {
948     // The dimensions of the corresponding buffer that the scatter slices
949     // across. These dimensions must be the major dimensions in the on-device
950     // layout of the buffer, and must all be untiled. The scatter acts as if
951     // the buffer were transposed/reshaped so that all of these dimensions were
952     // combined into a single dimension whose size is the product of the
953     // dimensions, and the slice indices correspond to indices in that single
954     // combined dimension.
955     //
956     // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
957     // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
958     // treated as if it were shape [12, 128, 128] and the indices in slices
959     // range in [0, 12].
960     absl::InlinedVector<int, 3> dimensions;
961     // The start and end indices of the slices.
962     std::vector<std::pair<int64_t, int64_t>> slices;
963   };
964   virtual void CopyToRemoteDeviceScattered(
965       absl::Span<const std::pair<std::string, RemoteSendCallback>>
966           serialized_descriptors_and_callbacks,
967       const ScatterDetails& scatter_details) = 0;
968 
969   // Returns a future that can be used to discover when the data in the
970   // PjRtBuffer has been computed, or an error has occurred.
971   //
972   // TODO(b/241967811): change these weird semantics
973   // If the buffer has been deleted or donated the returned future will
974   // immediately hold an error, however if GetReadyFuture() is called before
975   // the buffer has been deleted or donated then the returned future will stay
976   // valid (will not transition to error as a consequence of buffer deletion)
977   // even if the buffer is subsequently donated or deleted.
978   virtual PjRtFuture<Status> GetReadyFuture() = 0;
979 
980   // Blocks the host until the buffer's value has been computed and is ready for
981   // immediate use on the device. Useful in particular for timing benchmarks.
982   ABSL_DEPRECATED("Use GetReadyFuture().Await() instead")
BlockHostUntilReady()983   Status BlockHostUntilReady() {
984     auto s = GetReadyFuture().Await();
985     // Fix up error string because some clients rely on it.
986     if (!s.ok() && s.error_message() ==
987                        "GetReadyFuture() called on deleted or donated buffer") {
988       return InvalidArgument(
989           "BlockHostUntilReady() called on deleted or donated buffer");
990     }
991     return s;
992   }
993 
994   // Calls callback when the buffer is ready.
995   //
996   //   buf->OnReady(callback);
997   //
998   // is semantically almost identical to:
999   //
1000   //   ForkThread([]() { callback(buf->Await()); });
1001   //
1002   // the only difference being that the callback may happen immediately on the
1003   // calling thread. (The implementation may also be more efficient.)
1004   //
1005   // The interface makes no assumptions about what thread calls callback, so the
1006   // caller must ensure that callback returns quickly and hands off long-running
1007   // work or any blocking operation to a caller-managed threadpool.
1008   ABSL_DEPRECATED("Use GetReadyFuture().OnReady() instead")
OnReady(std::function<void (Status)> callback)1009   void OnReady(std::function<void(Status)> callback) {
1010     return GetReadyFuture().OnReady(std::move(callback));
1011   }
1012 
1013   // Whether this buffer is on CPU and thus allows for certain optimizations.
1014   virtual bool IsOnCpu() const = 0;
1015 };
1016 
1017 class ExecuteContext {
1018  public:
1019   virtual ~ExecuteContext() = default;
1020 };
1021 
1022 struct PjRtTransferMetadata {
1023   Shape device_shape;
1024 };
1025 
1026 struct SendCallback {
1027   int64_t channel_id;
1028   // The callback for retrieving the send value. It will be invoked once for
1029   // each invocation of the corresponding Send op in the HLO program (So it can
1030   // be invoked multiple times if it is in a loop). Currently there is no
1031   // guarantee that the callback here will be invoked in the same order as their
1032   // corresponding HLO Send ops. The callback can also return errors to indicate
1033   // the execution should fail.
1034   //
1035   // IMPORTANT: the implementation might NOT signal the error to the execution,
1036   // and the execution will run to completion with UNDEFINED DATA returned by
1037   // the callback. If there is any potential control flow that depends on the
1038   // value of the returned data, an error return is unsafe.
1039   //
1040   // TODO(chky): Currently the callback invocation order may not be consistent
1041   // with the HLO send op invocation order, due to limitations in some PjRt
1042   // implementation. Consider making it strictly the same order as HLO program.
1043   std::function<Status(const PjRtTransferMetadata& metadata, PjRtChunk chunk,
1044                        size_t total_size_in_bytes, bool done)>
1045       callback;
1046 };
1047 
1048 struct RecvCallback {
1049   int64_t channel_id;
1050   // The callback for feeding the recv value. It will be invoked once for each
1051   // invocation of the corresponding Recv op in the HLO program (So it can be
1052   // invoked multiple times if it is in a loop). Currently there is no
1053   // guarantee that the callback here will be invoked in the same order as their
1054   // corresponding HLO Recv ops.
1055   std::function<void(const PjRtTransferMetadata& metadata,
1056                      CopyToDeviceStream& stream)>
1057       callback;
1058 };
1059 
1060 struct ExecuteOptions {
1061   // If true, the client must pass a single PjRtBuffer which contains all of
1062   // the arguments as a single XLA tuple, otherwise each argument must be
1063   // passed in its own PjRtBuffer. May only be true if the executable was
1064   // compiled with parameter_is_tupled_arguments==true.
1065   bool arguments_are_tupled = false;
1066   // If true, the computation must return a tuple, which will be destructured
1067   // into its elements.
1068   bool untuple_result = false;
1069   // If non-zero, identifies this execution as part of a potentially
1070   // multi-device launch. This can be used to detect scheduling errors, e.g. if
1071   // multi-host programs are launched in different orders on different hosts,
1072   // the launch IDs may be used by the runtime to detect the mismatch.
1073   int32_t launch_id = 0;
1074   // If non-null, an opaque context passed to an execution that may be used to
1075   // supply additional arguments to a derived class of PjRtExecutable.
1076   const ExecuteContext* context = nullptr;
1077   // If true, check that the PjRtBuffer argument shapes match the compiled
1078   // shapes. Otherwise, any shape with the right size on device may be passed.
1079   bool strict_shape_checking = true;
1080 
1081   // Set multi_slice_config when the computation spans multiple slices. The
1082   // config should match what was used during compilation to generate this
1083   // executable.
1084   const MultiSliceConfig* multi_slice_config = nullptr;
1085 
1086   // The send/recv callbacks for PjRt execution. The first level span is for
1087   // multi-device parallel execution, the second level vector contains the
1088   // callbacks for all send/recv ops in the executable. These callbacks can be
1089   // stateful and the user code is responsible for managing the states here.
1090   // These callbacks must outlive the execution.
1091   absl::Span<const std::vector<SendCallback>> send_callbacks;
1092   absl::Span<const std::vector<RecvCallback>> recv_callbacks;
1093 
1094   // The `execution_mode` decides whether the execution will be invoked in the
1095   // caller thread or launched to a separate thread. By default, the
1096   // implementation may choose either strategy or use a heuristic to decide.
1097   // Currently it is only applied to CPU implementations
1098   enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous };
1099   ExecutionMode execution_mode = ExecutionMode::kDefault;
1100 };
1101 
1102 // Represents a compiled computation that can be executed given handles to
1103 // device-allocated literals. If any input/output alias has been specified in
1104 // the computation, the parameter containing the input buffer will be donated
1105 // when passed to the execution.
1106 class PjRtLoadedExecutable : public PjRtExecutable {
1107  public:
1108   virtual ~PjRtLoadedExecutable() = default;
1109 
1110   virtual PjRtClient* client() const = 0;
1111 
1112   virtual const DeviceAssignment& device_assignment() const = 0;
1113 
1114   // The replica and partition indices of device_assignment to be run by this
1115   // client. On single-host platforms without partitioning, this is all replicas
1116   // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
1117   // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
1118   // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
1119   struct LogicalDeviceIds {
1120     int replica;
1121     int partition;
1122   };
1123   virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
1124       const = 0;
1125 
1126   // An addressable_device is one which the client can issue commands to.
1127   // addressable_devices()[i] is the Device to which
1128   // addressable_device_logical_ids()[i] is assigned.
1129   virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
1130 
1131   // Donation Semantics:
1132   //
1133   // The following Execute*() methods will donate the input buffer to the
1134   // execution if it is specified in the executable. Donation is usually
1135   // implemented as a transaction: it is acquired in the begining and committed
1136   // when the device execution is successully launched. Concurrent donations
1137   // might either block or return failures.
1138   //
1139   // TODO(chky): It is generally desired that concurrent donations do not block,
1140   // as it otherwise results in deadlock easily. Consider always returning
1141   // failure on concurrent donations.
1142 
1143   // Executes on devices addressable by the client. Requires executable has a
1144   // device_assignment and all devices in the device_assignment are addressable
1145   // by the client.
1146   //
1147   // `argument_handles` is `[num_devices, num_args]`.
1148   //
1149   // If returned_futures.has_value():
1150   //   if Execute does not return an error status:
1151   //     *returned_futures will be resized to be the same length as the return
1152   //     vector, and each future will become ready once the corresponding device
1153   //     execute has completed.
1154   //   else:
1155   //     *returned_futures is undefined.
1156   //
1157   // The caller is *NOT* required to ensure that PjRtLoadedExecutable stays
1158   // alive until futures are ready.
1159   virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
1160   Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1161           const ExecuteOptions& options,
1162           std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) = 0;
1163   // Convenience wrapper for Execute that never returns futures.
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)1164   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
1165       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1166       const ExecuteOptions& options) {
1167     std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
1168     return Execute(std::move(argument_handles), options, returned_futures);
1169   }
1170 
1171   // Execute the assigned replica/partition on a given `device`. Requires
1172   // executable has a device_assignment, `device` is present in the
1173   // device_assignment and addressable by the client.
1174   //
1175   // If fill_future is true:
1176   //   if ExecuteSharded does not return an error status:
1177   //     returned_future will be filled with a future that will become ready
1178   //     once the execution has completed.
1179   //    else:
1180   //     returned_future will not be modified.
1181   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1182       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1183       const ExecuteOptions& options,
1184       std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) = 0;
1185   // Convenience wrapper for ExecuteSharded that always returns a future.
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future)1186   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1187       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1188       const ExecuteOptions& options,
1189       std::optional<PjRtFuture<Status>>& returned_future) {
1190     return ExecuteSharded(std::move(argument_handles), device, options,
1191                           returned_future, /*fill_future=*/true);
1192   }
1193   // Convenience wrapper for ExecuteSharded that never returns a future.
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)1194   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1195       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1196       const ExecuteOptions& options) {
1197     std::optional<PjRtFuture<Status>> returned_future;
1198     return ExecuteSharded(std::move(argument_handles), device, options,
1199                           returned_future, /*fill_future=*/false);
1200   }
1201 
1202   // Execute on a given `device`. Requires `device` to be addressable by client.
1203   // Requires executable has exactly 1 replica and 1 partition and no
1204   // device_assignment (thus portable).
1205   //
1206   // If fill_future is true:
1207   //   if ExecutePortable does not return an error status:
1208   //     returned_future will be filled with a future that will become ready
1209   //     once the execution has completed.
1210   //    else:
1211   //     returned_future will not be modified.
1212   virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1213       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1214       const ExecuteOptions& options,
1215       std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) = 0;
1216   // Convenience wrapper for ExecutePortable that always returns a future.
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future)1217   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1218       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1219       const ExecuteOptions& options,
1220       std::optional<PjRtFuture<Status>>& returned_future) {
1221     return ExecutePortable(std::move(argument_handles), device, options,
1222                            returned_future, /*fill_future=*/true);
1223   }
1224   // Convenience wrapper for ExecutePortable that never returns a future.
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)1225   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1226       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1227       const ExecuteOptions& options) {
1228     std::optional<PjRtFuture<Status>> returned_future;
1229     return ExecutePortable(std::move(argument_handles), device, options,
1230                            returned_future, /*fill_future=*/false);
1231   }
1232 
1233   // Asynchronously free resources after the last execution completes.
1234   virtual void Delete() = 0;
1235 
1236   // True if on-device resources associated with the executable are freed.
1237   virtual bool IsDeleted() = 0;
1238 
1239   // True if the `returned_futures` output parameter is supported in the
1240   // Execute*() methods.
1241   //
1242   // TODO(b/240696624): Although the PjRt interface require `returned_futures`
1243   // to be resized correctly if it is not nullopt, some implementation does not
1244   // implement this. So we have to check whether returned_futures is empty.
1245   // Remove this method once the implementation is fixed.
IsReturnedFutureSupported()1246   virtual bool IsReturnedFutureSupported() const { return false; }
1247 
1248  protected:
1249   // Value returned internally from routines that enqueue an execution,
1250   // combining the result buffers with a future that becomes ready when the
1251   // execution completes.
1252   struct Result {
1253     std::optional<PjRtFuture<Status>> future;
1254     std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1255   };
1256 };
1257 
1258 }  // namespace xla
1259 
1260 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
1261