xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
18 
19 #include <memory>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
23 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/service/transfer_manager.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/stream_executor/device_memory.h"
28 #include "tensorflow/stream_executor/device_memory_allocator.h"
29 #include "tensorflow/stream_executor/stream.h"
30 
31 namespace xla {
32 
33 // A BufferSequencingEvent keeps track of dependencies of a buffer on each
34 // stream it has been used on.
35 //
36 // Each logical buffer in an XLA computation may be defined (i.e., written to)
37 // at most once. We call the operation that writes the buffer's value on some
38 // stream (e.g., a transfer or compute kernel) the buffer's definition event.
39 //
40 // After the operation that populates the value of a buffer has been enqueued on
41 // 'stream', RecordOnStream(stream) should also be called to trigger the
42 // definition event after the operation has completed.
43 //
44 // After the buffer is read on 'stream' another event should be added so that
45 // it is possible to sequence buffer donation after all reads have completed.
46 //
47 // Since different streams are not necessarily synchronized with one another,
48 // if we wish to consume the value of the buffer on a different stream, we
49 // should first call WaitForEventOnStream(stream), which add a cross-stream
50 // from 'stream' to the buffer's definition event, causing 'stream' to pause
51 // until the definition event has been triggered, if needed. Operations on
52 // 'stream' may then assume that the buffer is valid and its contents correspond
53 // to the desired buffer.
54 //
55 // The dependency logic caches the set of streams at the tail of which the
56 // definition event is known to have occurred; waiting for the same event on the
57 // same stream causes no additional waiting.
58 class BufferSequencingEvent {
59  public:
60   BufferSequencingEvent() = default;
61 
62   // Sets the sequencing event to 'event', which is recorded on 'stream'. Must
63   // be called at most once. Unblocks any other host threads that are blocked in
64   // WaitForEventOnStream.
65   void SetSequencingEvent(EventPool::Handle event, se::Stream* stream);
66 
67   // Adds synchronization events to 'stream' that wait for this event to be
68   // defined on 'stream'. Does nothing if the event is already known to have
69   // occurred by the tail of 'stream'. If RecordOnStream has not yet been
70   // called, blocks the calling thread until the event has been recorded.
71   void WaitForEventOnStream(se::Stream* stream);
72 
73   // Returns true if the event is known to have occurred by the tail of
74   // 'stream'. If RecordOnStream has not yet been called, blocks the calling
75   // thread until the event has been recorded.
76   bool DefinedOn(se::Stream* stream);
77 
78   // Returns true if the event is known by the host to have already occurred. If
79   // RecordOnStream has not yet been called, blocks the calling thread until the
80   // event has been recorded.
81   bool IsComplete();
82 
83   // Compares the sequence numbers of two recorded events. It is illegal to call
84   // the comparison operators unless both events have been recorded.
85   inline bool operator<(const BufferSequencingEvent& rhs) const {
86     return sequence_number() < rhs.sequence_number();
87   }
88   inline bool operator>(const BufferSequencingEvent& rhs) const {
89     return rhs < *this;
90   }
91   inline bool operator<=(const BufferSequencingEvent& rhs) const {
92     return !(*this > rhs);
93   }
94   inline bool operator>=(const BufferSequencingEvent& rhs) const {
95     return !(*this < rhs);
96   }
97 
98  private:
99   bool EventHasBeenRecorded() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
100   uint64_t sequence_number() const;
101 
102   // An event that is triggered when the content of one or more buffers has been
103   // read or written. If this event is used as a definition event and is
104   // nullptr, it is assumed that the buffer's content is always defined for
105   // example because it uses storage borrowed from elsewhere.
106   EventPool::Handle event_;
107 
108   // Cache of event_->sequence_number that avoids synchronization overhead.
109   // TODO(phawkins): In fact, event_->sequence_number is unused beyond the
110   // initial population of sequence_number_, and we could remove it if we
111   // refactored the EventPool API.
112   std::atomic<uint64_t> sequence_number_{0};
113 
114   mutable absl::Mutex mu_;
115   // A list of all streams for which the buffer's content is known to be defined
116   // at the tail of the queue, i.e., for any newly enqueued command.
117   absl::InlinedVector<se::Stream*, 2> streams_defined_on_ ABSL_GUARDED_BY(mu_);
118 };
119 
120 // Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it
121 // owns all of the device memory in the tuple. It also tracks the definition and
122 // usage of the memory on streams, to allow for synchronized usage and deletion
123 // of memory under all of the allocation model semantics.
124 class TrackedDeviceBuffer {
125  public:
126   // Helper object to keep track of usage of the buffer on streams.
127   struct StreamAndEvent {
128     // A stream the buffer has been used on.
129     se::Stream* stream;
130     // An event that is later than the most recent usage of the buffer on
131     // stream.
132     std::shared_ptr<BufferSequencingEvent> event;
133     // True if and only if a reference to the buffer is kept live until after
134     // the host knows that event is complete.
135     bool reference_held;
136   };
137 
138   // Converts a ScopedShapedBuffer into a TrackedDeviceBuffer. Takes ownership
139   // of the buffers of the shaped_buffer.
140   static std::shared_ptr<TrackedDeviceBuffer> FromScopedShapedBuffer(
141       ScopedShapedBuffer* shaped_buffer,
142       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
143           definition_events);
144 
145   // Builds a ShapedBuffer view onto the buffers of 'tree'.
146   ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const;
147 
148   // Adds the owned device buffers in order to 'iterator'. Used to add the
149   // buffers to an ExecutionInput. We require but do not verify that 'iterator'
150   // when passed in is pointing to a sub-tuple of the ExecutionInput whose
151   // on_device_shape matches that of the TrackedDeviceBuffer. 'end' is used to
152   // check that 'iterator' doesn't run out of bounds.
153   void AddToInputAsImmutable(
154       ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
155       const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const;
156 
157   // Adds the owned device buffers in order to 'iterator', marking them as
158   // available to be donated. If donation succeeds, i.e., execution_input is
159   // subsequently successfully enqueued to a computation,
160   // this->ReleaseDeviceMemory() must be called to avoid freeing the device
161   // memory twice. We require but do not verify that 'iterator' when passed in
162   // is pointing to a sub-tuple of execution_input whose on_device_shape matches
163   // that of the TrackedDeviceBuffer. 'end' is used to check that 'iterator'
164   // doesn't run out of bounds.
165   void AddToInputAsDonated(
166       ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
167       const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
168       ExecutionInput* execution_input,
169       se::DeviceMemoryAllocator* allocator) const;
170 
allocator()171   se::DeviceMemoryAllocator* allocator() const { return allocator_; }
device_ordinal()172   int device_ordinal() const { return device_ordinal_; }
device_memory()173   absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
174     return device_memory_;
175   }
device_memory()176   const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
177     return device_memory_;
178   }
definition_events()179   absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events()
180       const {
181     return definition_events_;
182   }
usage_events()183   absl::Span<const StreamAndEvent> usage_events() const {
184     return usage_events_;
185   }
186 
187   // Relinquishes ownership of the buffer's device memory, e.g., after the
188   // buffer is passed to a computation that aliases its inputs to outputs.
ReleaseDeviceMemory()189   void ReleaseDeviceMemory() { device_memory_.clear(); }
190 
191   // Indicates that the buffer has been used on a stream.
192   //
193   //   usage_stream:   a stream that the buffer was used on.
194   //   event:          an event that has been recorded on usage_stream after the
195   //                   buffer was used.
196   //   reference_held: true if and only if the caller has caused a memory
197   //                   reference to *this to stay live until after the host
198   //                   is sure that the usage (transfer or execution) has
199   //                   completed.
200   void AddUsageEvent(se::Stream* usage_stream,
201                      std::shared_ptr<BufferSequencingEvent> event,
202                      bool reference_held);
203 
204   using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>;
205   // Returns the set of streams that the buffer was used on, and for each stream
206   // an event later than the last use of the buffer. After
207   // LockUseAndTransferUsageEvents is called it is illegal to use the buffer on
208   // any stream and, e.g. AddUsageHold will CHECK fail.
209   StreamAndEventContainer LockUseAndTransferUsageEvents();
210 
TrackedDeviceBuffer()211   TrackedDeviceBuffer() : in_use_(true) {}
212   TrackedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
213                       absl::Span<se::DeviceMemoryBase const> device_memory,
214                       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
215                           definition_events,
216                       std::function<void()> on_delete_callback);
217   ~TrackedDeviceBuffer();
218 
219  private:
220   // Are the buffers in device_memory_ owned? If so, which allocator and device
221   // ordinal? May be nullptr, indicating the buffers are not owned.
222   se::DeviceMemoryAllocator* allocator_;
223   int device_ordinal_;
224 
225   // Each host-side buffer may have several buffers on-device.
226   absl::InlinedVector<se::DeviceMemoryBase, 1> device_memory_;
227 
228   // Events that are triggered when the content of one or more buffers is ready
229   // during multistream execution. May be nullptr, which is used in the
230   // single-stream execution case where events are not necessary for buffer
231   // event sequencing. All events must be triggered before the buffers can be
232   // used.
233   absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
234       definition_events_;
235 
236   // in_use_ starts out true, and is set to false when the buffer is released
237   // from its owning PjRtBuffer. Once in_use_ is false, the buffer may no
238   // longer be used on any stream.
239   bool in_use_;
240   // Set of streams that the buffer has ever been used on, see comment on
241   // StreamAndEvent.
242   StreamAndEventContainer usage_events_;
243 
244   // A callback to call when the TrackedDeviceBuffer is about to be destroyed.
245   std::function<void()> on_delete_callback_;
246 };
247 
248 // Populates 'events' with the set of buffer events for buffer. If
249 // get_usage_events=true populates with the latest usage events, otherwise
250 // populates with the definition events.
251 void GetDeviceBufferEvents(const TrackedDeviceBuffer& buffer,
252                            bool get_usage_events,
253                            absl::flat_hash_set<BufferSequencingEvent*>* events);
254 
255 // Waits for all of the definition events in a buffer on 'stream'.
256 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
257                                            se::Stream* stream);
258 
259 }  // namespace xla
260 
261 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
262