xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/platform/mem.h"
30 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
31 
32 namespace xla {
33 
34 class MaybeOwningCpuMemory {
35  public:
36   MaybeOwningCpuMemory() = default;
37 
38   // Non-owning.
MaybeOwningCpuMemory(void * buf,size_t size)39   explicit MaybeOwningCpuMemory(void* buf, size_t size)
40       : buf_(buf), size_(size) {}
41 
42   // Owning.
43   using OwnedDataPtr =
44       std::unique_ptr<uint8_t[], decltype(tensorflow::port::AlignedFree)*>;
MaybeOwningCpuMemory(OwnedDataPtr data,size_t size)45   explicit MaybeOwningCpuMemory(OwnedDataPtr data, size_t size)
46       : buf_(data.get()), data_(std::move(data)), size_(size) {}
47 
48   // Move-only.
49   MaybeOwningCpuMemory(MaybeOwningCpuMemory&&) = default;
50   MaybeOwningCpuMemory& operator=(MaybeOwningCpuMemory&&) = default;
51   MaybeOwningCpuMemory(const MaybeOwningCpuMemory&) = delete;
52   MaybeOwningCpuMemory& operator=(const MaybeOwningCpuMemory&) = delete;
53 
54   // Owning.
AllocateShared(size_t size)55   static StatusOr<std::shared_ptr<MaybeOwningCpuMemory>> AllocateShared(
56       size_t size) {
57     uint8_t* data = static_cast<uint8_t*>(tensorflow::port::AlignedMalloc(
58         size, cpu_function_runtime::MinAlign()));
59     if (!data) {
60       return ResourceExhausted("Out of memory allocating %d bytes.", size);
61     }
62     return std::make_shared<MaybeOwningCpuMemory>(
63         OwnedDataPtr{data, tensorflow::port::AlignedFree}, size);
64   }
65 
data()66   void* data() const { return buf_; }
size()67   size_t size() const { return size_; }
owns_data()68   bool owns_data() const { return data_ != nullptr; }
69 
70  private:
71   void* buf_ = nullptr;                  // Non-owning data pointer.
72   OwnedDataPtr data_ = {nullptr, free};  // Owning data pointer;
73   size_t size_ = 0;                      // Size in number of bytes.
74 };
75 
76 // tfrt::AsyncValueRef<CpuEvent> is used to indicate the completion of a CPU
77 // operation, e.g., data transfer or running a program.
78 struct CpuEvent {
79   CpuEvent() = default;
80 };
81 
82 // Class that represents CPU buffers. It optionally owns the buffers. It also
83 // tracks the definition and usage of the memory to allow for synchronized usage
84 // and deletion of CPU memory. This class is thread-compatible.
85 class TrackedTfrtCpuDeviceBuffer {
86  public:
87   // For non-tuple, takes a single buffer.
88   // For tuple, takes the leaf buffers. Tuple index table created internally.
89   // Nested tuple is not supported.
90   TrackedTfrtCpuDeviceBuffer(
91       bool is_tuple,
92       absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
93       absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
94       std::function<void()> on_delete_callback = nullptr);
95 
96   TrackedTfrtCpuDeviceBuffer(
97       bool is_tuple,
98       absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
99       tfrt::AsyncValueRef<CpuEvent> definition_event,
100       std::function<void()> on_delete_callback = nullptr);
101 
102   // Move-only.
103   TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default;
104   TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default;
105   TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete;
106   TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) =
107       delete;
108 
109   ~TrackedTfrtCpuDeviceBuffer();
110 
Buffers()111   absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> Buffers() {
112     return buffers_;
113   }
114 
115   std::shared_ptr<MaybeOwningCpuMemory> Buffer(const ShapeIndex& shape_index);
116 
definition_event()117   const tfrt::AsyncValueRef<CpuEvent>& definition_event() const {
118     return definition_event_;
119   }
120 
UsageEvents()121   absl::Span<const tfrt::AsyncValueRef<CpuEvent>> UsageEvents() const {
122     return usage_events_;
123   }
124 
125   void AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events);
126 
127   // Return the usage events for the buffers. After
128   // LockUseAndTransferUsageEvents is called, it is illegal to AddUsageEvent.
129   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4>
130   LockUseAndTransferUsageEvents();
131 
132   // Relinquishes ownership of the buffer's device memory, e.g., after the
133   // buffer is passed to a computation that aliases its inputs to outputs.
134   void ReleaseDeviceMemory();
135 
136  private:
137   bool is_tuple_;
138   // If tuple, tuple index table is created and stored.
139   std::shared_ptr<MaybeOwningCpuMemory> tuple_index_table_;
140   // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers.
141   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers_;
142   // The definition event are associated with CPU operations that write to the
143   // buffers.
144   tfrt::AsyncValueRef<CpuEvent> definition_event_;
145 
146   // Usage events are associated with CPU operations that read from the buffers.
147   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events_;
148   // A callback to call when the TrackedTfrtCpuDeviceBuffer is about to be
149   // destroyed.
150   std::function<void()> on_delete_callback_;
151 };
152 }  // namespace xla
153 
154 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
155