1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
17
18 #include <atomic>
19 #include <iterator>
20 #include <memory>
21
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/stream_executor/device_memory.h"
27 #include "tensorflow/stream_executor/device_memory_allocator.h"
28 #include "tensorflow/stream_executor/event.h"
29 #include "tensorflow/stream_executor/stream.h"
30
31 namespace xla {
32
SetSequencingEvent(EventPool::Handle event,se::Stream * stream)33 void BufferSequencingEvent::SetSequencingEvent(EventPool::Handle event,
34 se::Stream* stream) {
35 absl::MutexLock lock(&mu_);
36 CHECK(!event_.event());
37 event_ = std::move(event);
38 CHECK(streams_defined_on_.empty());
39 streams_defined_on_.push_back(stream);
40 sequence_number_.store(event_.sequence_number(), std::memory_order_seq_cst);
41 }
42
EventHasBeenRecorded() const43 bool BufferSequencingEvent::EventHasBeenRecorded() const {
44 return event_.event() != nullptr;
45 }
46
sequence_number() const47 uint64_t BufferSequencingEvent::sequence_number() const {
48 uint64_t seq = sequence_number_.load(std::memory_order_seq_cst);
49 CHECK_NE(seq, 0);
50 return seq;
51 }
52
WaitForEventOnStream(se::Stream * stream)53 void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) {
54 absl::MutexLock lock(&mu_);
55
56 // We cannot wait for an event until ThenRecordEvent has been called; on GPU
57 // newly created events are deemed to have already happened past.
58 mu_.Await(
59 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
60
61 // The set of defined streams is expected to be very small indeed (usually
62 // 1-2), so a simple linear scan should be fast enough.
63 if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
64 stream) != streams_defined_on_.end()) {
65 // stream is in streams_defined_on_; it doesn't need to be waited on.
66 return;
67 }
68
69 stream->ThenWaitFor(event_.event());
70 streams_defined_on_.push_back(stream);
71 }
72
DefinedOn(se::Stream * stream)73 bool BufferSequencingEvent::DefinedOn(se::Stream* stream) {
74 absl::MutexLock lock(&mu_);
75
76 // We cannot wait for an event until ThenRecordEvent has been called; on GPU
77 // newly created events are deemed to have already happened past.
78 mu_.Await(
79 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
80
81 // The set of defined streams is expected to be very small indeed (usually
82 // 1-2), so a simple linear scan should be fast enough.
83 return std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
84 stream) != streams_defined_on_.end();
85 }
86
IsComplete()87 bool BufferSequencingEvent::IsComplete() {
88 absl::MutexLock lock(&mu_);
89
90 // We cannot wait for an event until ThenRecordEvent has been called; on
91 // GPU newly created events are deemed to have already happened past.
92 mu_.Await(
93 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
94
95 return event_.event()->PollForStatus() == se::Event::Status::kComplete;
96 }
97
98 /* static */ std::shared_ptr<TrackedDeviceBuffer>
FromScopedShapedBuffer(ScopedShapedBuffer * shaped_buffer,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events)99 TrackedDeviceBuffer::FromScopedShapedBuffer(
100 ScopedShapedBuffer* shaped_buffer,
101 absl::Span<const std::shared_ptr<BufferSequencingEvent>>
102 definition_events) {
103 ShapeTree<se::DeviceMemoryBase>::iterator iterator =
104 shaped_buffer->buffers().begin();
105 std::vector<se::DeviceMemoryBase> buffers;
106 buffers.reserve(1);
107
108 ShapeUtil::ForEachSubshape(
109 shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) {
110 CHECK(iterator != shaped_buffer->buffers().end());
111 buffers.push_back(iterator->second);
112 iterator->second = se::DeviceMemoryBase();
113 ++iterator;
114 });
115 CHECK(iterator == shaped_buffer->buffers().end());
116 return std::make_shared<TrackedDeviceBuffer>(
117 shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(),
118 absl::Span<se::DeviceMemoryBase>(buffers), definition_events,
119 /*on_delete_callback=*/nullptr);
120 }
121
AsShapedBuffer(const Shape & on_device_shape) const122 ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
123 const Shape& on_device_shape) const {
124 ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
125 ShapeTree<se::DeviceMemoryBase>::iterator iterator =
126 shaped_buffer.buffers().begin();
127 for (const se::DeviceMemoryBase& buf : device_memory_) {
128 CHECK(iterator != shaped_buffer.buffers().end());
129 iterator->second = buf;
130 ++iterator;
131 }
132 CHECK(iterator == shaped_buffer.buffers().end());
133 return shaped_buffer;
134 }
135
136 // See comment on ExecutionInput in xla/service/executable.h to understand
137 // the meaning of owned/unowned in that class.
138
AddToInputAsImmutable(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end) const139 void TrackedDeviceBuffer::AddToInputAsImmutable(
140 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
141 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const {
142 for (const se::DeviceMemoryBase& buf : device_memory_) {
143 CHECK(*iterator != end);
144 // Set buffers to be case (1) in the comment on ExecutionInput.
145 (*iterator)->second = MaybeOwningDeviceMemory(buf);
146 ++(*iterator);
147 }
148 }
149
AddToInputAsDonated(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const150 void TrackedDeviceBuffer::AddToInputAsDonated(
151 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
152 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
153 ExecutionInput* execution_input,
154 se::DeviceMemoryAllocator* allocator) const {
155 for (const se::DeviceMemoryBase& buf : device_memory_) {
156 CHECK(*iterator != end);
157 // Set buffers to be case (2) in the comment on ExecutionInput.
158 (*iterator)->second = MaybeOwningDeviceMemory(
159 se::OwningDeviceMemory(buf, device_ordinal_, allocator));
160 execution_input->SetUnownedIndex((*iterator)->first);
161 ++(*iterator);
162 }
163 }
164
TrackedDeviceBuffer(se::DeviceMemoryAllocator * allocator,int device_ordinal,absl::Span<se::DeviceMemoryBase const> device_memory,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,std::function<void ()> on_delete_callback)165 TrackedDeviceBuffer::TrackedDeviceBuffer(
166 se::DeviceMemoryAllocator* allocator, int device_ordinal,
167 absl::Span<se::DeviceMemoryBase const> device_memory,
168 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,
169 std::function<void()> on_delete_callback)
170 : allocator_(allocator),
171 device_ordinal_(device_ordinal),
172 device_memory_(device_memory.begin(), device_memory.end()),
173 definition_events_(std::make_move_iterator(definition_events.begin()),
174 std::make_move_iterator(definition_events.end())),
175 in_use_(true),
176 on_delete_callback_(std::move(on_delete_callback)) {}
177
~TrackedDeviceBuffer()178 TrackedDeviceBuffer::~TrackedDeviceBuffer() {
179 if (allocator_) {
180 for (const se::DeviceMemoryBase& buffer : device_memory_) {
181 Status status = allocator_->Deallocate(device_ordinal_, buffer);
182 if (!status.ok()) {
183 LOG(ERROR) << "Buffer deallocation failed: " << status;
184 }
185 }
186 }
187 if (on_delete_callback_) {
188 on_delete_callback_();
189 }
190 }
191
AddUsageEvent(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)192 void TrackedDeviceBuffer::AddUsageEvent(
193 se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
194 bool reference_held) {
195 CHECK(in_use_);
196
197 for (auto& existing : usage_events_) {
198 if (existing.stream == usage_stream) {
199 if (*existing.event < *event) {
200 existing.event = event;
201 existing.reference_held = reference_held;
202 }
203 return;
204 }
205 }
206 usage_events_.push_back({usage_stream, event, reference_held});
207 }
208
209 TrackedDeviceBuffer::StreamAndEventContainer
LockUseAndTransferUsageEvents()210 TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
211 CHECK(in_use_);
212 in_use_ = false;
213 return std::move(usage_events_);
214 }
215
GetDeviceBufferEvents(const TrackedDeviceBuffer & buffer,bool get_usage_events,absl::flat_hash_set<BufferSequencingEvent * > * events)216 void GetDeviceBufferEvents(
217 const TrackedDeviceBuffer& buffer, bool get_usage_events,
218 absl::flat_hash_set<BufferSequencingEvent*>* events) {
219 if (get_usage_events) {
220 for (const auto& e : buffer.usage_events()) {
221 events->insert(e.event.get());
222 }
223 } else {
224 for (const auto& e : buffer.definition_events()) {
225 events->insert(e.get());
226 }
227 }
228 }
229
WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer & buffer,se::Stream * stream)230 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
231 se::Stream* stream) {
232 absl::flat_hash_set<BufferSequencingEvent*> events;
233 GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
234 for (BufferSequencingEvent* event : events) {
235 event->WaitForEventOnStream(stream);
236 }
237 }
238
239 } // namespace xla
240