1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <utility> 22 23 #include "absl/container/inlined_vector.h" 24 #include "absl/synchronization/mutex.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/cpu_function_runtime.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/util.h" 29 #include "tensorflow/core/platform/mem.h" 30 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 31 32 namespace xla { 33 34 class MaybeOwningCpuMemory { 35 public: 36 MaybeOwningCpuMemory() = default; 37 38 // Non-owning. MaybeOwningCpuMemory(void * buf,size_t size)39 explicit MaybeOwningCpuMemory(void* buf, size_t size) 40 : buf_(buf), size_(size) {} 41 42 // Owning. 43 using OwnedDataPtr = 44 std::unique_ptr<uint8_t[], decltype(tensorflow::port::AlignedFree)*>; MaybeOwningCpuMemory(OwnedDataPtr data,size_t size)45 explicit MaybeOwningCpuMemory(OwnedDataPtr data, size_t size) 46 : buf_(data.get()), data_(std::move(data)), size_(size) {} 47 48 // Move-only. 49 MaybeOwningCpuMemory(MaybeOwningCpuMemory&&) = default; 50 MaybeOwningCpuMemory& operator=(MaybeOwningCpuMemory&&) = default; 51 MaybeOwningCpuMemory(const MaybeOwningCpuMemory&) = delete; 52 MaybeOwningCpuMemory& operator=(const MaybeOwningCpuMemory&) = delete; 53 54 // Owning. AllocateShared(size_t size)55 static StatusOr<std::shared_ptr<MaybeOwningCpuMemory>> AllocateShared( 56 size_t size) { 57 uint8_t* data = static_cast<uint8_t*>(tensorflow::port::AlignedMalloc( 58 size, cpu_function_runtime::MinAlign())); 59 if (!data) { 60 return ResourceExhausted("Out of memory allocating %d bytes.", size); 61 } 62 return std::make_shared<MaybeOwningCpuMemory>( 63 OwnedDataPtr{data, tensorflow::port::AlignedFree}, size); 64 } 65 data()66 void* data() const { return buf_; } size()67 size_t size() const { return size_; } owns_data()68 bool owns_data() const { return data_ != nullptr; } 69 70 private: 71 void* buf_ = nullptr; // Non-owning data pointer. 72 OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer; 73 size_t size_ = 0; // Size in number of bytes. 74 }; 75 76 // tfrt::AsyncValueRef<CpuEvent> is used to indicate the completion of a CPU 77 // operation, e.g., data transfer or running a program. 78 struct CpuEvent { 79 CpuEvent() = default; 80 }; 81 82 // Class that represents CPU buffers. It optionally owns the buffers. It also 83 // tracks the definition and usage of the memory to allow for synchronized usage 84 // and deletion of CPU memory. This class is thread-compatible. 85 class TrackedTfrtCpuDeviceBuffer { 86 public: 87 // For non-tuple, takes a single buffer. 88 // For tuple, takes the leaf buffers. Tuple index table created internally. 89 // Nested tuple is not supported. 90 TrackedTfrtCpuDeviceBuffer( 91 bool is_tuple, 92 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers, 93 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events, 94 std::function<void()> on_delete_callback = nullptr); 95 96 TrackedTfrtCpuDeviceBuffer( 97 bool is_tuple, 98 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers, 99 tfrt::AsyncValueRef<CpuEvent> definition_event, 100 std::function<void()> on_delete_callback = nullptr); 101 102 // Move-only. 103 TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default; 104 TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default; 105 TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete; 106 TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) = 107 delete; 108 109 ~TrackedTfrtCpuDeviceBuffer(); 110 Buffers()111 absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> Buffers() { 112 return buffers_; 113 } 114 115 std::shared_ptr<MaybeOwningCpuMemory> Buffer(const ShapeIndex& shape_index); 116 definition_event()117 const tfrt::AsyncValueRef<CpuEvent>& definition_event() const { 118 return definition_event_; 119 } 120 UsageEvents()121 absl::Span<const tfrt::AsyncValueRef<CpuEvent>> UsageEvents() const { 122 return usage_events_; 123 } 124 125 void AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events); 126 127 // Return the usage events for the buffers. After 128 // LockUseAndTransferUsageEvents is called, it is illegal to AddUsageEvent. 129 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> 130 LockUseAndTransferUsageEvents(); 131 132 // Relinquishes ownership of the buffer's device memory, e.g., after the 133 // buffer is passed to a computation that aliases its inputs to outputs. 134 void ReleaseDeviceMemory(); 135 136 private: 137 bool is_tuple_; 138 // If tuple, tuple index table is created and stored. 139 std::shared_ptr<MaybeOwningCpuMemory> tuple_index_table_; 140 // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers. 141 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers_; 142 // The definition event are associated with CPU operations that write to the 143 // buffers. 144 tfrt::AsyncValueRef<CpuEvent> definition_event_; 145 146 // Usage events are associated with CPU operations that read from the buffers. 147 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events_; 148 // A callback to call when the TrackedTfrtCpuDeviceBuffer is about to be 149 // destroyed. 150 std::function<void()> on_delete_callback_; 151 }; 152 } // namespace xla 153 154 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 155