xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
17 
18 #include <atomic>
19 #include <iterator>
20 #include <memory>
21 
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/stream_executor/device_memory.h"
27 #include "tensorflow/stream_executor/device_memory_allocator.h"
28 #include "tensorflow/stream_executor/event.h"
29 #include "tensorflow/stream_executor/stream.h"
30 
31 namespace xla {
32 
SetSequencingEvent(EventPool::Handle event,se::Stream * stream)33 void BufferSequencingEvent::SetSequencingEvent(EventPool::Handle event,
34                                                se::Stream* stream) {
35   absl::MutexLock lock(&mu_);
36   CHECK(!event_.event());
37   event_ = std::move(event);
38   CHECK(streams_defined_on_.empty());
39   streams_defined_on_.push_back(stream);
40   sequence_number_.store(event_.sequence_number(), std::memory_order_seq_cst);
41 }
42 
EventHasBeenRecorded() const43 bool BufferSequencingEvent::EventHasBeenRecorded() const {
44   return event_.event() != nullptr;
45 }
46 
sequence_number() const47 uint64_t BufferSequencingEvent::sequence_number() const {
48   uint64_t seq = sequence_number_.load(std::memory_order_seq_cst);
49   CHECK_NE(seq, 0);
50   return seq;
51 }
52 
WaitForEventOnStream(se::Stream * stream)53 void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) {
54   absl::MutexLock lock(&mu_);
55 
56   // We cannot wait for an event until ThenRecordEvent has been called; on GPU
57   // newly created events are deemed to have already happened past.
58   mu_.Await(
59       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
60 
61   // The set of defined streams is expected to be very small indeed (usually
62   // 1-2), so a simple linear scan should be fast enough.
63   if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
64                 stream) != streams_defined_on_.end()) {
65     // stream is in streams_defined_on_; it doesn't need to be waited on.
66     return;
67   }
68 
69   stream->ThenWaitFor(event_.event());
70   streams_defined_on_.push_back(stream);
71 }
72 
DefinedOn(se::Stream * stream)73 bool BufferSequencingEvent::DefinedOn(se::Stream* stream) {
74   absl::MutexLock lock(&mu_);
75 
76   // We cannot wait for an event until ThenRecordEvent has been called; on GPU
77   // newly created events are deemed to have already happened past.
78   mu_.Await(
79       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
80 
81   // The set of defined streams is expected to be very small indeed (usually
82   // 1-2), so a simple linear scan should be fast enough.
83   return std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
84                    stream) != streams_defined_on_.end();
85 }
86 
IsComplete()87 bool BufferSequencingEvent::IsComplete() {
88   absl::MutexLock lock(&mu_);
89 
90   // We cannot wait for an event until ThenRecordEvent has been called; on
91   // GPU newly created events are deemed to have already happened past.
92   mu_.Await(
93       absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
94 
95   return event_.event()->PollForStatus() == se::Event::Status::kComplete;
96 }
97 
98 /* static */ std::shared_ptr<TrackedDeviceBuffer>
FromScopedShapedBuffer(ScopedShapedBuffer * shaped_buffer,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events)99 TrackedDeviceBuffer::FromScopedShapedBuffer(
100     ScopedShapedBuffer* shaped_buffer,
101     absl::Span<const std::shared_ptr<BufferSequencingEvent>>
102         definition_events) {
103   ShapeTree<se::DeviceMemoryBase>::iterator iterator =
104       shaped_buffer->buffers().begin();
105   std::vector<se::DeviceMemoryBase> buffers;
106   buffers.reserve(1);
107 
108   ShapeUtil::ForEachSubshape(
109       shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) {
110         CHECK(iterator != shaped_buffer->buffers().end());
111         buffers.push_back(iterator->second);
112         iterator->second = se::DeviceMemoryBase();
113         ++iterator;
114       });
115   CHECK(iterator == shaped_buffer->buffers().end());
116   return std::make_shared<TrackedDeviceBuffer>(
117       shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(),
118       absl::Span<se::DeviceMemoryBase>(buffers), definition_events,
119       /*on_delete_callback=*/nullptr);
120 }
121 
AsShapedBuffer(const Shape & on_device_shape) const122 ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
123     const Shape& on_device_shape) const {
124   ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
125   ShapeTree<se::DeviceMemoryBase>::iterator iterator =
126       shaped_buffer.buffers().begin();
127   for (const se::DeviceMemoryBase& buf : device_memory_) {
128     CHECK(iterator != shaped_buffer.buffers().end());
129     iterator->second = buf;
130     ++iterator;
131   }
132   CHECK(iterator == shaped_buffer.buffers().end());
133   return shaped_buffer;
134 }
135 
136 // See comment on ExecutionInput in xla/service/executable.h to understand
137 // the meaning of owned/unowned in that class.
138 
AddToInputAsImmutable(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end) const139 void TrackedDeviceBuffer::AddToInputAsImmutable(
140     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
141     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const {
142   for (const se::DeviceMemoryBase& buf : device_memory_) {
143     CHECK(*iterator != end);
144     // Set buffers to be case (1) in the comment on ExecutionInput.
145     (*iterator)->second = MaybeOwningDeviceMemory(buf);
146     ++(*iterator);
147   }
148 }
149 
AddToInputAsDonated(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const150 void TrackedDeviceBuffer::AddToInputAsDonated(
151     ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
152     const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
153     ExecutionInput* execution_input,
154     se::DeviceMemoryAllocator* allocator) const {
155   for (const se::DeviceMemoryBase& buf : device_memory_) {
156     CHECK(*iterator != end);
157     // Set buffers to be case (2) in the comment on ExecutionInput.
158     (*iterator)->second = MaybeOwningDeviceMemory(
159         se::OwningDeviceMemory(buf, device_ordinal_, allocator));
160     execution_input->SetUnownedIndex((*iterator)->first);
161     ++(*iterator);
162   }
163 }
164 
TrackedDeviceBuffer(se::DeviceMemoryAllocator * allocator,int device_ordinal,absl::Span<se::DeviceMemoryBase const> device_memory,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,std::function<void ()> on_delete_callback)165 TrackedDeviceBuffer::TrackedDeviceBuffer(
166     se::DeviceMemoryAllocator* allocator, int device_ordinal,
167     absl::Span<se::DeviceMemoryBase const> device_memory,
168     absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,
169     std::function<void()> on_delete_callback)
170     : allocator_(allocator),
171       device_ordinal_(device_ordinal),
172       device_memory_(device_memory.begin(), device_memory.end()),
173       definition_events_(std::make_move_iterator(definition_events.begin()),
174                          std::make_move_iterator(definition_events.end())),
175       in_use_(true),
176       on_delete_callback_(std::move(on_delete_callback)) {}
177 
~TrackedDeviceBuffer()178 TrackedDeviceBuffer::~TrackedDeviceBuffer() {
179   if (allocator_) {
180     for (const se::DeviceMemoryBase& buffer : device_memory_) {
181       Status status = allocator_->Deallocate(device_ordinal_, buffer);
182       if (!status.ok()) {
183         LOG(ERROR) << "Buffer deallocation failed: " << status;
184       }
185     }
186   }
187   if (on_delete_callback_) {
188     on_delete_callback_();
189   }
190 }
191 
AddUsageEvent(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)192 void TrackedDeviceBuffer::AddUsageEvent(
193     se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
194     bool reference_held) {
195   CHECK(in_use_);
196 
197   for (auto& existing : usage_events_) {
198     if (existing.stream == usage_stream) {
199       if (*existing.event < *event) {
200         existing.event = event;
201         existing.reference_held = reference_held;
202       }
203       return;
204     }
205   }
206   usage_events_.push_back({usage_stream, event, reference_held});
207 }
208 
209 TrackedDeviceBuffer::StreamAndEventContainer
LockUseAndTransferUsageEvents()210 TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
211   CHECK(in_use_);
212   in_use_ = false;
213   return std::move(usage_events_);
214 }
215 
GetDeviceBufferEvents(const TrackedDeviceBuffer & buffer,bool get_usage_events,absl::flat_hash_set<BufferSequencingEvent * > * events)216 void GetDeviceBufferEvents(
217     const TrackedDeviceBuffer& buffer, bool get_usage_events,
218     absl::flat_hash_set<BufferSequencingEvent*>* events) {
219   if (get_usage_events) {
220     for (const auto& e : buffer.usage_events()) {
221       events->insert(e.event.get());
222     }
223   } else {
224     for (const auto& e : buffer.definition_events()) {
225       events->insert(e.get());
226     }
227   }
228 }
229 
WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer & buffer,se::Stream * stream)230 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
231                                            se::Stream* stream) {
232   absl::flat_hash_set<BufferSequencingEvent*> events;
233   GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
234   for (BufferSequencingEvent* event : events) {
235     event->WaitForEventOnStream(stream);
236   }
237 }
238 
239 }  // namespace xla
240