1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/container/flat_hash_set.h" 28 #include "absl/container/inlined_vector.h" 29 #include "absl/strings/string_view.h" 30 #include "absl/synchronization/mutex.h" 31 #include "absl/synchronization/notification.h" 32 #include "absl/types/span.h" 33 #include "tensorflow/compiler/xla/client/executable_build_options.h" 34 #include "tensorflow/compiler/xla/client/local_client.h" 35 #include "tensorflow/compiler/xla/client/xla_computation.h" 36 #include "tensorflow/compiler/xla/layout.h" 37 #include "tensorflow/compiler/xla/literal.h" 38 #include "tensorflow/compiler/xla/pjrt/local_device_state.h" 39 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 40 #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" 41 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" 42 #include "tensorflow/compiler/xla/pjrt/transpose.h" 43 #include "tensorflow/compiler/xla/service/computation_layout.h" 44 #include "tensorflow/compiler/xla/service/computation_placer.h" 45 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 46 #include "tensorflow/compiler/xla/service/hlo_module.h" 47 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 48 #include "tensorflow/compiler/xla/shape.h" 49 #include "tensorflow/compiler/xla/status.h" 50 #include "tensorflow/compiler/xla/statusor.h" 51 #include "tensorflow/compiler/xla/util.h" 52 #include "tensorflow/compiler/xla/xla_data.pb.h" 53 #include "tensorflow/core/framework/allocator.h" 54 #include "tensorflow/core/lib/core/status.h" 55 #include "tensorflow/core/platform/casts.h" 56 #include "tensorflow/stream_executor/stream.h" 57 58 namespace xla { 59 60 class PjRtStreamExecutorDevice : public PjRtDevice { 61 public: 62 explicit PjRtStreamExecutorDevice( 63 int id, std::unique_ptr<LocalDeviceState> local_device_state, 64 std::string device_kind, int process_index = 0) id_(id)65 : id_(id), 66 device_ordinal_( 67 local_device_state ? local_device_state->device_ordinal() : -1), 68 local_device_state_(std::move(local_device_state)), 69 process_index_(process_index), 70 device_kind_(std::move(device_kind)) {} ~PjRtStreamExecutorDevice()71 ~PjRtStreamExecutorDevice() override {} 72 73 // Must set client exactly once. SetClient(PjRtClient * client)74 void SetClient(PjRtClient* client) { 75 CHECK(client_ == nullptr); 76 client_ = client; 77 // We have to define debug_string_ and to_string_ here, because 78 // platform_name() requires client_ to be set. 79 debug_string_ = absl::StrCat(platform_name(), ":", id()); 80 to_string_ = absl::StrCat(platform_name(), "(id=", id(), ")"); 81 } 82 process_index()83 int process_index() const override { return process_index_; } 84 85 // Return `platform_id` from client. 86 PjRtPlatformId platform_id() const; 87 88 // Return `platform_name` from client. 89 absl::string_view platform_name() const; 90 client()91 PjRtClient* client() const override { return client_; } 92 id()93 int id() const override { return id_; } 94 IsAddressable()95 bool IsAddressable() const override { return device_ordinal_ != -1; } 96 local_hardware_id()97 int local_hardware_id() const override { return device_ordinal_; } 98 99 // If this is a device local to this host, returns a LocalDeviceState object 100 // that can be used to manipulate the device. Returns nullptr if the device is 101 // not local to this host. local_device_state()102 LocalDeviceState* local_device_state() const { 103 return local_device_state_.get(); 104 } 105 106 // If this is a device local to this host, returns a LocalDeviceState object 107 // that can be used to manipulate the device. Returns an error if the device 108 // is not local to this host. 109 StatusOr<LocalDeviceState*> GetLocalDeviceState() const; 110 device_kind()111 absl::string_view device_kind() const override { return device_kind_; } 112 113 absl::string_view ToString() const override; 114 115 absl::string_view DebugString() const override; 116 117 Status TransferToInfeed(const LiteralSlice& literal) override; 118 119 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; 120 CreateAsyncTrackingEvent(absl::string_view description)121 std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent( 122 absl::string_view description) const override { 123 return nullptr; 124 } 125 Attributes()126 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes() 127 const override { 128 return attributes_; 129 } 130 131 protected: 132 absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_; 133 134 private: 135 const int id_; 136 const int device_ordinal_; // -1 means not local. 137 const std::unique_ptr<LocalDeviceState> local_device_state_; 138 const int process_index_; 139 const std::string device_kind_; 140 std::string debug_string_; 141 std::string to_string_; 142 PjRtClient* client_ = nullptr; 143 }; 144 145 class PjRtStreamExecutorClient : public PjRtClient { 146 public: 147 // `allocator` may null, in which case the platform default allocator is used. 148 explicit PjRtStreamExecutorClient( 149 std::string platform_name, LocalClient* client, 150 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, 151 int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator, 152 std::unique_ptr<tensorflow::Allocator> host_memory_allocator, 153 bool should_stage_host_to_device_transfers, 154 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options); 155 ~PjRtStreamExecutorClient() override = default; 156 process_index()157 int process_index() const override { return process_index_; } 158 device_count()159 int device_count() const override { return devices_.size(); } addressable_device_count()160 int addressable_device_count() const override { 161 return addressable_devices_.size(); 162 } devices()163 absl::Span<PjRtDevice* const> devices() const override { return devices_; } addressable_devices()164 absl::Span<PjRtDevice* const> addressable_devices() const override { 165 return addressable_devices_; 166 } 167 LookupDevice(int device_id)168 StatusOr<PjRtDevice*> LookupDevice(int device_id) const override { 169 auto it = id_to_device_.find(device_id); 170 if (it != id_to_device_.end()) { 171 return it->second; 172 } 173 return InvalidArgument("No matching device found for device_id %d", 174 device_id); 175 } 176 177 StatusOr<PjRtDevice*> LookupAddressableDevice( 178 int local_hardware_id) const override; 179 platform_id()180 PjRtPlatformId platform_id() const override { return platform_id_; } platform_name()181 absl::string_view platform_name() const override { return platform_name_; } platform_version()182 absl::string_view platform_version() const override { return "<unknown>"; } runtime_type()183 PjRtRuntimeType runtime_type() const override { return kStreamExecutor; } 184 185 // Most platforms expect device-to-device transfers to be enqueued on the 186 // source d2d stream, but some platforms use the destination d2d stream. This 187 // function specifies which one the platform expects. EnqueueD2DTransfersOnSrcStream()188 virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } 189 190 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 191 int num_replicas, int num_partitions) const override; 192 193 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 194 const XlaComputation& computation, CompileOptions options) override; 195 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile( 196 mlir::ModuleOp mlir_module, CompileOptions options) override; 197 ExecutableFingerprint(const PjRtLoadedExecutable & executable)198 StatusOr<std::optional<std::string>> ExecutableFingerprint( 199 const PjRtLoadedExecutable& executable) const override { 200 return std::optional<std::string>(); 201 } 202 SerializeExecutable(const PjRtLoadedExecutable & executable)203 StatusOr<std::string> SerializeExecutable( 204 const PjRtLoadedExecutable& executable) const override { 205 return Unimplemented("SerializeExecutable not implemented on %s", 206 platform_name()); 207 } 208 DeserializeExecutable(absl::string_view serialized,CompileOptions options)209 StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable( 210 absl::string_view serialized, CompileOptions options) override { 211 return Unimplemented("DeserializeExecutable not implemented on %s", 212 platform_name()); 213 } 214 215 StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override; 216 217 // Creates a buffer on the device without initializing or copying any data. 218 // An optional `definition_event` may be speficied that can be used to 219 // ensure the buffer isn't referenced until some external mechanism has 220 // initialized the data. 221 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 222 const Shape& shape, PjRtDevice* device) override; 223 StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( 224 const Shape& shape, PjRtDevice* device, 225 std::shared_ptr<BufferSequencingEvent> definition_event); 226 227 StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>> CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,PjRtDevice * device)228 CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes, 229 PjRtDevice* device) override { 230 return Unimplemented("Async transfer to buffers not implemented"); 231 }; 232 233 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( 234 const void* data, PrimitiveType type, absl::Span<int64_t const> dims, 235 std::optional<absl::Span<int64_t const>> byte_strides, 236 HostBufferSemantics host_buffer_semantics, 237 std::function<void()> on_done_with_host_buffer, 238 PjRtDevice* device) override; 239 240 StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( 241 const LiteralSlice& literal, PjRtDevice* device) override; 242 243 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> 244 MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes, 245 PjRtDevice* device, 246 PjRtCrossHostRecvNotifier notifier) override; 247 248 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> 249 MakeCrossHostReceiveBuffersForGather( 250 absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details, 251 PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override; 252 253 StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( 254 void* device_ptr, const Shape& shape, PjRtDevice* device, 255 std::function<void()> on_delete_callback) override; 256 CreateChannelHandle()257 StatusOr<ChannelHandle> CreateChannelHandle() override { 258 return client()->CreateChannelHandle(); 259 } CreateDeviceToHostChannelHandle()260 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { 261 return client()->CreateDeviceToHostChannelHandle(); 262 } CreateHostToDeviceChannelHandle()263 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { 264 return client()->CreateHostToDeviceChannelHandle(); 265 } 266 267 // TODO(zhangqiaorjc): Experimental. Will be removed. Defragment()268 Status Defragment() override { 269 return Unimplemented("Defragment not implemented"); 270 } 271 device_state(int device_ordinal)272 LocalDeviceState& device_state(int device_ordinal) const { 273 return *tensorflow::down_cast<PjRtStreamExecutorDevice*>( 274 LookupAddressableDevice(device_ordinal).ValueOrDie()) 275 ->local_device_state(); 276 } client()277 LocalClient* client() const { return client_; } allocator()278 se::DeviceMemoryAllocator* allocator() const { return allocator_; } host_memory_allocator()279 tensorflow::Allocator* host_memory_allocator() const { 280 return host_memory_allocator_.get(); 281 } should_stage_host_to_device_transfers()282 bool should_stage_host_to_device_transfers() const { 283 return should_stage_host_to_device_transfers_; 284 } 285 gpu_run_options()286 gpu::GpuExecutableRunOptions* gpu_run_options() const { 287 return gpu_run_options_.get(); 288 } 289 thread_pool()290 tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; } 291 292 protected: 293 friend class PjRtStreamExecutorBuffer; 294 EnqueueCrossHostReceive(absl::Span<const std::unique_ptr<PjRtBuffer>> buffers,std::shared_ptr<BufferSequencingEvent> definition_event,PjRtCrossHostRecvNotifier notifier,std::optional<std::vector<GatherDetails>> gather_details)295 virtual Status EnqueueCrossHostReceive( 296 absl::Span<const std::unique_ptr<PjRtBuffer>> buffers, 297 std::shared_ptr<BufferSequencingEvent> definition_event, 298 PjRtCrossHostRecvNotifier notifier, 299 std::optional<std::vector<GatherDetails>> gather_details) const { 300 return Unimplemented("Cross host receives not implemented."); 301 } 302 CopyToRemoteDevice(PjRtBuffer * buffer,absl::string_view serialized_descriptor,PjRtBuffer::RemoteSendCallback on_done)303 virtual void CopyToRemoteDevice( 304 PjRtBuffer* buffer, absl::string_view serialized_descriptor, 305 PjRtBuffer::RemoteSendCallback on_done) const { 306 on_done(Unimplemented("Cross host sends not implemented."), 307 /*sends_were_enqueued=*/false); 308 } 309 CopyToRemoteDeviceScattered(PjRtBuffer * buffer,absl::Span<const std::pair<std::string,PjRtBuffer::RemoteSendCallback>> serialized_descriptors_and_callbacks,const PjRtBuffer::ScatterDetails & scatter_details)310 virtual void CopyToRemoteDeviceScattered( 311 PjRtBuffer* buffer, 312 absl::Span<const std::pair<std::string, PjRtBuffer::RemoteSendCallback>> 313 serialized_descriptors_and_callbacks, 314 const PjRtBuffer::ScatterDetails& scatter_details) const { 315 for (const auto& d_and_cb : serialized_descriptors_and_callbacks) { 316 d_and_cb.second( 317 Unimplemented("Scattered cross host sends not implemented."), 318 /*sends_were_enqueued=*/false); 319 } 320 } 321 CopyRawSubBufferToHost(PjRtBuffer * buffer,void * dst,int64_t offset,int64_t transfer_size)322 virtual PjRtFuture<Status> CopyRawSubBufferToHost(PjRtBuffer* buffer, 323 void* dst, int64_t offset, 324 int64_t transfer_size) { 325 return PjRtFuture<Status>( 326 Unimplemented("Raw copies to host not implemented.")); 327 } 328 329 // Helper function for creating PjRtStreamExecutorExecutables. Modifies 330 // `options` in-place. 331 struct ExecutableExtras { 332 std::shared_ptr<DeviceAssignment> device_assignment; 333 std::vector<PjRtLoadedExecutable::LogicalDeviceIds> 334 addressable_device_logical_ids; 335 std::vector<PjRtDevice*> addressable_devices; 336 }; 337 StatusOr<ExecutableExtras> GetExecutableExtras(CompileOptions* options); 338 339 const PjRtPlatformId platform_id_; 340 const std::string platform_name_; 341 LocalClient* client_; 342 343 // Allocator to be used for staging memory transfers to devices. 344 std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; 345 346 // Device memory allocator. If owned, the allocator must outlive the devices, 347 // because it is the device destructor that waits for any outstanding work to 348 // complete. 349 se::DeviceMemoryAllocator* allocator_; 350 std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; 351 352 // Includes all devices, including non-local devices on multi-host platforms. 353 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_; 354 // Pointers to `owned_devices_`. 355 std::vector<PjRtDevice*> devices_; 356 // Maps Device::id() to the corresponding Device. Includes all devices. 357 std::map<int, PjRtDevice*> id_to_device_; 358 // Local devices indexed by local device ordinal. 359 std::vector<PjRtDevice*> addressable_devices_; 360 int process_index_; 361 362 // Should we always prefer to stage host-to-device transfers via memory 363 // allocated on host_memory_allocator_? True only on GPU, where we prefer to 364 // transfer via pinned memory. 365 bool should_stage_host_to_device_transfers_; 366 367 std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_; 368 369 tensorflow::thread::ThreadPool thread_pool_; 370 371 absl::Mutex transpose_mu_; 372 TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); 373 }; 374 375 // Converts a 2D set of Device objects indexed by [replica][partition] into an 376 // xla::DeviceAssignment. 377 StatusOr<DeviceAssignment> DevicesToDeviceAssignment( 378 absl::Span<const std::vector<PjRtDevice*>> devices); 379 380 class PjRtStreamExecutorBuffer : public PjRtBuffer { 381 public: 382 // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold 383 // may not outlive its parent PjRtStreamExecutorBuffer. 384 // 385 // There are three types of hold, as follows: 386 // 387 // 1) Usage hold: a transient hold while an operation using the buffer is 388 // being enqueued onto a stream. 389 // A client acquires a usage hold by calling 390 // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience 391 // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the 392 // hold should be released using a call to ConvertUsageHold. If the ScopedHold 393 // is deleted without ConvertUsageHold being called, e.g., on error, the hold 394 // is dropped. It is legal to drop a usage hold instead of calling 395 // ConvertUsageHold, even if the buffer was successfully enqueued, as long as 396 // the client ensures that all necessary synchronization has been done. 397 // 398 // 2) External hold: a potentially long-lived hold while the buffer is being 399 // shared by an external framework, e.g., NumPy. 400 // A client acquires an external hold by calling 401 // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience 402 // wrapper GetBufferWithExternalReference and releases it by deleting the 403 // ScopedHold. The external framework should not modify the underlying buffer 404 // unless it is confident via its own synchronization that modifications do 405 // not race with reads from the PjRtStreamExecutorBuffer. 406 // 407 // 3) Donation hold: a transient hold while an execution that donates the 408 // buffer is being enqueued onto the compute stream. 409 // A client acquires a donation hold by calling 410 // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue 411 // completes successfully the hold should be released using a call to 412 // ConfirmDonation after which the buffer is invalid. If the ScopedHold is 413 // deleted without ConfirmDonation being called, e.g., on error, the hold is 414 // dropped and the buffer remains valid. If the buffer is successfully 415 // enqueued the client *must* call ConfirmDonation. 416 // 417 // Donation holds behave like exclusive write locks: when a donation hold 418 // has been acquired, any attempt to acquire another hold of any type will 419 // block until the donation hold is dropped or confirmed. Acquiring a donation 420 // hold will fail with an error if there is any outstanding external hold, and 421 // will block if there are any outstanding usage holds until those holds are 422 // dropped or converted. 423 // 424 // Calls to PjRtStreamExecutorBuffer::Release (and transitively to 425 // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will 426 // block until all usage and donation holds are either deleted or 427 // converted/confirmed. 428 class ScopedHold { 429 public: 430 enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; 431 // Use a State enum instead of encoding the state in an error Status to 432 // avoid creating Status values in non-error cases. Creating a Status 433 // entails several allocations and can add O(us) to every use of a hold. 434 enum State { 435 kUninitialized = 0, 436 kValid, 437 kMoved, 438 kConverted, 439 kReleased, 440 kDonated, 441 kError 442 }; 443 444 ~ScopedHold(); 445 ScopedHold(ScopedHold&& other); 446 ScopedHold(const ScopedHold&) = delete; 447 ScopedHold& operator=(const ScopedHold&) = delete; 448 type()449 Type type() const { return type_; } 450 status()451 Status status() const { 452 // Lazily create Status values only when they are requested. 453 switch (state_) { 454 case kUninitialized: 455 return InvalidArgument("Buffer has not been initialized"); 456 case kValid: 457 return OkStatus(); 458 case kMoved: 459 return InvalidArgument("Buffer has been moved."); 460 case kConverted: 461 return InvalidArgument("Buffer has been converted"); 462 case kReleased: 463 return InvalidArgument("Buffer has been released"); 464 case kDonated: 465 return InvalidArgument("Buffer has been donated"); 466 case kError: 467 return status_; 468 default: 469 CHECK(false) << "Unexpected state value " << state_; 470 } 471 } ok()472 bool ok() const { return state_ == kValid; } 473 474 // Access to the underlying device buffer storage. Requires this->ok(). buffer()475 const std::shared_ptr<TrackedDeviceBuffer>& buffer() const { 476 CHECK_EQ(state_, kValid); 477 CHECK_NE(buffer_, nullptr); 478 return buffer_; 479 } 480 TrackedDeviceBuffer* operator->() const { return buffer().get(); } 481 const TrackedDeviceBuffer& operator*() const { return *buffer(); } 482 483 // Converts the hold into a usage event. Only valid for holds of type 484 // kUsage. 485 // 486 // usage_stream: the stream that the buffer was used on. 487 // event: an event that has been recorded on usage_stream after 488 // the buffer was used. 489 // reference_held: true if and only if the caller has caused a 490 // reference to this->buffer() to stay live until after 491 // the host is sure that the usage (transfer or execution) 492 // has completed. 493 void ConvertUsageHold(se::Stream* usage_stream, 494 std::shared_ptr<BufferSequencingEvent> event, 495 bool reference_held); 496 497 // Confirms that the buffer was successfully donated to an execution. 498 // Only valid for holds of type kDonation. Causes the buffer to become 499 // invalid. 500 void ConfirmDonation(); 501 502 // Adds the held device buffers in order to 'iterator'. Used to add the 503 // buffers to an ExecutionInput. We require but do not verify that 504 // 'iterator' when passed in is pointing to a sub-tuple of the 505 // ExecutionInput whose on_device_shape matches that of the 506 // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run 507 // out of bounds. Donates the device buffers if the hold type is kDonation, 508 // otherwise retains ownership of the device buffers. 509 void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, 510 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end, 511 ExecutionInput* execution_input, 512 se::DeviceMemoryAllocator* allocator) const; 513 514 private: 515 friend class PjRtStreamExecutorBuffer; 516 friend class PjRtStreamExecutorClient; 517 518 // Helper struct that makes it possible to move a ScopedHold through a 519 // closure. 520 using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State, 521 Status, std::shared_ptr<TrackedDeviceBuffer>>; 522 ScopedHold(PjRtStreamExecutorBuffer * parent,Type type)523 ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) 524 : parent_(parent), type_(type), state_(kUninitialized) {} ScopedHold(const ForClosure & closure_helper)525 explicit ScopedHold(const ForClosure& closure_helper) 526 : parent_(std::get<0>(closure_helper)), 527 type_(std::get<1>(closure_helper)), 528 state_(std::get<2>(closure_helper)), 529 status_(std::get<3>(closure_helper)), 530 buffer_(std::get<4>(closure_helper)) { 531 // Check the buffer is not in an error state. 532 CHECK(status_.ok() && buffer_ != nullptr); 533 } 534 535 // Sets buffer state. SetState(State state)536 void SetState(State state) { state_ = state; } 537 538 // Sets buffer_ and status_. Called by parent_ to initialize the hold. 539 void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or); 540 // Releases the contents of *this, so *this can subsequently be 541 // deleted without releasing the parent's hold. Should be passed to the 542 // appropriate constructor of another ScopedHold, e.g., when a hold must be 543 // passed through a closure that is incompatible with std::move. 544 ForClosure ToClosure(); 545 546 PjRtStreamExecutorBuffer* const parent_; 547 const Type type_; 548 549 // There is an invariant that if ok() then 550 // buffer_.ValueOrDie() != nullptr. 551 State state_; 552 Status status_; 553 std::shared_ptr<TrackedDeviceBuffer> buffer_; 554 }; 555 556 PjRtStreamExecutorBuffer(Shape on_device_shape, 557 std::shared_ptr<TrackedDeviceBuffer> device_buffer, 558 PjRtClient* client, PjRtDevice* device); 559 ~PjRtStreamExecutorBuffer() override; 560 561 PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; 562 PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; 563 PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; 564 PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; 565 on_device_shape()566 const Shape& on_device_shape() const override { return on_device_shape_; } 567 StatusOr<Shape> logical_on_device_shape() override; device()568 PjRtStreamExecutorDevice* device() const override { return device_; } platform_id()569 PjRtPlatformId platform_id() const { return client_->platform_id(); } platform_name()570 absl::string_view platform_name() const { return client_->platform_name(); } client()571 PjRtStreamExecutorClient* client() const override { return client_; } IsEmptyTuple()572 bool IsEmptyTuple() const { 573 return on_device_shape_.IsTuple() && 574 on_device_shape_.tuple_shapes_size() == 0; 575 } 576 577 StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() 578 override; 579 580 StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( 581 bool wait_for_operations_to_complete) override; 582 583 using PjRtBuffer::ToLiteralSync; 584 PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override; 585 586 StatusOr<size_t> GetOnDeviceSizeInBytes() const override; 587 588 PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset, 589 int64_t transfer_size) override; 590 591 // Drops the buffer's reference to its associated device memory, leaving the 592 // buffer in an invalid state. The memory will be freed lazily when all async 593 // operations using the buffer have completed, according to the allocation 594 // semantics of the underlying platform. Delete may briefly block if another 595 // thread is in the process of enqueuing an operation on this buffer, but it 596 // will never block for a stream operation to complete. If an external 597 // framework holds a reference to the TrackedDeviceBuffer via 598 // GetBufferWithExternalReference, the memory will not be freed until the 599 // external framework drops the reference. 600 void Delete() override; 601 602 bool IsDeleted() override; 603 604 // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The 605 // PjRtBuffer retains ownership of the device buffers. 606 StatusOr<ShapedBuffer> AsShapedBuffer() const; 607 608 // Returns a hold on the TrackedDeviceBuffer holding the device 609 // buffers. See comment on ScopedHold. 610 ScopedHold GetBufferWithHold(ScopedHold::Type type); GetBufferWithUsageHold()611 ScopedHold GetBufferWithUsageHold() { 612 return GetBufferWithHold(ScopedHold::kUsage); 613 } GetBufferWithExternalReference()614 ScopedHold GetBufferWithExternalReference() { 615 return GetBufferWithHold(ScopedHold::kExternalReference); 616 } 617 618 StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( 619 PjRtDevice* dst_device) override; 620 621 void CopyToRemoteDevice(absl::string_view serialized_descriptor, 622 RemoteSendCallback on_done) override; 623 624 void CopyToRemoteDeviceScattered( 625 absl::Span<const std::pair<std::string, RemoteSendCallback>> 626 serialized_descriptors_and_callbacks, 627 const ScatterDetails& scatter_details) override; 628 629 PjRtFuture<Status> GetReadyFuture() override; 630 631 bool IsOnCpu() const override; 632 633 // Similar to Delete, drops the buffer's reference to its associated device 634 // memory, leaving the buffer in an invalid state, but returns the 635 // TrackedDeviceBuffer rather than freeing the device memory, so that another 636 // framework can take ownership of it. The buffer returned from Release may 637 // be safely dropped at any time even if it still has pending async 638 // operations. The client should call GetReadyFuture()->Await() before calling 639 // Release with wait_for_operations_to_complete=false, to ensure that the host 640 // has synchronized past any outstanding write operations to the buffer. If 641 // wait_for_operations_to_complete=true the host will block until any 642 // potentially outstanding asynchronous operations have completed before 643 // returning, in which case it is safe to read or mutate the returned buffer. 644 // If the buffer was shared via an external reference it is the client's 645 // responsibility that accesses via that reference do not interfere with 646 // accesses via the buffer returned from Release. 647 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release( 648 bool wait_for_operations_to_complete); 649 650 private: 651 friend class PjRtClient; 652 653 // Blocks in mu_.Await until there are no more usage holds. 654 void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 655 656 // Blocks in mu_.Await until there is no donation hold. 657 void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 658 659 // Adds a hold of 'type' and returns device_buffer_. Returns an error if 660 // device_buffer_ is null, or if a donation hold was requested when there is 661 // an outstanding external hold. 662 // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() 663 // must be called first.) 664 StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked( 665 ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 666 667 // Adds a hold of hold->type() and initializes `hold` with device_buffer_. 668 // Initializes hold with an error if device_buffer_ is null, or if a donation 669 // hold was requested when there is an outstanding external hold. 670 // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() 671 // must be called first.) 672 void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 673 674 // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity 675 // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after 676 // device_buffer_ was successfully enqueued on a stream. 677 void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, 678 std::shared_ptr<BufferSequencingEvent> event, 679 bool reference_held); 680 681 // Drops a donation hold and makes *this invalid for further use. Does a 682 // sanity check that buffer==device_buffer_. Called after device_buffer_ was 683 // successfully donated to an execution. 684 void ConfirmDonation(TrackedDeviceBuffer* device_buffer); 685 686 // Drops a hold without taking any other action. Does a sanity check that 687 // buffer==device_buffer_ or device_buffer_==nullptr. 688 void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); 689 690 StatusOr<std::pair<std::unique_ptr<PjRtBuffer>, 691 std::shared_ptr<BufferSequencingEvent>>> 692 CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, 693 LocalDeviceState* transfer_local_device, 694 se::Stream* transfer_stream, 695 std::shared_ptr<TrackedDeviceBuffer> src_device_buffer); 696 697 PjRtStreamExecutorClient* const client_; 698 const Shape on_device_shape_; 699 PjRtStreamExecutorDevice* const device_; 700 701 mutable absl::Mutex mu_; 702 std::shared_ptr<TrackedDeviceBuffer> device_buffer_ ABSL_GUARDED_BY(mu_); 703 // Count of holds on the buffer. 704 std::array<int, ScopedHold::Type::kMaxValue> holds_ ABSL_GUARDED_BY(mu_); 705 PjRtFuture<Status>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); 706 }; 707 708 // Wraps one or more XLA LocalExecutables (one per partition, as specified by 709 // the build options). 710 class PjRtStreamExecutorExecutable : public PjRtLoadedExecutable { 711 public: 712 PjRtStreamExecutorExecutable( 713 std::vector<std::unique_ptr<LocalExecutable>> executables, 714 bool parameter_is_tupled_arguments, 715 std::shared_ptr<DeviceAssignment> device_assignment, 716 std::vector<LogicalDeviceIds> addressable_device_logical_ids, 717 std::vector<PjRtDevice*> addressable_devices, 718 PjRtStreamExecutorClient* client); 719 720 ~PjRtStreamExecutorExecutable() override = default; 721 client()722 PjRtStreamExecutorClient* client() const override { return client_; } 723 724 absl::string_view name() const override; 725 num_replicas()726 int num_replicas() const override { 727 return executables_[0]->build_options().num_replicas(); 728 } 729 num_partitions()730 int num_partitions() const override { 731 return executables_[0]->build_options().num_partitions(); 732 } 733 SizeOfGeneratedCodeInBytes()734 int64_t SizeOfGeneratedCodeInBytes() const override { 735 int64_t size = 0; 736 for (auto& executable : executables_) { 737 size += executable->executable()->SizeOfGeneratedCodeInBytes(); 738 } 739 return size; 740 } 741 device_assignment()742 const DeviceAssignment& device_assignment() const override { 743 return *device_assignment_; 744 } 745 addressable_device_logical_ids()746 absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() 747 const override { 748 return addressable_device_logical_ids_; 749 } 750 addressable_devices()751 absl::Span<PjRtDevice* const> addressable_devices() const override { 752 return addressable_devices_; 753 } 754 755 // Return an HloModule per partition. 756 StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() 757 const override; 758 759 using PjRtLoadedExecutable::Execute; 760 StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( 761 absl::Span<const std::vector<PjRtBuffer*>> argument_handles, 762 const ExecuteOptions& options, 763 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) 764 override; 765 766 using PjRtLoadedExecutable::ExecuteSharded; 767 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( 768 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 769 const ExecuteOptions& options, 770 std::optional<PjRtFuture<Status>>& returned_future, 771 bool fill_future) override; 772 773 using PjRtLoadedExecutable::ExecutePortable; 774 StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( 775 absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, 776 const ExecuteOptions& options, 777 std::optional<PjRtFuture<Status>>& returned_future, 778 bool fill_future) override; 779 Delete()780 void Delete() override { executables_.clear(); } 781 IsDeleted()782 bool IsDeleted() override { return executables_.empty(); } 783 IsReturnedFutureSupported()784 bool IsReturnedFutureSupported() const override { return true; } 785 executables()786 absl::Span<const std::shared_ptr<LocalExecutable>> executables() const { 787 return executables_; 788 } 789 790 protected: parameter_is_tupled_arguments()791 bool parameter_is_tupled_arguments() const { 792 return parameter_is_tupled_arguments_; 793 } 794 795 private: 796 friend class PjRtStreamExecutorClient; 797 friend class PjRtTpuClient; 798 friend class InternalPjRtTpuClient; 799 // Initializes information about which arguments to which executables must be 800 // donated due to aliases that were specified by the computation. 801 Status SetUpDonation(bool tuple_inputs); 802 803 // Returns a sorted list of the parameters that must be donated. Derived 804 // classes may use custom logic. 805 virtual absl::Span<int const> ParametersThatMustBeDonated( 806 int executable_idx) const; 807 808 virtual StatusOr<std::vector<ExecutionInput>> 809 MakeExecutionInputsAndWaitForEvents( 810 int device_ordinal, const ExecuteOptions& options, 811 absl::Span<const Shape> executable_parameter_shapes, 812 absl::Span<PjRtBuffer* const> argument_handles, 813 absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers, 814 absl::flat_hash_set<BufferSequencingEvent*>& events) const; 815 816 StatusOr<ScopedShapedBuffer> EnqueueExecution( 817 absl::Span<PjRtBuffer* const> argument_handles, int replica, 818 int partition, int executable_idx, const RunId& run_id, 819 const ExecuteOptions& options, PjRtDevice* device, 820 std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers, 821 std::shared_ptr<DeviceAssignment> device_assignment, 822 std::vector<std::function<void()>>& compute_callbacks) const; 823 824 virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers( 825 int device_ordinal, const ExecuteOptions& options, 826 ScopedShapedBuffer result_buffer, 827 std::shared_ptr<BufferSequencingEvent> definition_event, 828 PjRtDevice* device, std::vector<std::function<void()>>& compute_callbacks, 829 std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) 830 const; 831 832 StatusOr<Result> ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles, 833 int replica, int partition, 834 const RunId& run_id, 835 const ExecuteOptions& options, 836 bool fill_future, 837 PjRtDevice* device = nullptr) const; 838 839 // Create shared pointers so we can free them after the execution: with 840 // asynchronous execution, the process being executed can outlive the 841 // executable itself. 842 PjRtStreamExecutorClient* const client_; 843 // One executable per partition. 844 std::vector<std::shared_ptr<LocalExecutable>> executables_; 845 // On device shapes of the executable parameters. 846 std::vector<std::vector<Shape>> on_device_executable_parameter_shapes_; 847 // Per-executable sorted vector of parameters that have any aliased buffers 848 // and thus must be donated when executing the computation. 849 std::vector<std::vector<int>> parameters_that_must_be_donated_; 850 std::shared_ptr<DeviceAssignment> device_assignment_; 851 852 // True if the executables were compiled expecting arguments in a single 853 // tuple. 854 const bool parameter_is_tupled_arguments_; 855 856 // The replica and partition indices of device_assignment_ to be run by this 857 // client. On single-host platforms without partitioning, this is all replicas 858 // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the 859 // case on multi-host platforms. If there are 4 replicas and 2 partitions on a 860 // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. 861 std::vector<LogicalDeviceIds> addressable_device_logical_ids_; 862 863 // addressable_devices_[i] is the Device to which 864 // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of 865 // unique_ptrs to play well with the Python bindings (see xla.cc). 866 std::vector<PjRtDevice*> addressable_devices_; 867 }; 868 869 } // namespace xla 870 871 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ 872