1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 24 #include "absl/base/thread_annotations.h" 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/compiler/xla/client/executable_build_options.h" 27 #include "tensorflow/compiler/xla/client/xla_computation.h" 28 #include "tensorflow/compiler/xla/layout.h" 29 #include "tensorflow/compiler/xla/literal.h" 30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 31 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" 32 #include "tensorflow/compiler/xla/pjrt/semaphore.h" 33 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h" 34 #include "tensorflow/compiler/xla/pjrt/transpose.h" 35 #include "tensorflow/compiler/xla/pjrt/worker_thread.h" 36 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 37 #include "tensorflow/compiler/xla/service/computation_placer.h" 38 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" 39 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" 40 #include "tensorflow/compiler/xla/service/executable.h" 41 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 42 #include "tensorflow/compiler/xla/service/hlo_module_util.h" 43 #include "tensorflow/compiler/xla/statusor.h" 44 #include "tensorflow/compiler/xla/xla_data.pb.h" 45 #include "tensorflow/core/profiler/lib/traceme.h" 46 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 47 #include "tfrt/host_context/host_context.h" // from @tf_runtime 48 49 namespace xla { 50 51 class TfrtCpuDevice final : public PjRtDevice { 52 public: 53 TfrtCpuDevice(int id, bool asynchronous); 54 SetClient(PjRtClient * client)55 void SetClient(PjRtClient* client) { 56 CHECK(client_ == nullptr); 57 client_ = client; 58 } 59 client()60 PjRtClient* client() const override { return client_; } 61 IsAddressable()62 bool IsAddressable() const override { 63 return process_index() == client()->process_index(); 64 } 65 id()66 int id() const override { return id_; } 67 process_index()68 int process_index() const override { return 0; } 69 70 // Used as `device_ordinal`. local_hardware_id()71 int local_hardware_id() const override { return id_; } 72 73 absl::string_view device_kind() const override; 74 75 absl::string_view DebugString() const override; 76 77 absl::string_view ToString() const override; 78 79 Status TransferToInfeed(const LiteralSlice& literal) override; 80 81 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; 82 83 // Returns a semaphore for admission control on inflight computations. max_inflight_computations_semaphore()84 Semaphore& max_inflight_computations_semaphore() { 85 return max_inflight_computations_semaphore_; 86 } 87 CreateAsyncTrackingEvent(absl::string_view description)88 std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent( 89 absl::string_view description) const override { 90 return nullptr; 91 } 92 Attributes()93 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes() 94 const override { 95 return attributes_; 96 } 97 98 private: 99 int id_; 100 PjRtClient* client_ = nullptr; 101 std::string debug_string_; 102 std::string to_string_; 103 104 // TODO(zhangqiaorjc): Optimize semaphore related overhead. 105 // Semaphore used to limit how many programs can be enqueued by the host 106 // ahead of the device. 107 Semaphore max_inflight_computations_semaphore_; 108 absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_ = {}; 109 }; 110 111 class TfrtCpuExecutable; 112 113 class TfrtCpuClient final : public PjRtClient { 114 public: 115 TfrtCpuClient(int process_index, 116 std::vector<std::unique_ptr<TfrtCpuDevice>> devices, 117 std::unique_ptr<tfrt::HostContext> host_ctx); 118 ~TfrtCpuClient(); 119 process_index()120 int process_index() const override { return process_index_; } 121 device_count()122 int device_count() const override { return devices_.size(); } 123 addressable_device_count()124 int addressable_device_count() const override { 125 return addressable_devices_.size(); 126 } 127 devices()128 absl::Span<PjRtDevice* const> devices() const override { return devices_; } 129 addressable_devices()130 absl::Span<PjRtDevice* const> addressable_devices() const override { 131 return addressable_devices_; 132 } 133 134 StatusOr<PjRtDevice*> LookupDevice(int device_id) const override; 135 136 StatusOr<PjRtDevice*> LookupAddressableDevice( 137 int local_hardware_id) const override; 138 platform_id()139 PjRtPlatformId platform_id() const override { 140 return tensorflow::Fingerprint64(CpuName()); 141 } 142 platform_name()143 absl::string_view platform_name() const override { return CpuName(); } 144 platform_version()145 absl::string_view platform_version() const override { return "<unknown>"; } 146 runtime_type()147 PjRtRuntimeType runtime_type() const override { return kTfrt; } 148 149 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 150 int num_replicas, int num_partitions) const override; 151 152 StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override; 153 154 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 155 const XlaComputation& computation, CompileOptions options) override; 156 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 157 mlir::ModuleOp module, CompileOptions options) override; 158 159 StatusOr<std::optional<std::string>> ExecutableFingerprint( 160 const PjRtLoadedExecutable& executable) const override; 161 SerializeExecutable(const PjRtLoadedExecutable & executable)162 StatusOr<std::string> SerializeExecutable( 163 const PjRtLoadedExecutable& executable) const override { 164 return Unimplemented("SerializeExecutable not implemented on %s", 165 platform_name()); 166 } 167 DeserializeExecutable(absl::string_view serialized,CompileOptions options)168 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable( 169 absl::string_view serialized, CompileOptions options) override { 170 return Unimplemented("DeserializeExecutable not implemented on %s", 171 platform_name()); 172 } 173 174 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 175 const Shape& shape, PjRtDevice* device) override; 176 177 StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>> CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)178 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes, 179 PjRtDevice* device) override { 180 return Unimplemented("Async transfer to buffers not implemented"); 181 }; 182 183 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 184 const void* data, PrimitiveType type, absl::Span<int64_t const> dims, 185 std::optional<absl::Span<int64_t const>> byte_strides, 186 HostBufferSemantics host_buffer_semantics, 187 std::function<void()> on_done_with_host_buffer, 188 PjRtDevice* device) override; 189 190 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 191 const LiteralSlice& literal, PjRtDevice* device) override; 192 193 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)194 MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes, 195 PjRtDevice* device, 196 PjRtCrossHostRecvNotifier notifier) override { 197 return Unimplemented("MakeCrossHostReceiveBuffers not implemented."); 198 } 199 200 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> MakeCrossHostReceiveBuffersForGather(absl::Span<const Shape> shapes,std::vector<GatherDetails> gather_details,PjRtDevice * device,PjRtCrossHostRecvNotifier notifier)201 MakeCrossHostReceiveBuffersForGather( 202 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details, 203 PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { 204 return Unimplemented( 205 "MakeCrossHostReceiveBuffersForGather not implemented."); 206 } 207 208 StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 209 void* device_ptr, const Shape& shape, PjRtDevice* device, 210 std::function<void()> on_delete_callback) override; 211 CreateChannelHandle()212 StatusOr<ChannelHandle> CreateChannelHandle() override { 213 return Unimplemented("CreateChannelHandle not implemented."); 214 } CreateDeviceToHostChannelHandle()215 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { 216 return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); 217 } CreateHostToDeviceChannelHandle()218 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { 219 return Unimplemented("CreateHostToDeviceChannelHandle not implemented."); 220 } 221 Defragment()222 Status Defragment() override { 223 return Unimplemented("Defragment not implemented."); 224 } 225 GetHostContext()226 tfrt::HostContext* GetHostContext() const { return host_ctx_.get(); } 227 eigen_intraop_device()228 Eigen::ThreadPoolDevice* eigen_intraop_device() const { 229 return eigen_intraop_device_.get(); 230 } 231 GetLastCollectiveLaunchEvent()232 tfrt::AsyncValueRef<CpuEvent> GetLastCollectiveLaunchEvent() { 233 absl::MutexLock lock(&mu_); 234 return last_collective_launch_event_.CopyRef(); 235 } 236 SetLastCollectiveLaunchEvent(tfrt::AsyncValueRef<CpuEvent> event)237 void SetLastCollectiveLaunchEvent(tfrt::AsyncValueRef<CpuEvent> event) { 238 absl::MutexLock lock(&mu_); 239 last_collective_launch_event_ = std::move(event); 240 } 241 242 private: 243 int process_index_; 244 // Includes all devices, including non-addressable devices. 245 std::vector<std::unique_ptr<TfrtCpuDevice>> owned_devices_; 246 // Pointers to `owned_devices_`. 247 std::vector<PjRtDevice*> devices_; 248 // Maps Device::id() to the corresponding Device. Includes all devices. 249 absl::flat_hash_map<int, TfrtCpuDevice*> id_to_device_; 250 // Addressable devices indexed by core_id. 251 std::vector<PjRtDevice*> addressable_devices_; 252 std::unique_ptr<tfrt::HostContext> host_ctx_; 253 std::unique_ptr<ComputationPlacer> computation_placer_; 254 255 // TODO(zhangqiaorjc): Use tfrt::compat::EigenHostContextThreadPool. 256 std::unique_ptr<tensorflow::thread::ThreadPool> eigen_intraop_pool_; 257 std::unique_ptr<Eigen::ThreadPoolDevice> eigen_intraop_device_; 258 259 // Launching collectives are prone to deadlock when we use fixed-sized 260 // threadpools since ExecuteHelper will block until all replicas reach the 261 // barrier. We ensure that 262 // 1. Threadpool size is at least as large as device_count so one collective 263 // launch over all devices can succeed. 264 // 2. Gang-schedule each collective by conservatively ensuring a total order 265 // of collectives and launching only one collective at a time to avoid 266 // having no active threads to make progress 267 // TODO(zhangqiaorjc): Explore alternatives that allow multiple concurrent 268 // collectives. 269 mutable absl::Mutex mu_; 270 tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event_ 271 ABSL_GUARDED_BY(mu_); 272 273 // A cache for transpose plans. We use transposes to convert 274 // (possibly strided) buffers provided to BufferFromHostBuffer into dense 275 // major-to-minor layout. 276 absl::Mutex transpose_mu_; 277 TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); 278 }; 279 280 class TfrtCpuBuffer final : public PjRtBuffer { 281 public: 282 TfrtCpuBuffer( 283 Shape on_device_shape, 284 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer, 285 TfrtCpuClient* client, TfrtCpuDevice* device); 286 ~TfrtCpuBuffer() override; 287 288 TfrtCpuBuffer(const TfrtCpuBuffer&) = delete; 289 TfrtCpuBuffer(TfrtCpuBuffer&&) = delete; 290 TfrtCpuBuffer& operator=(const TfrtCpuBuffer&) = delete; 291 TfrtCpuBuffer& operator=(TfrtCpuBuffer&&) = delete; 292 on_device_shape()293 const Shape& on_device_shape() const override { return on_device_shape_; } device()294 TfrtCpuDevice* device() const override { return device_; } client()295 TfrtCpuClient* client() const override { return client_; } 296 297 StatusOr<Shape> logical_on_device_shape() override; 298 299 StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() 300 override; 301 302 StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( 303 bool wait_for_operations_to_complete) override; 304 305 using PjRtBuffer::ToLiteralSync; 306 PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override; 307 308 StatusOr<size_t> GetOnDeviceSizeInBytes() const override; 309 CopyRawToHost(void * dst,int64_t offset,int64_t transfer_size)310 PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset, 311 int64_t transfer_size) override { 312 return PjRtFuture<Status>(Unimplemented("CopyRawToHost not implemented")); 313 } 314 315 void Delete() override; 316 317 bool IsDeleted() override; 318 319 StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 320 PjRtDevice* dst_device) override; 321 CopyToRemoteDevice(absl::string_view serialized_descriptor,RemoteSendCallback on_done)322 void CopyToRemoteDevice(absl::string_view serialized_descriptor, 323 RemoteSendCallback on_done) override { 324 on_done(Unimplemented("CopyToRemoteDevice not implemented."), 325 /*sends_were_enqueued=*/false); 326 } 327 CopyToRemoteDeviceScattered(absl::Span<const std::pair<std::string,RemoteSendCallback>> serialized_descriptors_and_callbacks,const ScatterDetails & scatter_details)328 void CopyToRemoteDeviceScattered( 329 absl::Span<const std::pair<std::string, RemoteSendCallback>> 330 serialized_descriptors_and_callbacks, 331 const ScatterDetails& scatter_details) override { 332 for (const auto& d_and_cb : serialized_descriptors_and_callbacks) { 333 d_and_cb.second( 334 Unimplemented("CopyToRemoteDeviceScattered not implemented."), 335 /*sends_were_enqueued=*/false); 336 } 337 } 338 339 PjRtFuture<Status> GetReadyFuture() override; 340 IsOnCpu()341 bool IsOnCpu() const override { return true; } 342 343 private: IsEmptyTuple()344 bool IsEmptyTuple() const { 345 return on_device_shape_.IsTuple() && 346 on_device_shape_.tuple_shapes_size() == 0; 347 } 348 349 StatusOr<tfrt::AsyncValueRef<Literal>> CopyToHostAsyncInternal( 350 bool discard_cached_copy, std::optional<xla::Layout> layout); 351 352 // Acquires the device buffer for shared read-only usages, and it also adds 353 // the `usage_event` to it. Any donation event in the future is expected to be 354 // serialized after all the usage events added through this method. Returns 355 // nullptr if the buffer is already donated or there is outstanding external 356 // references. 357 TrackedTfrtCpuDeviceBuffer* AcquireUsage( 358 tfrt::AsyncValueRef<CpuEvent> usage_event); 359 360 // A helper class for managing a pending donation. It should be committed upon 361 // success. Otherwise, the donated buffer is returned to the TfrtCpuBuffer. 362 class DonationTransaction { 363 public: DonationTransaction(TfrtCpuBuffer * buffer,std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer)364 explicit DonationTransaction( 365 TfrtCpuBuffer* buffer, 366 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer) 367 : buffer_(buffer), device_buffer_(std::move(device_buffer)) { 368 CHECK(buffer_); 369 } 370 DonationTransaction(const DonationTransaction&) = delete; 371 DonationTransaction& operator=(const DonationTransaction&) = delete; 372 DonationTransaction(DonationTransaction&&) = default; 373 DonationTransaction& operator=(DonationTransaction&& other) { 374 Abort(); 375 376 buffer_ = other.buffer_; 377 device_buffer_ = std::move(other.device_buffer_); 378 return *this; 379 } 380 ~DonationTransaction()381 ~DonationTransaction() { Abort(); } 382 383 // Commit the donation. The rvalue ref qualifier is used to ensure the 384 // semantic that it can be committed at most once. Commit()385 void Commit() && { 386 buffer_->CommitDonation(); 387 device_buffer_.reset(); 388 } 389 device_buffer()390 TrackedTfrtCpuDeviceBuffer* device_buffer() const { 391 return device_buffer_.get(); 392 } 393 394 private: Abort()395 void Abort() { 396 if (device_buffer_) buffer_->AbortDonation(std::move(device_buffer_)); 397 } 398 399 TfrtCpuBuffer* buffer_ = nullptr; 400 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer_; 401 }; 402 403 // Acquires the device buffer for exclusive donation. The caller of this 404 // method is expected to use the usage events and definition events to 405 // serialize this donation with previous usages. After this method is called, 406 // calls to AcquireUsage() will fail. Returns error status if the buffer is 407 // already donated or there is outstanding external references. 408 StatusOr<DonationTransaction> AcquireDonation(); 409 DropExternalReference()410 void DropExternalReference() { 411 absl::MutexLock lock(&mu_); 412 CHECK_GT(external_reference_counter_, 0); 413 --external_reference_counter_; 414 } 415 416 // Commits the pending donation by setting `pending_donation_` to false. 417 // `pending_donation_` must be true before calling this method. 418 void CommitDonation(); 419 420 // Aborts the pending donation by returning the donated buffer, and setting 421 // `pending_donation_` to false. `pending_donation_` must be true before 422 // calling this method. 423 void AbortDonation(std::unique_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer); 424 425 // Similar to Delete, drops the buffer's reference to its associated device 426 // memory, leaving the buffer in an invalid state, but returns the 427 // TrackedTfrtCpuDeviceBuffer rather than freeing the device memory, so that 428 // another framework can take ownership of it. The buffer returned from 429 // Release may be safely dropped at any time even if it still has pending 430 // async operations. The client should call Await before calling Release with 431 // wait_for_operations_to_complete=false, to ensure that the host has 432 // synchronized past any outstanding write operations to the buffer. If 433 // wait_for_operations_to_complete=true the host will block until any 434 // potentially outstanding asynchronous operations have completed before 435 // returning, in which case it is safe to read or mutate the returned buffer. 436 // If the buffer was shared via an external reference it is the client's 437 // responsibility that accesses via that reference do not interfere with 438 // accesses via the buffer returned from Release. 439 StatusOr<std::unique_ptr<TrackedTfrtCpuDeviceBuffer>> Release( 440 bool wait_for_operations_to_complete); 441 442 // Releases the device buffer by returning a unique_ptr of it. If there is 443 // outstanding donation or usage holds, this method blocks until those holds 444 // are commited or dropped. 445 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> ReleaseBufferLocked() 446 ABSL_LOCKS_EXCLUDED(mu_); 447 448 TfrtCpuClient* client_; 449 const Shape on_device_shape_; 450 TfrtCpuDevice* const device_; 451 452 mutable absl::Mutex mu_; 453 std::unique_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer_ 454 ABSL_GUARDED_BY(mu_); 455 // Count of external references on the buffer. 456 int external_reference_counter_ ABSL_GUARDED_BY(mu_) = 0; 457 458 // `pending_donation_` indicates whether a donation is pending. The destructor 459 // of the TfrtCpuBuffer will wait for a pending donation, as the donation 460 // might fail. Note that concurrent calls to AcquireUsage() and 461 // AcquireDonation() might fail even if the pending donation is aborted later. 462 bool pending_donation_ ABSL_GUARDED_BY(mu_) = false; 463 464 friend class TfrtCpuClient; 465 friend class TfrtCpuExecutable; 466 }; 467 468 class TfrtCpuExecutable final : public PjRtLoadedExecutable { 469 public: 470 TfrtCpuExecutable( 471 int num_replicas, int num_partitions, 472 std::shared_ptr<DeviceAssignment> device_assignment, 473 bool parameter_is_tupled_arguments, 474 std::unique_ptr<Executable> cpu_executable, 475 BufferAllocation::Index result_buffer_index, 476 absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices, 477 std::vector<LogicalDeviceIds> addressable_device_logical_ids, 478 std::vector<PjRtDevice*> addressable_devices, TfrtCpuClient* client); 479 480 ~TfrtCpuExecutable() override = default; 481 client()482 TfrtCpuClient* client() const override { return client_; } 483 name()484 absl::string_view name() const override { 485 return cpu_executable_->shared_module()->name(); 486 } 487 num_replicas()488 int num_replicas() const override { return num_replicas_; } 489 num_partitions()490 int num_partitions() const override { return num_partitions_; } 491 SizeOfGeneratedCodeInBytes()492 int64_t SizeOfGeneratedCodeInBytes() const override { 493 return cpu_executable_->SizeOfGeneratedCodeInBytes(); 494 } 495 device_assignment()496 const DeviceAssignment& device_assignment() const override { 497 return *device_assignment_; 498 } 499 addressable_device_logical_ids()500 absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 501 const override { 502 return addressable_device_logical_ids_; 503 } 504 addressable_devices()505 absl::Span<PjRtDevice* const> addressable_devices() const override { 506 return addressable_devices_; 507 } 508 GetHloModules()509 StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 510 const override { 511 return std::vector<std::shared_ptr<HloModule>>{ 512 cpu_executable_->shared_module()}; 513 } 514 515 using PjRtLoadedExecutable::Execute; 516 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( 517 absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 518 const ExecuteOptions& options, 519 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) 520 override; 521 522 using PjRtLoadedExecutable::ExecuteSharded; 523 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 524 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 525 const ExecuteOptions& options, 526 std::optional<PjRtFuture<Status>>& returned_future, 527 bool fill_future) override; 528 529 using PjRtLoadedExecutable::ExecutePortable; 530 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 531 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 532 const ExecuteOptions& options, 533 std::optional<PjRtFuture<Status>>& returned_future, 534 bool fill_future) override; 535 536 void Delete() override; 537 538 bool IsDeleted() override; 539 IsReturnedFutureSupported()540 bool IsReturnedFutureSupported() const override { return true; } 541 542 StatusOr<std::optional<std::string>> Fingerprint() const; 543 544 private: 545 friend class TfrtCpuClient; 546 547 Status SetUpDonation(bool tuple_inputs); 548 549 // Checks that the input buffers passed in by the user have the correct size 550 // on device for the compiled program. 551 Status CheckBufferCompatibilities( 552 absl::Span<TrackedTfrtCpuDeviceBuffer* const> input_buffers) const; 553 554 StatusOr<Result> ExecuteHelper( 555 absl::Span<PjRtBuffer* const> argument_handles, int replica, 556 int partition, const RunId& run_id, const ExecuteOptions& options, 557 tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event, 558 bool fill_future, TfrtCpuDevice* device = nullptr); 559 560 TfrtCpuClient* client_; 561 562 int num_replicas_; 563 int num_partitions_; 564 std::shared_ptr<DeviceAssignment> device_assignment_; 565 bool parameter_is_tupled_arguments_; 566 567 std::shared_ptr<Executable> cpu_executable_; 568 569 // Caching `result_buffer_index_` and `result_buffer_indices_` to avoid lookup 570 // HLO dataflow analysis data structures in program execution critical path. 571 572 // Buffer allocation index corresponding to root buffer buffer. 573 BufferAllocation::Index result_buffer_index_; 574 // Buffer allocation indices corresponding to each result buffer leaf buffer. 575 absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices_; 576 577 // Size on device of each leaf buffer of the compiled program, cached here 578 // for performance reasons. 579 std::vector<int64_t> input_buffer_sizes_in_bytes_; 580 581 // A sorted vector of parameters that have any aliased buffers and thus must 582 // be donated when executing the computation. 583 std::vector<int> parameters_that_must_be_donated_; 584 585 // The replica and partition indices of device_assignment_ to be run by this 586 // client. On single-host platforms without partitioning, this is all 587 // replicas (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may 588 // not be the case on multi-host platforms. If there are 4 replicas and 2 589 // partitions on a single host platform, size of 590 // addressable_device_logical_ids_ is 4*2 = 8. 591 std::vector<LogicalDeviceIds> addressable_device_logical_ids_; 592 593 // addressable_devices_[i] is the Device to which 594 // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of 595 // unique_ptrs to play well with the Python bindings (see xla.cc). 596 std::vector<PjRtDevice*> addressable_devices_; 597 598 // Cached result of comparing HloCostAnalysis FLOP estimate for execute 599 // critical path. 600 bool cheap_computation_; 601 }; 602 603 // Creates a CPU client with one Device. For testing purposes, you can set the 604 // number of devices passing the --xla_force_host_platform_device_count flag to 605 // the XLA_FLAGS environment variable. 606 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous); 607 608 // Similar to the function above, but you can set the number of devices 609 // explicitly. 610 StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous, 611 int cpu_device_count); 612 613 } // namespace xla 614 615 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ 616