1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
18
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/synchronization/notification.h"
32 #include "absl/types/span.h"
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "tensorflow/compiler/xla/client/executable_build_options.h"
35 #include "tensorflow/compiler/xla/client/xla_computation.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h"
38 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
39 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
40 #include "tensorflow/compiler/xla/shape.h"
41 #include "tensorflow/compiler/xla/status.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/fingerprint.h"
47
48 // API notes:
49 // PjRt stands for "Pretty much Just another RunTime".
50
51 namespace xla {
52
53 using PjRtPlatformId = uint64_t;
54
CpuName()55 inline const char* CpuName() {
56 static constexpr char kCpuName[] = "cpu";
57 return kCpuName;
58 }
GpuName()59 inline const char* GpuName() {
60 static constexpr char kGpuName[] = "gpu";
61 return kGpuName;
62 }
TpuName()63 inline const char* TpuName() {
64 static constexpr char kTpuName[] = "tpu";
65 return kTpuName;
66 }
CpuId()67 inline PjRtPlatformId CpuId() {
68 static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(CpuName());
69 return kCpuId;
70 }
GpuId()71 inline PjRtPlatformId GpuId() {
72 static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(GpuName());
73 return kGpuId;
74 }
TpuId()75 inline PjRtPlatformId TpuId() {
76 static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(TpuName());
77 return kTpuId;
78 }
79
80 enum PjRtRuntimeType { kStreamExecutor, kTfrt };
PjRtRuntimeTypeString(PjRtRuntimeType type)81 inline constexpr absl::string_view PjRtRuntimeTypeString(PjRtRuntimeType type) {
82 switch (type) {
83 case kStreamExecutor:
84 return "stream_executor";
85 case kTfrt:
86 return "tfrt";
87 }
88 }
89
90 class PjRtClient;
91
92 using PjRtDeviceAttribute =
93 std::variant<std::string, int64_t, std::vector<int64_t>>;
94
95 class PjRtDevice {
96 public:
~PjRtDevice()97 virtual ~PjRtDevice() {}
98
99 // Return the client that owns this device.
100 virtual PjRtClient* client() const = 0;
101
102 // Whether client can issue command to this device.
103 virtual bool IsAddressable() const = 0;
104
105 // The ID of this device. IDs are unique among devices of this type
106 // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
107 // hosts' devices. This is the ID that should be used in a DeviceAssignment.
108 virtual int id() const = 0;
109
110 // The index of the process that this device belongs to, i.e. is addressable
111 // from. This is not always identical to PjRtClient::process_index() in a
112 // multi-process setting, where each client can see devices from all
113 // processes, but only a subset of them are addressable and have the same
114 // process_index as the client.
115 virtual int process_index() const = 0;
116
117 // Opaque hardware ID, e.g., the CUDA device number, useful for identifying
118 // which GPU when interacting with non-JAX code. In general, not guaranteed to
119 // be dense, and -1 if undefined.
120 virtual int local_hardware_id() const = 0;
121
122 // A vendor-dependent string that uniquely identifies the kind of device,
123 // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
124 // compatible compilation.
125 virtual absl::string_view device_kind() const = 0;
126
127 // Debug string suitable for logging when errors occur. Should be verbose
128 // enough to describe the current device unambiguously.
129 virtual absl::string_view DebugString() const = 0;
130
131 // Debug string suitable for reading by end users, should be reasonably terse,
132 // for example: "CpuDevice(id=0)".
133 virtual absl::string_view ToString() const = 0;
134
135 // Returns a scoped event that the caller uses to tell the PjRtClient that
136 // there is asynchronous work happening that depends on activity on the
137 // PjRtDevice. See comment on class definition in pjrt_future.h.
138 //
139 // Only some PjRtDevice implementations support ScopedAsyncTrackingEvent, and
140 // those that do not will return nullptr.
141 virtual std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
142 absl::string_view description) const = 0;
143
144 // Transfer the given literal to the infeed queue.
145 virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
146
147 // Transfer and return a value of the given shape from the outfeed queue.
148 virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
149
150 // Returns vendor specific attributes about the device. For example the model
151 // number of a GPU, or the mesh coordinates of a TPU device. The returned
152 // reference will remain valid for the lifetime of the PjRtDevice.
153 virtual const absl::flat_hash_map<std::string, PjRtDeviceAttribute>&
154 Attributes() const = 0;
155 };
156
157 // Forward declaration.
158 class PjRtBuffer;
159
160 // Helper struct for cross host transfers, returned by the callback from a call
161 // to PjRtBuffer::MakeCrossHostReceiveBuffers or
162 // PjRtBuffer::MakeCrossHostReceiveBuffersForGather.
163 struct PjRtCrossHostRecvDescriptors {
164 // There is one serialized_descriptor per sub-buffer being gathered (i.e. a
165 // single descriptor if the buffer is returned from a call to
166 // MakeCrossHostReceiveBuffers). The descriptor should be transmitted to the
167 // sender(s) and passed to a call to src_buffer->CopyToRemoteDevice.
168 absl::InlinedVector<std::string, 1> serialized_descriptors;
169 };
170 // Function that the client should call at the receiver if it needs to cancel a
171 // cross-host send, for example because the buffer that the remote host wanted
172 // to send is not available. The serialized descriptor should match one of the
173 // descriptors returned in a PjRtCrossHostRecvDescriptors. on_canceled will be
174 // called once cancellation is complete and indicates whether cancellation was
175 // successful or not.
176 //
177 // For each serialized_descriptor provided in a PjRtCrossHostRecvDescriptors,
178 // *either* the sending host must successfully complete a CopyToRemoteDevice
179 // for that descriptor, *or* the receiving host must cancel. If there is a
180 // duplicate (e.g., both send and cancel) then the system will be left in an
181 // undefined state. If there is no send or cancellation then the system will
182 // hang indefinitely.
183 using PjRtCrossHostSendCancelNotifier =
184 std::function<void(absl::string_view serialized_descriptor, Status reason,
185 std::function<void(Status)> on_canceled)>;
186 // State asynchronously returned by MakeCrossHostReceiveBuffers. "descriptors"
187 // will match the returned PjRtBuffer objects 1:1. Specifically, each PjRtBuffer
188 // returned by MakeCrossHostReceiveBuffers will have one
189 // PjRtCrossHostRecvDescriptors object containing it descriptor(s).
190 struct PjRtCrossHostRecvState {
191 std::vector<PjRtCrossHostRecvDescriptors> descriptors;
192 PjRtCrossHostSendCancelNotifier cancel_notifier;
193 };
194 using PjRtCrossHostRecvNotifier =
195 std::function<void(StatusOr<PjRtCrossHostRecvState>)>;
196
197 // Provides configuration for implementations that support compile and execute
198 // spanning multiple slices. A slice is a set of devices connected by dedicated
199 // high speed interconnect. Connectivity between slices is typically over data
200 // center networks. Concrete implementations of MultiSliceConfig contain
201 // environment specific information to enable communication between devices on
202 // different slices. Passed as options during compile and execute.
203 // Implementations that do not support this are allowed to pass nullptr.
204 class MultiSliceConfig {
205 public:
206 virtual ~MultiSliceConfig();
207
208 // Returns the total number of slices.
209 virtual int32_t NumSlices() const = 0;
210
211 // Returns the SliceID at this host - an integer in [0, NumSlices)
212 virtual int32_t SliceId() const = 0;
213
214 // Returns the number of devices on each slice indexed by SliceId.
215 virtual absl::flat_hash_map<int32_t, int32_t> NumDevicesPerSlice() const = 0;
216 };
217
218 struct CompileOptions {
219 // The layouts of the arguments that the computation should expect.
220 std::optional<std::vector<Shape>> argument_layouts;
221
222 // If true, the supplied computation expects its arguments to be wrapped in a
223 // tuple and passed as a single parameter.
224 bool parameter_is_tupled_arguments = false;
225
226 // XLA's compilation time options.
227 ExecutableBuildOptions executable_build_options;
228
229 // If true, the executable can be run on any device. May only be true if
230 // !executable_build_options.has_device_assignment(), so only applies to
231 // single-device executables. Beware: on GPUs, sometimes an executable
232 // compiled for one device doesn't run on another.
233 bool compile_portable_executable = false;
234
235 // XLA compilation profile version.
236 int64_t profile_version = 0;
237
238 // Set multi_slice_config to trigger compilation for DCN connected multi
239 // slice operation.
240 const MultiSliceConfig* multi_slice_config = nullptr;
241
242 // Serialize the CompileOptions into a CompileOptionsProto.
243 StatusOr<CompileOptionsProto> ToProto() const;
244 };
245
246 StatusOr<CompileOptions> CompileOptionsFromProto(
247 const CompileOptionsProto& input);
248
249 // A sized chunk of host data. The host data can be either in host layout or in
250 // device layout, and it can be one part of the entire buffer. The PjRt
251 // implementations can customize how the memory is allocated and deallocated.
252 class PjRtChunk {
253 public:
254 // Allocate a PjRtChunk using malloc.
AllocateDefault(size_t size)255 static PjRtChunk AllocateDefault(size_t size) {
256 return PjRtChunk(malloc(size), size, [](void* ptr) { free(ptr); });
257 }
258
259 PjRtChunk() = default;
PjRtChunk(void * data,size_t size,std::function<void (void *)> deleter)260 PjRtChunk(void* data, size_t size, std::function<void(void*)> deleter)
261 : data_(static_cast<uint8_t*>(data)),
262 size_(size),
263 deleter_(std::move(deleter)) {}
264
~PjRtChunk()265 ~PjRtChunk() {
266 if (data_) {
267 deleter_(data_);
268 }
269 }
270
PjRtChunk(PjRtChunk && other)271 PjRtChunk(PjRtChunk&& other)
272 : data_(other.data_),
273 size_(other.size_),
274 deleter_(std::move(other.deleter_)) {
275 other.data_ = nullptr;
276 }
277 PjRtChunk& operator=(PjRtChunk&& other) {
278 if (data_) {
279 deleter_(data_);
280 }
281 data_ = other.data_;
282 size_ = other.size_;
283 deleter_ = std::move(other.deleter_);
284 other.data_ = nullptr;
285 return *this;
286 }
287
288 PjRtChunk(const PjRtChunk&) = delete;
289 PjRtChunk& operator=(const PjRtChunk&) = delete;
290
data()291 uint8_t* data() { return data_; }
data()292 const uint8_t* data() const { return data_; }
size()293 int64_t size() const { return size_; }
294
295 private:
296 // The ownership of the bytes pointed to by `data_` is controlled by the
297 // `deleter_`.
298 uint8_t* data_ = nullptr;
299 size_t size_ = 0;
300 std::function<void(void*)> deleter_;
301 };
302
303 // A stream of Chunks from the host to the device. Once the stream enters
304 // Complete state it never changes state again.
305 //
306 // This class is thread-safe.
307 class CopyToDeviceStream {
308 public:
CopyToDeviceStream(int64_t total_bytes,int64_t granule_bytes)309 explicit CopyToDeviceStream(int64_t total_bytes, int64_t granule_bytes)
310 : total_bytes_(total_bytes), granule_bytes_(granule_bytes) {}
311
312 // Emplaces a new Chunk of data to copy to the device. Returns a non-OK status
313 // if the Chunk's size causes the amount of transferred data to exceed
314 // total_bytes() or if the stream is already complete.
315 //
316 // The size of the chunk must be a multiple of granule_bytes().
317 // TODO(jmolloy): Enforce the granule size.
318 Status AddChunk(PjRtChunk chunk);
319
320 // Returns the total amount of data the stream expects to be transferred.
total_bytes()321 int64_t total_bytes() const { return total_bytes_; }
322
323 // Returns the granule size in bytes. The size of the chunk added to this
324 // stream must be a multiple of this number.
granule_size_in_bytes()325 int64_t granule_size_in_bytes() const { return granule_bytes_; }
326
327 // Returns the amount of data the stream currently has either transferred or
328 // has buffered to transfer.
current_bytes()329 int64_t current_bytes() const {
330 absl::MutexLock lock(&mu_);
331 return current_bytes_;
332 }
333
334 // Returns true if the stream is complete; all expected bytes have been
335 // transferred or are buffered to transfer.
IsComplete()336 bool IsComplete() const {
337 absl::MutexLock lock(&mu_);
338 return current_bytes_ == total_bytes_;
339 }
340
341 // Returns true if the stream is empty; no data has been queued.
empty()342 bool empty() const { return current_bytes() == 0; }
343
344 // Consumes the next chunk. If no chunks remain, returns nullopt. Blocks
345 // until a chunk is available.
346 std::optional<PjRtChunk> ConsumeNextChunk();
347
348 // Members are protected to allow subclassing for mocking in tests.
349 protected:
350 int64_t total_bytes_;
351 int64_t granule_bytes_;
352 int64_t current_bytes_ ABSL_GUARDED_BY(mu_) = 0;
353 std::deque<PjRtChunk> buffered_chunks_ ABSL_GUARDED_BY(mu_);
354 mutable absl::Mutex mu_;
355 };
356
357 class PjRtHostMemoryForDeviceManager {
358 public:
359 virtual ~PjRtHostMemoryForDeviceManager();
360
361 // Transforms the host memory representations of a shape with the host layout
362 // to the host memory representation of the same shape with the device layout.
363 // `src_shape` and `dst_shape` may only differ in their layouts.
364 virtual StatusOr<PjRtChunk> ToDeviceLayout(const void* src_data,
365 size_t src_size,
366 const Shape& host_shape,
367 const Shape& device_shape) = 0;
368
369 // Transforms the host memory representations of a shape with the device
370 // layout to the host memory representation of the same shape with the host
371 // layout. `src_shape` and `dst_shape` may only differ in their layouts.
372 virtual Status ToHostLayout(const void* src_data, size_t src_size,
373 const Shape& src_shape, void* dst_data,
374 size_t dst_size, const Shape& dst_shape) = 0;
375 };
376
377 class PjRtLoadedExecutable;
378
379 // Encapsulates the state of Python session with XLA.
380 //
381 // It is the responsibility of the client of this API to keep the PjRtClient
382 // alive as long as any of the other runtime objects are alive.
383 //
384 // A note on the semantics of cross-device copies.
385 //
386 // There are two mechanisms to transfer a buffer from one device to another.
387 // When both devices are on the same host (more specifically, the user program
388 // ends up with pointers to both the source and destination buffers in the same
389 // address space), the caller can use:
390 // dst_buffer = src_buffer->CopyToDevice(dst_device)
391 //
392 // When the source and destination are on different hosts, but the transfer is
393 // made via native device networking (as opposed to the user program fetching
394 // the buffer and sending it using its own networking code), the caller can
395 // use:
396 // DstHost: dst_client->MakeCrossHostReceiveBuffers(...)
397 // DstHost: [...]
398 // DstHost: gets callback containing PjRtCrossHostRecvDescriptors
399 // DstHost: sends cross-host recv serialized descriptors to SrcHost
400 // SrcHost: src_buffer->CopyToRemoteDevice(serialized_descriptors)
401 //
402 // Note that in the cross-host case, the dst_client may call
403 // MakeCrossHostReceiveBuffers before the action that produces src_buffer has
404 // been enqueued at SrcHost.
405 //
406 // On some platforms, device-to-device transfers consume scarce hardware
407 // resources. If dst_client->MakeCrossHostReceiveBuffers immediately claimed
408 // those resources, then there would be a risk of system-wide deadlock, if the
409 // resources claimed by the recv prevented other transfers that are necessary
410 // to generate src_buffer from acquiring enough resources to proceed.
411 //
412 // In order to allow clients to avoid deadlocks such as those in the preceding
413 // paragraph, PjRtClient guarantees progress but not fairness with respect to
414 // the order that cross-device transfers are enqueued on a given host, as
415 // follows:
416 //
417 // The progress guarantee is that a cross-device transfer T on host A will not
418 // claim scarce hardware resources until it is guaranteed that all transfers
419 // enqueued on A before T have already either completed, or been assigned enough
420 // resources to ensure that they can eventually complete.
421 //
422 // The lack of a fairness guarantee means that, if cross-device transfer T1 is
423 // enqueued before transfer T2 at A, then T2 may complete before T1. T1 may be
424 // delayed for an unbounded time waiting for T2 if T2 is large, even though T1
425 // will eventually be able to make progress.
426 class PjRtClient {
427 public:
428 PjRtClient() = default;
PjRtClient(std::unique_ptr<PjRtHostMemoryForDeviceManager> host_memory_for_device_manager)429 explicit PjRtClient(std::unique_ptr<PjRtHostMemoryForDeviceManager>
430 host_memory_for_device_manager)
431 : host_memory_for_device_manager_(
432 std::move(host_memory_for_device_manager)) {}
433
434 virtual ~PjRtClient() = default;
435
436 // Return the process index of this client. Always 0 in single-process
437 // settings.
438 virtual int process_index() const = 0;
439
440 // Return the number of devices in the entire computation. In multi-headed
441 // client setting, some are addressable by this client, some are not. In a
442 // single-client setting, this is equal to the number of addressable devices.
443 virtual int device_count() const = 0;
444
445 // Return number of addressable devices. Addressable devices are those that
446 // the client can issue commands to.
447 virtual int addressable_device_count() const = 0;
448
449 // Return all devices known to the client, including addressable and
450 // non-addressable devices.
451 virtual absl::Span<PjRtDevice* const> devices() const = 0;
452
453 // Return only addressable devices. The devices are in no particular order.
454 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
455
456 // Lookup any PjRtDevice for a given PjRtDevice::id().
457 virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
458
459 // Return an addressable PjRtDevice for a given
460 // PjRtDevice::local_hardware_id().
461 virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
462 int local_hardware_id) const = 0;
463
464 // Return an ID that identifies the platform (CPU/GPU/TPU).
465 virtual PjRtPlatformId platform_id() const = 0;
466
467 // Returns a string that identifies the platform (CPU/GPU/TPU).
468 virtual absl::string_view platform_name() const = 0;
469
470 // Returns a string containing human-readable, platform-specific version info
471 // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
472 virtual absl::string_view platform_version() const = 0;
473
474 // Returns an enum that identifies the type of runtime being used under this
475 // client.
476 virtual PjRtRuntimeType runtime_type() const = 0;
477
478 // Return a device-specific default device assignment, e.g., GPU and TPU may
479 // be different.
480 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
481 int num_replicas, int num_partitions) const = 0;
482
483 // Returns a device-specific default device assignment for multi-slice system.
484 // If num_replicas_per_slice is not defined (nullopt) then we assume that
485 // all the partitions live entirely on a single slice and that all cross slice
486 // communication happens across replicas assuming then that
487 // num_replicas_per_slice is going to be "num_replicas / num_slices".
488 // TODO(zhangqiaorjc): Convert this to pure virtual and push down.
GetDefaultDeviceAssignment(int num_replicas,std::optional<int> num_replicas_per_slice,int num_partitions,const MultiSliceConfig * multi_slice_config)489 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
490 int num_replicas, std::optional<int> num_replicas_per_slice,
491 int num_partitions, const MultiSliceConfig* multi_slice_config) const {
492 return Unimplemented("Multi slice device assignment is not supported.");
493 }
494
495 // Returns a backend-specific HLO cost analysis visitor.
496 virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0;
497
498 // Compile `computation` with given `options`.
499 virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
500 const XlaComputation& computation, CompileOptions options) = 0;
501
502 // Variant of `Compile` that accepts an MLIR module.
503 virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
504 mlir::ModuleOp module, CompileOptions options) = 0;
505
506 // Generates a unique fingerprint for `executable`, may be std::nullopt.
507 virtual StatusOr<std::optional<std::string>> ExecutableFingerprint(
508 const PjRtLoadedExecutable& executable) const = 0;
509
510 // Returns a platform-specific serialization of `executable`. The
511 // serialization is not guaranteed to be stable over time. `executable` must
512 // have been produced by this client.
513 virtual StatusOr<std::string> SerializeExecutable(
514 const PjRtLoadedExecutable& executable) const = 0;
515
516 // Deserializes a serialized executable as produced by
517 // SerializeExecutable(). `serialized` must have been produced by a client of
518 // the same platform and version as this one.
519 virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
520 absl::string_view serialized, CompileOptions options) = 0;
521
522 // LoadSerializedExecutable takes the serialized output of PjRtExecutable. The
523 // returned executable is loaded by this client. The same checks are made as
524 // in Load that the serialized executable is compatible with the client.
525 // LoadSerializedExecutable will materialize CompileOptions from within the
526 // serialized executable unlike 'DeserializeExecutable' above that accepts
527 // CompileOptions.
528 virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
LoadSerializedExecutable(absl::string_view serialized)529 LoadSerializedExecutable(absl::string_view serialized) const {
530 return Unimplemented("Loading serialized executable not supported.");
531 }
532
533 // Loads the executable returns aa PjRtLoadedExecutable runnable by this
534 // client. Returns an error if the PjRtExecutable was created with an
535 // incompatible topology or client.
536 // PjRtExecutable contains a copy of the CompileOptions that was used to
537 // generate the executable. Load will use the CompileOptions from within the
538 // executable.
Load(std::unique_ptr<PjRtExecutable> executable)539 virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
540 std::unique_ptr<PjRtExecutable> executable) {
541 return Unimplemented("Loading executable not supported.");
542 }
543
544 // Creates a buffer on the device without initializing or copying any data.
545 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
546 const Shape& shape, PjRtDevice* device) = 0;
547
548 // A client may want to create a buffer, and hand the buffer to other PjRt
549 // methods, before the data to store in the buffer is available to the client.
550 // This is supported using CreateBuffersForAsyncTransfer, which returns an
551 // AsyncBufferTransferManager helper object.
552 //
553 // The PjRtBuffers can be retrieved from the AsyncBufferTransferManager and
554 // safely passed immediately to downstream PjRt method calls. Subsequently the
555 // client can call methods on the AsyncBufferTransferManager object to copy
556 // data into the buffers, and once the data copies are complete, the buffers'
557 // definition events will automatically become ready, unblocking downstream
558 // consumers of the buffers.
559 //
560 // A single call to CreateBuffersForAsyncTransfer creates a "batch" of buffers
561 // that share a single definition event, which may amortize some performance
562 // overheads, but means that none of the buffers are available to downstream
563 // consumers until all the transfers have completed. Multiple calls to
564 // CreateBuffersForAsyncTransfer should be made if it is desirable for buffers
565 // to become available as soon as transfers into them complete.
566
567 // Helper class to all clients to asynchronously transfer data into buffers
568 // that are created uninitialized, see comments immediately above.
569 class AsyncBufferTransferManager {
570 public:
571 virtual ~AsyncBufferTransferManager() = default;
572
573 // Returns the number of buffers managed by this object.
574 virtual size_t buffer_count() const = 0;
575
576 // Returns the destination device of the transfers.
577 virtual PjRtDevice* device() const = 0;
578
579 // Returns buffer_index, which can be passed to downstream consumers
580 // immediately and will become available once transfers complete. May not
581 // be called more than once for a given buffer_index.
582 //
583 // RetrieveBuffer can be called at any convenient time; transfer methods
584 // can safely be called for a buffer index after RetrieveBuffer has been
585 // called.
586 virtual std::unique_ptr<PjRtBuffer> RetrieveBuffer(int buffer_index) = 0;
587
588 // Transfers 'literal' into buffer_index. No transfer calls into
589 // buffer_index can be made after this call. on_done is called when the
590 // transfer is complete but before the buffers are made available to
591 // their consumers. 'literal' must remain in scope until on_done is
592 // called.
593 virtual Status TransferLiteralToBuffer(int buffer_index,
594 const LiteralSlice& literal,
595 std::function<void()> on_done) = 0;
596
597 // Returns the on-device size in bytes of buffer buffer_index.
598 virtual size_t buffer_size(int buffer_index) const = 0;
599
600 // Transfers 'data' into buffer_index. 'data' must be already laid out in
601 // the correct on-device format, for example returned by a call to
602 // buffer->CopyRawToHost. No transfer calls into buffer_index can be made
603 // after this call. on_done is called when the transfer is complete but
604 // before the buffers are made available to their consumers. 'data' must
605 // remain in scope until on_done is called.
606 virtual Status TransferRawDataToBuffer(int buffer_index,
607 absl::string_view data,
608 std::function<void()> on_done) = 0;
609
610 // Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
611 // length transfer_size. 'data' must be already laid out in the correct
612 // on-device format, for example returned by a call to
613 // buffer->CopyRawToHost. If is_last_transfer is false then the buffer
614 // remains unavailable to consumers after the transfer completes. If
615 // is_last_transfer is true then the buffer becomes available to consumers
616 // after the transfer completes, and no transfer calls into buffer_index can
617 // be made after this call. on_done is called when the transfer is complete
618 // but before the buffers are made available to their consumers. 'data' must
619 // remain in scope until on_done is called.
620 virtual Status TransferRawDataToSubBuffer(
621 int buffer_index, const void* data, int64_t offset,
622 int64_t transfer_size, bool is_last_transfer,
623 std::function<void()> on_done) = 0;
624
625 // Indicates that a client error occurred and the transfers will never
626 // complete. Puts all buffers in an error state. For the stream executor
627 // client, since error states are not well supported, this triggers a fatal
628 // error.
629 //
630 // SetTransferError may be called at most once, and may not be called unless
631 // at least one buffer has not yet had its final transfer initiated.
632 virtual void SetTransferError(Status error) = 0;
633
634 // Adds the specified key/value metadata for the transfer operation.
635 // This is typically used for debugging purposes, such as adding a handle
636 // that can be used to identify transfer operations.
637 using TransferMetadata = absl::flat_hash_map<std::string, std::string>;
638 virtual void AddTransferMetadata(const TransferMetadata& metadata) = 0;
639 };
640
641 // Returns a manager for async transfers into a set of buffers with on-host
642 // shapes 'shapes'.
643 virtual StatusOr<std::unique_ptr<AsyncBufferTransferManager>>
644 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
645 PjRtDevice* device) = 0;
646
647 // Describes the semantics the caller to BufferFromHostBuffer expects from the
648 // runtime, in a total order from most restrictive to least restrictive.
649 enum class HostBufferSemantics {
650 // The runtime may not hold references to `data` after the call to
651 // `BufferFromHostBuffer` completes. The caller promises that `data` is
652 // immutable and will not be freed only for the duration of the
653 // BufferFromHostBuffer call. `on_done_with_host_buffer` will be called
654 // before `BufferFromHostBuffer` returns.
655 kImmutableOnlyDuringCall,
656
657 // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
658 // returns while the runtime completes a transfer to the device. The caller
659 // promises not to mutate or free `data` until the transfer completes, at
660 // which point the runtime will call `on_done_with_host_buffer`. It is also
661 // correct to wait on the host (directly or indirectly) for the buffer's
662 // definition event to complete.
663 kImmutableUntilTransferCompletes,
664
665 // The PjRtBuffer may alias `data` internally and the runtime may use the
666 // `data` contents as long as the buffer is alive. The caller promises to
667 // keep `data` alive and not to mutate its contents as long as the buffer is
668 // alive; to notify the caller that the buffer may be freed, the runtime
669 // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On
670 // non-CPU platforms this acts identically to
671 // kImmutableUntilTransferCompletes.
672 kZeroCopy,
673 };
674
675 // on_done_with_host_buffer is optional and may be null.
676 // on_done_with_host_buffer will be called iff an OK status is returned.
677 //
678 // `data` points to the backing array of the host buffer. Caution:
679 // `byte_strides` are allowed to be negative, in which case `data` may need
680 // to point to the interior of the buffer, not necessarily its start.
681 //
682 // If byte_strides is omitted, the array is assumed to have a dense layout
683 // with dimensions in major-to-minor order.
684 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
685 const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
686 std::optional<absl::Span<int64_t const>> byte_strides,
687 HostBufferSemantics host_buffer_semantics,
688 std::function<void()> on_done_with_host_buffer, PjRtDevice* device) = 0;
689
690 // Note that literal must remain in scope until the transfer has completed, so
691 // the caller should, for example, wait for GetReadyFuture().Await()
692 // completes on the return value before letting literal go out of scope.
693 virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
694 const LiteralSlice& literal, PjRtDevice* device) = 0;
695
696 // Creates a PjRtBuffer that is a non-owned view of an on-device
697 // buffer (typically allocated by another library).
698 // on_delete_callback is called when the PjRtBuffer is done with the on-device
699 // buffer. The buffer may be mutated, for example, if the buffer is donated
700 // to an Execute operation.
701 // TODO(phawkins): Currently this API assumes the buffer is ready to use
702 // immediately on the device. Extend it to support, for example, waiting for a
703 // CUDA stream/event.
704 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
705 void* device_ptr, const Shape& shape, PjRtDevice* device,
706 std::function<void()> on_delete_callback) = 0;
707
708 // Returns platform-dependent address for the given buffer that is often but
709 // not guaranteed to be the physical/device address.
710 virtual StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer);
711
712 // Returns a vector of PjRtBuffers that can be used to receive
713 // cross host transfers using `client` on `device'. Asynchronously calls
714 // `notifier` once receive descriptors are ready to be communicated to the
715 // sender. `shapes` must be the exact shapes, with identical layouts,
716 // corresponding to the buffers that will be sent. When resources for the
717 // transfer are available, notifier will be called with a vector of
718 // PjRtCrossHostRecvDescriptors structs, one for each shape in `shapes`. Each
719 // struct contains an opaque string that should be transmitted to the sending
720 // host and used in a call to CopyToRemoteDevice. None of the recv buffers
721 // will become ready until *all* of the sends have completed.
722 //
723 // If MakeCrossHostReceiveBuffers returns an error, then `notifier` will not
724 // be called. Otherwise `notifier` will be called exactly once. In the case
725 // where `notifier` is called with an error status, then the PjRtBuffers
726 // returned by MakeCrossHostReceiveBuffers will never yield data.
727 //
728 // See note on semantics of cross-device copies in the class definition
729 // comment for PjRtClient.
730 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
731 MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
732 PjRtDevice* device,
733 PjRtCrossHostRecvNotifier notifier) = 0;
734
735 // Asynchronously makes a vector of PjRtBuffers that can be used to receive
736 // cross host transfers, as in MakeCrossHostReceiveBuffers above, however
737 // each buffer expects to be "gathered" using multiple sends, one for each of
738 // a set of sub-slices of the destination buffer.
739 //
740 // For each value in shapes there is a corresponding FullGatherDetails struct
741 // that describes the sub-slices.
742 struct GatherDetails {
743 // The dimensions of the corresponding buffer that the gather slices
744 // into. These dimensions must be the major dimensions in the on-device
745 // layout of the buffer, and must all be untiled. The scatter acts as if
746 // the buffer were transposed/reshaped so that all of these dimensions were
747 // combined into a single dimension whose size is the product of the
748 // dimensions, and the slice indices correspond to indices in that single
749 // combined dimension.
750 //
751 // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
752 // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
753 // treated as if it were shape [12, 128, 128] and the indices in
754 // slice_boundaries range in [0, 12].
755 absl::InlinedVector<int, 3> dimensions;
756 // The cumulative indices in dimension of the slices. For example, if
757 // shape.dimensions(dimension)==10, setting slice_boundaries to {2, 5, 10}
758 // would correspond to 3 slices of sizes {2, 3, 5} respectively. If the last
759 // entry in slice_boundaries is less than the size of the combined gather
760 // dimension, the trailing data in the buffer is undefined after the receive
761 // completes.
762 std::vector<int64_t> slice_boundaries;
763 };
764 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
765 MakeCrossHostReceiveBuffersForGather(
766 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
767 PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) = 0;
768
769 // Create ChannelHandles for XLA send/recv.
770 virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
771 virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
772 virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
773
774 // TODO(zhangqiaorjc): Experimental API to be removed.
775 // Defragment device memory.
776 virtual Status Defragment() = 0;
777
778 // Return the PjRtHostMemoryForDeviceManager for this client. It can be
779 // nullptr if the implementation does not provide one.
GetPjRtHostMemoryForDeviceManager()780 PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() const {
781 return host_memory_for_device_manager_.get();
782 }
783
784 private:
785 std::unique_ptr<PjRtHostMemoryForDeviceManager>
786 host_memory_for_device_manager_;
787 };
788
789 // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
790 // can be either valid or invalid. An invalid buffer is one that has never been
791 // initialized, or a buffer that has been deleted (e.g., by calling Delete, or
792 // by donating it to a computation that aliases an input parameter to an
793 // output). We allow PjRtBuffer objects to outlive the underlying device
794 // buffers so we can decouple buffer lifetimes from the corresponding Python
795 // references if needed. Thread-safe.
796 class PjRtBuffer {
797 public:
798 virtual ~PjRtBuffer() = default;
799
800 virtual const Shape& on_device_shape() const = 0;
801
802 // Same as on_device_shape when the shape is static. When the shape is
803 // dynamic, it gathers the metadata from the device and returns a static shape
804 // representing the logical shape of the data. This approach is identical to
805 // how tensorflow and xrt setup the output buffer in the graph.
806 //
807 // Since this method actually acquires locks and communicate with the device,
808 // it does not have the const qualifier, similar to what ToLiteral does.
809 virtual StatusOr<Shape> logical_on_device_shape() = 0;
810 virtual PjRtDevice* device() const = 0;
811 virtual PjRtClient* client() const = 0;
812
813 // ExternalReference is a potentially long-lived reference held while a buffer
814 // is being shared by an external framework, e.g., NumPy. A client acquires an
815 // external reference by calling PjRtBuffer::AcquireExternalReference() and
816 // releases it by deleting the ExternalReference. The external framework
817 // should not modify the underlying buffer unless it is confident via its own
818 // synchronization that modifications do not race with reads from the
819 // PjRtBuffer.
820 class ExternalReference {
821 public:
822 virtual ~ExternalReference() = 0;
823 // Return opaque device memory pointer to root buffer.
OpaqueDeviceMemoryDataPointer()824 void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; }
825
826 protected:
827 void* data_ptr_;
828 };
829 virtual StatusOr<std::unique_ptr<ExternalReference>>
830 AcquireExternalReference() = 0;
831
832 // Asynchronously copies the buffer's value into `literal`.
833 //
834 // Return value is a future the caller can use to discover when the copy has
835 // completed. The transfer respects the layout of `literal`; to specify a
836 // particular layout, set the layout before calling `ToLiteral`.
837 virtual PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) = 0;
838
839 // Copies the buffer's value into `literal`. Calls `on_ready` when the value
840 // (or an error) is ready. The transfer respects the layout of `literal`; to
841 // specify a particular layout, set the layout before calling `ToLiteral`.
842 ABSL_DEPRECATED("Use ToLiteral(...).OnReady() instead")
ToLiteral(MutableLiteralBase * literal,std::function<void (Status)> on_ready)843 void ToLiteral(MutableLiteralBase* literal,
844 std::function<void(Status)> on_ready) {
845 ToLiteral(literal).OnReady(std::move(on_ready));
846 }
847
848 // Synchronous overload of ToLiteral, as a convenience.
ToLiteralSync(MutableLiteralBase * literal)849 Status ToLiteralSync(MutableLiteralBase* literal) {
850 absl::Notification done;
851 Status status;
852 ToLiteral(literal, [&](Status s) {
853 status = std::move(s);
854 done.Notify();
855 });
856 done.WaitForNotification();
857 return status;
858 }
859
860 // Convenience synchronous overload that allocates a literal with a default
861 // layout.
ToLiteralSync()862 StatusOr<std::shared_ptr<Literal>> ToLiteralSync() {
863 auto literal = std::make_shared<Literal>(
864 ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
865 TF_RETURN_IF_ERROR(ToLiteralSync(literal.get()));
866 return literal;
867 }
868
869 // Returns the number of bytes of the buffer storage on the device.
870 virtual StatusOr<size_t> GetOnDeviceSizeInBytes() const = 0;
871
872 // Transfers a sub-range of the on-device representation of the buffer.
873 // offset+transfer_size must be less than GetOnDeviceSizeInBytes. The
874 // returned future transitions to ready on error, or after the transfer has
875 // completed.
876 virtual PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
877 int64_t transfer_size) = 0;
878
879 // Drops the buffer's reference to its associated device memory, leaving the
880 // buffer in an invalid state. The memory will be freed lazily when all async
881 // operations using the buffer have completed, according to the allocation
882 // semantics of the underlying platform. Delete may briefly block if another
883 // thread is in the process of enqueuing an operation on this buffer, but it
884 // will never block for a stream operation to complete. If an external
885 // framework holds a reference to the TrackedDeviceBuffer via
886 // GetBufferWithExternalReference, the memory will not be freed until the
887 // external framework drops the reference.
888 virtual void Delete() = 0;
889
890 // Similar to Delete, drops the buffer's reference to its associated device
891 // memory, leaving the buffer in an invalid state, but transfers the device
892 // memory ownership out via an ExternalReference rather than
893 // freeing the device memory, so that another framework can take ownership of
894 // it. A return value of nullptr indicates that PjRtBuffer has been
895 // deleted. The buffer returned from Release may be safely dropped at any time
896 // even if it still has pending async operations. The client should call
897 // GetReadyFuture().Await before calling ReleaseDeviceMemoryOwnership with
898 // wait_for_operations_to_complete=false, to ensure that the host has
899 // synchronized past any outstanding write operations to the buffer. If
900 // wait_for_operations_to_complete=true the host will block until any
901 // potentially outstanding asynchronous operations have completed before
902 // returning, in which case it is safe to read or mutate the returned buffer.
903 // If the buffer was shared via an external reference it is the client's
904 // responsibility that accesses via that reference do not interfere with
905 // accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
906 virtual StatusOr<std::unique_ptr<ExternalReference>>
907 ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
908
909 // True if and only if Delete or Release has previously been called.
910 virtual bool IsDeleted() = 0;
911
912 // Copies the buffer to device `dst_device`, performing a d2d transfer when
913 // `dst_device` is sharing the same Client, and performing a d2h and h2d copy
914 // if `dst_device` lives on a different Client.
915 // Returns an error if the buffer is already on dst_device.
916 //
917 // See note on semantics of cross-device copies in the class definition
918 // comment for PjRtClient.
919 virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
920 PjRtDevice* dst_device) = 0;
921
922 // Copies the buffer to the remote device encoded in serialized_descriptor.
923 // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
924 // remote host's destination device. MakeCrossHostReceiveBuffers takes an
925 // array of shapes to construct the destination buffers, and a callback
926 // supplies an array containing both the destination buffers, and a serialized
927 // descriptor for each buffer. For each destination buffer there should be a
928 // matching call to src->CopyToRemoteDevice on a remote host for a src buffer
929 // of the corresponding shape. serialized_descriptor is the string returned by
930 // the callback along with the corresponding destination buffer.
931 //
932 // When the send either completes or fails, on_done will be called. If
933 // status is Ok then it is guaranteed that sends_were_enqueued==true.
934 // Otherwise, if sends_were_enqueued==false then the sender should contact
935 // the receiver out of band to request cancellation of the transfer. If
936 // !status.ok() and sends_were_enqueued==true then it is not possible to
937 // determine whether the transfer succeeded and the system is in an
938 // undefined state. This undefined state almost certainly indicates an
939 // unrecoverable hardware error.
940 //
941 // See note on semantics of cross-device copies in the class definition
942 // comment for PjRtClient.
943 using RemoteSendCallback =
944 std::function<void(Status status, bool sends_were_enqueued)>;
945 virtual void CopyToRemoteDevice(absl::string_view serialized_descriptor,
946 RemoteSendCallback on_done) = 0;
947 struct ScatterDetails {
948 // The dimensions of the corresponding buffer that the scatter slices
949 // across. These dimensions must be the major dimensions in the on-device
950 // layout of the buffer, and must all be untiled. The scatter acts as if
951 // the buffer were transposed/reshaped so that all of these dimensions were
952 // combined into a single dimension whose size is the product of the
953 // dimensions, and the slice indices correspond to indices in that single
954 // combined dimension.
955 //
956 // For example, if the shape is [3, 4, 128, 128] with [3, 4] as the major
957 // dimensions in the layout, and dimensions = {0, 1}, then the buffer is
958 // treated as if it were shape [12, 128, 128] and the indices in slices
959 // range in [0, 12].
960 absl::InlinedVector<int, 3> dimensions;
961 // The start and end indices of the slices.
962 std::vector<std::pair<int64_t, int64_t>> slices;
963 };
964 virtual void CopyToRemoteDeviceScattered(
965 absl::Span<const std::pair<std::string, RemoteSendCallback>>
966 serialized_descriptors_and_callbacks,
967 const ScatterDetails& scatter_details) = 0;
968
969 // Returns a future that can be used to discover when the data in the
970 // PjRtBuffer has been computed, or an error has occurred.
971 //
972 // TODO(b/241967811): change these weird semantics
973 // If the buffer has been deleted or donated the returned future will
974 // immediately hold an error, however if GetReadyFuture() is called before
975 // the buffer has been deleted or donated then the returned future will stay
976 // valid (will not transition to error as a consequence of buffer deletion)
977 // even if the buffer is subsequently donated or deleted.
978 virtual PjRtFuture<Status> GetReadyFuture() = 0;
979
980 // Blocks the host until the buffer's value has been computed and is ready for
981 // immediate use on the device. Useful in particular for timing benchmarks.
982 ABSL_DEPRECATED("Use GetReadyFuture().Await() instead")
BlockHostUntilReady()983 Status BlockHostUntilReady() {
984 auto s = GetReadyFuture().Await();
985 // Fix up error string because some clients rely on it.
986 if (!s.ok() && s.error_message() ==
987 "GetReadyFuture() called on deleted or donated buffer") {
988 return InvalidArgument(
989 "BlockHostUntilReady() called on deleted or donated buffer");
990 }
991 return s;
992 }
993
994 // Calls callback when the buffer is ready.
995 //
996 // buf->OnReady(callback);
997 //
998 // is semantically almost identical to:
999 //
1000 // ForkThread([]() { callback(buf->Await()); });
1001 //
1002 // the only difference being that the callback may happen immediately on the
1003 // calling thread. (The implementation may also be more efficient.)
1004 //
1005 // The interface makes no assumptions about what thread calls callback, so the
1006 // caller must ensure that callback returns quickly and hands off long-running
1007 // work or any blocking operation to a caller-managed threadpool.
1008 ABSL_DEPRECATED("Use GetReadyFuture().OnReady() instead")
OnReady(std::function<void (Status)> callback)1009 void OnReady(std::function<void(Status)> callback) {
1010 return GetReadyFuture().OnReady(std::move(callback));
1011 }
1012
1013 // Whether this buffer is on CPU and thus allows for certain optimizations.
1014 virtual bool IsOnCpu() const = 0;
1015 };
1016
1017 class ExecuteContext {
1018 public:
1019 virtual ~ExecuteContext() = default;
1020 };
1021
1022 struct PjRtTransferMetadata {
1023 Shape device_shape;
1024 };
1025
1026 struct SendCallback {
1027 int64_t channel_id;
1028 // The callback for retrieving the send value. It will be invoked once for
1029 // each invocation of the corresponding Send op in the HLO program (So it can
1030 // be invoked multiple times if it is in a loop). Currently there is no
1031 // guarantee that the callback here will be invoked in the same order as their
1032 // corresponding HLO Send ops. The callback can also return errors to indicate
1033 // the execution should fail.
1034 //
1035 // IMPORTANT: the implementation might NOT signal the error to the execution,
1036 // and the execution will run to completion with UNDEFINED DATA returned by
1037 // the callback. If there is any potential control flow that depends on the
1038 // value of the returned data, an error return is unsafe.
1039 //
1040 // TODO(chky): Currently the callback invocation order may not be consistent
1041 // with the HLO send op invocation order, due to limitations in some PjRt
1042 // implementation. Consider making it strictly the same order as HLO program.
1043 std::function<Status(const PjRtTransferMetadata& metadata, PjRtChunk chunk,
1044 size_t total_size_in_bytes, bool done)>
1045 callback;
1046 };
1047
1048 struct RecvCallback {
1049 int64_t channel_id;
1050 // The callback for feeding the recv value. It will be invoked once for each
1051 // invocation of the corresponding Recv op in the HLO program (So it can be
1052 // invoked multiple times if it is in a loop). Currently there is no
1053 // guarantee that the callback here will be invoked in the same order as their
1054 // corresponding HLO Recv ops.
1055 std::function<void(const PjRtTransferMetadata& metadata,
1056 CopyToDeviceStream& stream)>
1057 callback;
1058 };
1059
1060 struct ExecuteOptions {
1061 // If true, the client must pass a single PjRtBuffer which contains all of
1062 // the arguments as a single XLA tuple, otherwise each argument must be
1063 // passed in its own PjRtBuffer. May only be true if the executable was
1064 // compiled with parameter_is_tupled_arguments==true.
1065 bool arguments_are_tupled = false;
1066 // If true, the computation must return a tuple, which will be destructured
1067 // into its elements.
1068 bool untuple_result = false;
1069 // If non-zero, identifies this execution as part of a potentially
1070 // multi-device launch. This can be used to detect scheduling errors, e.g. if
1071 // multi-host programs are launched in different orders on different hosts,
1072 // the launch IDs may be used by the runtime to detect the mismatch.
1073 int32_t launch_id = 0;
1074 // If non-null, an opaque context passed to an execution that may be used to
1075 // supply additional arguments to a derived class of PjRtExecutable.
1076 const ExecuteContext* context = nullptr;
1077 // If true, check that the PjRtBuffer argument shapes match the compiled
1078 // shapes. Otherwise, any shape with the right size on device may be passed.
1079 bool strict_shape_checking = true;
1080
1081 // Set multi_slice_config when the computation spans multiple slices. The
1082 // config should match what was used during compilation to generate this
1083 // executable.
1084 const MultiSliceConfig* multi_slice_config = nullptr;
1085
1086 // The send/recv callbacks for PjRt execution. The first level span is for
1087 // multi-device parallel execution, the second level vector contains the
1088 // callbacks for all send/recv ops in the executable. These callbacks can be
1089 // stateful and the user code is responsible for managing the states here.
1090 // These callbacks must outlive the execution.
1091 absl::Span<const std::vector<SendCallback>> send_callbacks;
1092 absl::Span<const std::vector<RecvCallback>> recv_callbacks;
1093
1094 // The `execution_mode` decides whether the execution will be invoked in the
1095 // caller thread or launched to a separate thread. By default, the
1096 // implementation may choose either strategy or use a heuristic to decide.
1097 // Currently it is only applied to CPU implementations
1098 enum class ExecutionMode { kDefault = 0, kSynchronous, kAsynchronous };
1099 ExecutionMode execution_mode = ExecutionMode::kDefault;
1100 };
1101
1102 // Represents a compiled computation that can be executed given handles to
1103 // device-allocated literals. If any input/output alias has been specified in
1104 // the computation, the parameter containing the input buffer will be donated
1105 // when passed to the execution.
1106 class PjRtLoadedExecutable : public PjRtExecutable {
1107 public:
1108 virtual ~PjRtLoadedExecutable() = default;
1109
1110 virtual PjRtClient* client() const = 0;
1111
1112 virtual const DeviceAssignment& device_assignment() const = 0;
1113
1114 // The replica and partition indices of device_assignment to be run by this
1115 // client. On single-host platforms without partitioning, this is all replicas
1116 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
1117 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a
1118 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
1119 struct LogicalDeviceIds {
1120 int replica;
1121 int partition;
1122 };
1123 virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
1124 const = 0;
1125
1126 // An addressable_device is one which the client can issue commands to.
1127 // addressable_devices()[i] is the Device to which
1128 // addressable_device_logical_ids()[i] is assigned.
1129 virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
1130
1131 // Donation Semantics:
1132 //
1133 // The following Execute*() methods will donate the input buffer to the
1134 // execution if it is specified in the executable. Donation is usually
1135 // implemented as a transaction: it is acquired in the begining and committed
1136 // when the device execution is successully launched. Concurrent donations
1137 // might either block or return failures.
1138 //
1139 // TODO(chky): It is generally desired that concurrent donations do not block,
1140 // as it otherwise results in deadlock easily. Consider always returning
1141 // failure on concurrent donations.
1142
1143 // Executes on devices addressable by the client. Requires executable has a
1144 // device_assignment and all devices in the device_assignment are addressable
1145 // by the client.
1146 //
1147 // `argument_handles` is `[num_devices, num_args]`.
1148 //
1149 // If returned_futures.has_value():
1150 // if Execute does not return an error status:
1151 // *returned_futures will be resized to be the same length as the return
1152 // vector, and each future will become ready once the corresponding device
1153 // execute has completed.
1154 // else:
1155 // *returned_futures is undefined.
1156 //
1157 // The caller is *NOT* required to ensure that PjRtLoadedExecutable stays
1158 // alive until futures are ready.
1159 virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
1160 Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1161 const ExecuteOptions& options,
1162 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) = 0;
1163 // Convenience wrapper for Execute that never returns futures.
Execute(absl::Span<const std::vector<PjRtBuffer * >> argument_handles,const ExecuteOptions & options)1164 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
1165 absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
1166 const ExecuteOptions& options) {
1167 std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
1168 return Execute(std::move(argument_handles), options, returned_futures);
1169 }
1170
1171 // Execute the assigned replica/partition on a given `device`. Requires
1172 // executable has a device_assignment, `device` is present in the
1173 // device_assignment and addressable by the client.
1174 //
1175 // If fill_future is true:
1176 // if ExecuteSharded does not return an error status:
1177 // returned_future will be filled with a future that will become ready
1178 // once the execution has completed.
1179 // else:
1180 // returned_future will not be modified.
1181 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1182 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1183 const ExecuteOptions& options,
1184 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) = 0;
1185 // Convenience wrapper for ExecuteSharded that always returns a future.
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future)1186 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1187 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1188 const ExecuteOptions& options,
1189 std::optional<PjRtFuture<Status>>& returned_future) {
1190 return ExecuteSharded(std::move(argument_handles), device, options,
1191 returned_future, /*fill_future=*/true);
1192 }
1193 // Convenience wrapper for ExecuteSharded that never returns a future.
ExecuteSharded(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)1194 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
1195 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1196 const ExecuteOptions& options) {
1197 std::optional<PjRtFuture<Status>> returned_future;
1198 return ExecuteSharded(std::move(argument_handles), device, options,
1199 returned_future, /*fill_future=*/false);
1200 }
1201
1202 // Execute on a given `device`. Requires `device` to be addressable by client.
1203 // Requires executable has exactly 1 replica and 1 partition and no
1204 // device_assignment (thus portable).
1205 //
1206 // If fill_future is true:
1207 // if ExecutePortable does not return an error status:
1208 // returned_future will be filled with a future that will become ready
1209 // once the execution has completed.
1210 // else:
1211 // returned_future will not be modified.
1212 virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1213 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1214 const ExecuteOptions& options,
1215 std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) = 0;
1216 // Convenience wrapper for ExecutePortable that always returns a future.
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options,std::optional<PjRtFuture<Status>> & returned_future)1217 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1218 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1219 const ExecuteOptions& options,
1220 std::optional<PjRtFuture<Status>>& returned_future) {
1221 return ExecutePortable(std::move(argument_handles), device, options,
1222 returned_future, /*fill_future=*/true);
1223 }
1224 // Convenience wrapper for ExecutePortable that never returns a future.
ExecutePortable(absl::Span<PjRtBuffer * const> argument_handles,PjRtDevice * device,const ExecuteOptions & options)1225 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
1226 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
1227 const ExecuteOptions& options) {
1228 std::optional<PjRtFuture<Status>> returned_future;
1229 return ExecutePortable(std::move(argument_handles), device, options,
1230 returned_future, /*fill_future=*/false);
1231 }
1232
1233 // Asynchronously free resources after the last execution completes.
1234 virtual void Delete() = 0;
1235
1236 // True if on-device resources associated with the executable are freed.
1237 virtual bool IsDeleted() = 0;
1238
1239 // True if the `returned_futures` output parameter is supported in the
1240 // Execute*() methods.
1241 //
1242 // TODO(b/240696624): Although the PjRt interface require `returned_futures`
1243 // to be resized correctly if it is not nullopt, some implementation does not
1244 // implement this. So we have to check whether returned_futures is empty.
1245 // Remove this method once the implementation is fixed.
IsReturnedFutureSupported()1246 virtual bool IsReturnedFutureSupported() const { return false; }
1247
1248 protected:
1249 // Value returned internally from routines that enqueue an execution,
1250 // combining the result buffers with a future that becomes ready when the
1251 // execution completes.
1252 struct Result {
1253 std::optional<PjRtFuture<Status>> future;
1254 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
1255 };
1256 };
1257
1258 } // namespace xla
1259
1260 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
1261