xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/base/thread_annotations.h"
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/compiler/xla/client/executable_build_options.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/layout.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
31 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h"
32 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
33 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h"
34 #include "tensorflow/compiler/xla/pjrt/transpose.h"
35 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
36 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
40 #include "tensorflow/compiler/xla/service/executable.h"
41 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
42 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
47 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
48 
49 namespace xla {
50 
51 class TfrtCpuDevice final : public PjRtDevice {
52  public:
53   TfrtCpuDevice(int id, bool asynchronous);
54 
SetClient(PjRtClient * client)55   void SetClient(PjRtClient* client) {
56     CHECK(client_ == nullptr);
57     client_ = client;
58   }
59 
client()60   PjRtClient* client() const override { return client_; }
61 
IsAddressable()62   bool IsAddressable() const override {
63     return process_index() == client()->process_index();
64   }
65 
id()66   int id() const override { return id_; }
67 
process_index()68   int process_index() const override { return 0; }
69 
70   // Used as `device_ordinal`.
local_hardware_id()71   int local_hardware_id() const override { return id_; }
72 
73   absl::string_view device_kind() const override;
74 
75   absl::string_view DebugString() const override;
76 
77   absl::string_view ToString() const override;
78 
79   Status TransferToInfeed(const LiteralSlice& literal) override;
80 
81   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
82 
83   // Returns a semaphore for admission control on inflight computations.
max_inflight_computations_semaphore()84   Semaphore& max_inflight_computations_semaphore() {
85     return max_inflight_computations_semaphore_;
86   }
87 
CreateAsyncTrackingEvent(absl::string_view description)88   std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
89       absl::string_view description) const override {
90     return nullptr;
91   }
92 
Attributes()93   const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
94       const override {
95     return attributes_;
96   }
97 
98  private:
99   int id_;
100   PjRtClient* client_ = nullptr;
101   std::string debug_string_;
102   std::string to_string_;
103 
104   // TODO(zhangqiaorjc): Optimize semaphore related overhead.
105   // Semaphore used to limit how many programs can be enqueued by the host
106   // ahead of the device.
107   Semaphore max_inflight_computations_semaphore_;
108   absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_ = {};
109 };
110 
111 class TfrtCpuExecutable;
112 
113 class TfrtCpuClient final : public PjRtClient {
114  public:
115   TfrtCpuClient(int process_index,
116                 std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
117                 std::unique_ptr<tfrt::HostContext> host_ctx);
118   ~TfrtCpuClient();
119 
process_index()120   int process_index() const override { return process_index_; }
121 
device_count()122   int device_count() const override { return devices_.size(); }
123 
addressable_device_count()124   int addressable_device_count() const override {
125     return addressable_devices_.size();
126   }
127 
devices()128   absl::Span<PjRtDevice* const> devices() const override { return devices_; }
129 
addressable_devices()130   absl::Span<PjRtDevice* const> addressable_devices() const override {
131     return addressable_devices_;
132   }
133 
134   StatusOr<PjRtDevice*> LookupDevice(int device_id) const override;
135 
136   StatusOr<PjRtDevice*> LookupAddressableDevice(
137       int local_hardware_id) const override;
138 
platform_id()139   PjRtPlatformId platform_id() const override {
140     return tensorflow::Fingerprint64(CpuName());
141   }
142 
platform_name()143   absl::string_view platform_name() const override { return CpuName(); }
144 
platform_version()145   absl::string_view platform_version() const override { return "<unknown>"; }
146 
runtime_type()147   PjRtRuntimeType runtime_type() const override { return kTfrt; }
148 
149   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
150       int num_replicas, int num_partitions) const override;
151 
152   StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override;
153 
154   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
155       const XlaComputation& computation, CompileOptions options) override;
156   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
157       mlir::ModuleOp module, CompileOptions options) override;
158 
159   StatusOr<std::optional<std::string>> ExecutableFingerprint(
160       const PjRtLoadedExecutable& executable) const override;
161 
SerializeExecutable(const PjRtLoadedExecutable & executable)162   StatusOr<std::string> SerializeExecutable(
163       const PjRtLoadedExecutable& executable) const override {
164     return Unimplemented("SerializeExecutable not implemented on %s",
165                          platform_name());
166   }
167 
DeserializeExecutable(absl::string_view serialized,CompileOptions options)168   StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
169       absl::string_view serialized, CompileOptions options) override {
170     return Unimplemented("DeserializeExecutable not implemented on %s",
171                          platform_name());
172   }
173 
174   StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
175       const Shape& shape, PjRtDevice* device) override;
176 
177   StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>>
CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)178   CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
179                                 PjRtDevice* device) override {
180     return Unimplemented("Async transfer to buffers not implemented");
181   };
182 
183   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
184       const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
185       std::optional<absl::Span<int64_t const>> byte_strides,
186       HostBufferSemantics host_buffer_semantics,
187       std::function<void()> on_done_with_host_buffer,
188       PjRtDevice* device) override;
189 
190   StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
191       const LiteralSlice& literal, PjRtDevice* device) override;
192 
193   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)194   MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
195                               PjRtDevice* device,
196                               PjRtCrossHostRecvNotifier notifier) override {
197     return Unimplemented("MakeCrossHostReceiveBuffers not implemented.");
198   }
199 
200   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)201   MakeCrossHostReceiveBuffersForGather(
202       absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
203       PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override {
204     return Unimplemented(
205         "MakeCrossHostReceiveBuffersForGather not implemented.");
206   }
207 
208   StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
209       void* device_ptr, const Shape& shape, PjRtDevice* device,
210       std::function<void()> on_delete_callback) override;
211 
CreateChannelHandle()212   StatusOr<ChannelHandle> CreateChannelHandle() override {
213     return Unimplemented("CreateChannelHandle not implemented.");
214   }
CreateDeviceToHostChannelHandle()215   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
216     return Unimplemented("CreateDeviceToHostChannelHandle not implemented.");
217   }
CreateHostToDeviceChannelHandle()218   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
219     return Unimplemented("CreateHostToDeviceChannelHandle not implemented.");
220   }
221 
Defragment()222   Status Defragment() override {
223     return Unimplemented("Defragment not implemented.");
224   }
225 
GetHostContext()226   tfrt::HostContext* GetHostContext() const { return host_ctx_.get(); }
227 
eigen_intraop_device()228   Eigen::ThreadPoolDevice* eigen_intraop_device() const {
229     return eigen_intraop_device_.get();
230   }
231 
GetLastCollectiveLaunchEvent()232   tfrt::AsyncValueRef<CpuEvent> GetLastCollectiveLaunchEvent() {
233     absl::MutexLock lock(&mu_);
234     return last_collective_launch_event_.CopyRef();
235   }
236 
SetLastCollectiveLaunchEvent(tfrt::AsyncValueRef<CpuEvent> event)237   void SetLastCollectiveLaunchEvent(tfrt::AsyncValueRef<CpuEvent> event) {
238     absl::MutexLock lock(&mu_);
239     last_collective_launch_event_ = std::move(event);
240   }
241 
242  private:
243   int process_index_;
244   // Includes all devices, including non-addressable devices.
245   std::vector<std::unique_ptr<TfrtCpuDevice>> owned_devices_;
246   // Pointers to `owned_devices_`.
247   std::vector<PjRtDevice*> devices_;
248   // Maps Device::id() to the corresponding Device. Includes all devices.
249   absl::flat_hash_map<int, TfrtCpuDevice*> id_to_device_;
250   // Addressable devices indexed by core_id.
251   std::vector<PjRtDevice*> addressable_devices_;
252   std::unique_ptr<tfrt::HostContext> host_ctx_;
253   std::unique_ptr<ComputationPlacer> computation_placer_;
254 
255   // TODO(zhangqiaorjc): Use tfrt::compat::EigenHostContextThreadPool.
256   std::unique_ptr<tensorflow::thread::ThreadPool> eigen_intraop_pool_;
257   std::unique_ptr<Eigen::ThreadPoolDevice> eigen_intraop_device_;
258 
259   // Launching collectives are prone to deadlock when we use fixed-sized
260   // threadpools since ExecuteHelper will block until all replicas reach the
261   // barrier. We ensure that
262   // 1. Threadpool size is at least as large as device_count so one collective
263   //    launch over all devices can succeed.
264   // 2. Gang-schedule each collective by conservatively ensuring a total order
265   //    of collectives and launching only one collective at a time to avoid
266   //    having no active threads to make progress
267   // TODO(zhangqiaorjc): Explore alternatives that allow multiple concurrent
268   // collectives.
269   mutable absl::Mutex mu_;
270   tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event_
271       ABSL_GUARDED_BY(mu_);
272 
273   // A cache for transpose plans. We use transposes to convert
274   // (possibly strided) buffers provided to BufferFromHostBuffer into dense
275   // major-to-minor layout.
276   absl::Mutex transpose_mu_;
277   TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_);
278 };
279 
280 class TfrtCpuBuffer final : public PjRtBuffer {
281  public:
282   TfrtCpuBuffer(
283       Shape on_device_shape,
284       std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer,
285       TfrtCpuClient* client, TfrtCpuDevice* device);
286   ~TfrtCpuBuffer() override;
287 
288   TfrtCpuBuffer(const TfrtCpuBuffer&) = delete;
289   TfrtCpuBuffer(TfrtCpuBuffer&&) = delete;
290   TfrtCpuBuffer& operator=(const TfrtCpuBuffer&) = delete;
291   TfrtCpuBuffer& operator=(TfrtCpuBuffer&&) = delete;
292 
on_device_shape()293   const Shape& on_device_shape() const override { return on_device_shape_; }
device()294   TfrtCpuDevice* device() const override { return device_; }
client()295   TfrtCpuClient* client() const override { return client_; }
296 
297   StatusOr<Shape> logical_on_device_shape() override;
298 
299   StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
300       override;
301 
302   StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
303       bool wait_for_operations_to_complete) override;
304 
305   using PjRtBuffer::ToLiteralSync;
306   PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override;
307 
308   StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
309 
CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)310   PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
311                                    int64_t transfer_size) override {
312     return PjRtFuture<Status>(Unimplemented("CopyRawToHost not implemented"));
313   }
314 
315   void Delete() override;
316 
317   bool IsDeleted() override;
318 
319   StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
320       PjRtDevice* dst_device) override;
321 
CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)322   void CopyToRemoteDevice(absl::string_view serialized_descriptor,
323                           RemoteSendCallback on_done) override {
324     on_done(Unimplemented("CopyToRemoteDevice not implemented."),
325             /*sends_were_enqueued=*/false);
326   }
327 
CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)328   void CopyToRemoteDeviceScattered(
329       absl::Span<const std::pair<std::string, RemoteSendCallback>>
330           serialized_descriptors_and_callbacks,
331       const ScatterDetails& scatter_details) override {
332     for (const auto& d_and_cb : serialized_descriptors_and_callbacks) {
333       d_and_cb.second(
334           Unimplemented("CopyToRemoteDeviceScattered not implemented."),
335           /*sends_were_enqueued=*/false);
336     }
337   }
338 
339   PjRtFuture<Status> GetReadyFuture() override;
340 
IsOnCpu()341   bool IsOnCpu() const override { return true; }
342 
343  private:
IsEmptyTuple()344   bool IsEmptyTuple() const {
345     return on_device_shape_.IsTuple() &&
346            on_device_shape_.tuple_shapes_size() == 0;
347   }
348 
349   StatusOr<tfrt::AsyncValueRef<Literal>> CopyToHostAsyncInternal(
350       bool discard_cached_copy, std::optional<xla::Layout> layout);
351 
352   // Acquires the device buffer for shared read-only usages, and it also adds
353   // the `usage_event` to it. Any donation event in the future is expected to be
354   // serialized after all the usage events added through this method. Returns
355   // nullptr if the buffer is already donated or there is outstanding external
356   // references.
357   TrackedTfrtCpuDeviceBuffer* AcquireUsage(
358       tfrt::AsyncValueRef<CpuEvent> usage_event);
359 
360   // A helper class for managing a pending donation. It should be committed upon
361   // success. Otherwise, the donated buffer is returned to the TfrtCpuBuffer.
362   class DonationTransaction {
363    public:
DonationTransaction(TfrtCpuBuffer * buffer,std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer)364     explicit DonationTransaction(
365         TfrtCpuBuffer* buffer,
366         std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer)
367         : buffer_(buffer), device_buffer_(std::move(device_buffer)) {
368       CHECK(buffer_);
369     }
370     DonationTransaction(const DonationTransaction&) = delete;
371     DonationTransaction& operator=(const DonationTransaction&) = delete;
372     DonationTransaction(DonationTransaction&&) = default;
373     DonationTransaction& operator=(DonationTransaction&& other) {
374       Abort();
375 
376       buffer_ = other.buffer_;
377       device_buffer_ = std::move(other.device_buffer_);
378       return *this;
379     }
380 
~DonationTransaction()381     ~DonationTransaction() { Abort(); }
382 
383     // Commit the donation. The rvalue ref qualifier is used to ensure the
384     // semantic that it can be committed at most once.
Commit()385     void Commit() && {
386       buffer_->CommitDonation();
387       device_buffer_.reset();
388     }
389 
device_buffer()390     TrackedTfrtCpuDeviceBuffer* device_buffer() const {
391       return device_buffer_.get();
392     }
393 
394    private:
Abort()395     void Abort() {
396       if (device_buffer_) buffer_->AbortDonation(std::move(device_buffer_));
397     }
398 
399     TfrtCpuBuffer* buffer_ = nullptr;
400     std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer_;
401   };
402 
403   // Acquires the device buffer for exclusive donation. The caller of this
404   // method is expected to use the usage events and definition events to
405   // serialize this donation with previous usages. After this method is called,
406   // calls to AcquireUsage() will fail. Returns error status if the buffer is
407   // already donated or there is outstanding external references.
408   StatusOr<DonationTransaction> AcquireDonation();
409 
DropExternalReference()410   void DropExternalReference() {
411     absl::MutexLock lock(&mu_);
412     CHECK_GT(external_reference_counter_, 0);
413     --external_reference_counter_;
414   }
415 
416   // Commits the pending donation by setting `pending_donation_` to false.
417   // `pending_donation_` must be true before calling this method.
418   void CommitDonation();
419 
420   // Aborts the pending donation by returning the donated buffer, and setting
421   // `pending_donation_` to false. `pending_donation_` must be true before
422   // calling this method.
423   void AbortDonation(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer);
424 
425   // Similar to Delete, drops the buffer's reference to its associated device
426   // memory, leaving the buffer in an invalid state, but returns the
427   // TrackedTfrtCpuDeviceBuffer rather than freeing the device memory, so that
428   // another framework can take ownership of it. The buffer returned from
429   // Release may be safely dropped at any time even if it still has pending
430   // async operations. The client should call Await before calling Release with
431   // wait_for_operations_to_complete=false, to ensure that the host has
432   // synchronized past any outstanding write operations to the buffer. If
433   // wait_for_operations_to_complete=true the host will block until any
434   // potentially outstanding asynchronous operations have completed before
435   // returning, in which case it is safe to read or mutate the returned buffer.
436   // If the buffer was shared via an external reference it is the client's
437   // responsibility that accesses via that reference do not interfere with
438   // accesses via the buffer returned from Release.
439   StatusOr<std::unique_ptr<TrackedTfrtCpuDeviceBuffer>> Release(
440       bool wait_for_operations_to_complete);
441 
442   // Releases the device buffer by returning a unique_ptr of it. If there is
443   // outstanding donation or usage holds, this method blocks until those holds
444   // are commited or dropped.
445   std::unique_ptr<TrackedTfrtCpuDeviceBuffer> ReleaseBufferLocked()
446       ABSL_LOCKS_EXCLUDED(mu_);
447 
448   TfrtCpuClient* client_;
449   const Shape on_device_shape_;
450   TfrtCpuDevice* const device_;
451 
452   mutable absl::Mutex mu_;
453   std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer_
454       ABSL_GUARDED_BY(mu_);
455   // Count of external references on the buffer.
456   int external_reference_counter_ ABSL_GUARDED_BY(mu_) = 0;
457 
458   // `pending_donation_` indicates whether a donation is pending. The destructor
459   // of the TfrtCpuBuffer will wait for a pending donation, as the donation
460   // might fail. Note that concurrent calls to AcquireUsage() and
461   // AcquireDonation() might fail even if the pending donation is aborted later.
462   bool pending_donation_ ABSL_GUARDED_BY(mu_) = false;
463 
464   friend class TfrtCpuClient;
465   friend class TfrtCpuExecutable;
466 };
467 
468 class TfrtCpuExecutable final : public PjRtLoadedExecutable {
469  public:
470   TfrtCpuExecutable(
471       int num_replicas, int num_partitions,
472       std::shared_ptr<DeviceAssignment> device_assignment,
473       bool parameter_is_tupled_arguments,
474       std::unique_ptr<Executable> cpu_executable,
475       BufferAllocation::Index result_buffer_index,
476       absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices,
477       std::vector<LogicalDeviceIds> addressable_device_logical_ids,
478       std::vector<PjRtDevice*> addressable_devices, TfrtCpuClient* client);
479 
480   ~TfrtCpuExecutable() override = default;
481 
client()482   TfrtCpuClient* client() const override { return client_; }
483 
name()484   absl::string_view name() const override {
485     return cpu_executable_->shared_module()->name();
486   }
487 
num_replicas()488   int num_replicas() const override { return num_replicas_; }
489 
num_partitions()490   int num_partitions() const override { return num_partitions_; }
491 
SizeOfGeneratedCodeInBytes()492   int64_t SizeOfGeneratedCodeInBytes() const override {
493     return cpu_executable_->SizeOfGeneratedCodeInBytes();
494   }
495 
device_assignment()496   const DeviceAssignment& device_assignment() const override {
497     return *device_assignment_;
498   }
499 
addressable_device_logical_ids()500   absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
501       const override {
502     return addressable_device_logical_ids_;
503   }
504 
addressable_devices()505   absl::Span<PjRtDevice* const> addressable_devices() const override {
506     return addressable_devices_;
507   }
508 
GetHloModules()509   StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
510       const override {
511     return std::vector<std::shared_ptr<HloModule>>{
512         cpu_executable_->shared_module()};
513   }
514 
515   using PjRtLoadedExecutable::Execute;
516   StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
517       absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
518       const ExecuteOptions& options,
519       std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
520       override;
521 
522   using PjRtLoadedExecutable::ExecuteSharded;
523   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
524       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
525       const ExecuteOptions& options,
526       std::optional<PjRtFuture<Status>>& returned_future,
527       bool fill_future) override;
528 
529   using PjRtLoadedExecutable::ExecutePortable;
530   StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
531       absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
532       const ExecuteOptions& options,
533       std::optional<PjRtFuture<Status>>& returned_future,
534       bool fill_future) override;
535 
536   void Delete() override;
537 
538   bool IsDeleted() override;
539 
IsReturnedFutureSupported()540   bool IsReturnedFutureSupported() const override { return true; }
541 
542   StatusOr<std::optional<std::string>> Fingerprint() const;
543 
544  private:
545   friend class TfrtCpuClient;
546 
547   Status SetUpDonation(bool tuple_inputs);
548 
549   // Checks that the input buffers passed in by the user have the correct size
550   // on device for the compiled program.
551   Status CheckBufferCompatibilities(
552       absl::Span<TrackedTfrtCpuDeviceBuffer* const> input_buffers) const;
553 
554   StatusOr<Result> ExecuteHelper(
555       absl::Span<PjRtBuffer* const> argument_handles, int replica,
556       int partition, const RunId& run_id, const ExecuteOptions& options,
557       tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,
558       bool fill_future, TfrtCpuDevice* device = nullptr);
559 
560   TfrtCpuClient* client_;
561 
562   int num_replicas_;
563   int num_partitions_;
564   std::shared_ptr<DeviceAssignment> device_assignment_;
565   bool parameter_is_tupled_arguments_;
566 
567   std::shared_ptr<Executable> cpu_executable_;
568 
569   // Caching `result_buffer_index_` and `result_buffer_indices_` to avoid lookup
570   // HLO dataflow analysis data structures in program execution critical path.
571 
572   // Buffer allocation index corresponding to root buffer buffer.
573   BufferAllocation::Index result_buffer_index_;
574   // Buffer allocation indices corresponding to each result buffer leaf buffer.
575   absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices_;
576 
577   // Size on device of each leaf buffer of the compiled program, cached here
578   // for performance reasons.
579   std::vector<int64_t> input_buffer_sizes_in_bytes_;
580 
581   // A sorted vector of parameters that have any aliased buffers and thus must
582   // be donated when executing the computation.
583   std::vector<int> parameters_that_must_be_donated_;
584 
585   // The replica and partition indices of device_assignment_ to be run by this
586   // client. On single-host platforms without partitioning, this is all
587   // replicas (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may
588   // not be the case on multi-host platforms. If there are 4 replicas and 2
589   // partitions on a single host platform, size of
590   // addressable_device_logical_ids_ is 4*2 = 8.
591   std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
592 
593   // addressable_devices_[i] is the Device to which
594   // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
595   // unique_ptrs to play well with the Python bindings (see xla.cc).
596   std::vector<PjRtDevice*> addressable_devices_;
597 
598   // Cached result of comparing HloCostAnalysis FLOP estimate for execute
599   // critical path.
600   bool cheap_computation_;
601 };
602 
603 // Creates a CPU client with one Device. For testing purposes, you can set the
604 // number of devices passing the --xla_force_host_platform_device_count flag to
605 // the XLA_FLAGS environment variable.
606 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous);
607 
608 // Similar to the function above, but you can set the number of devices
609 // explicitly.
610 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous,
611                                                        int cpu_device_count);
612 
613 }  // namespace xla
614 
615 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_
616