1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xrt/xrt_metrics.h"
24 #include "tensorflow/core/lib/monitoring/timed.h"
25 #include "tensorflow/core/lib/random/random.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27
28 namespace tensorflow {
29 namespace {
30
31 // We use kDeviceBits to store the device ordinal in the handle. We store the
32 // device in the upper part of the int64 handle to make sure the random bits are
33 // in the lower part which is better when storing the handle as a key for
34 // unordered maps.
35 const int kDeviceBits = 12;
36
MakeDeviceHandle(int64_t device_ordinal,int64_t rnd_value)37 int64_t MakeDeviceHandle(int64_t device_ordinal, int64_t rnd_value) {
38 const int64_t kUidMask = (static_cast<int64_t>(1) << (64 - kDeviceBits)) - 1;
39 return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
40 }
41
GetDeviceFromHandle(int64_t handle)42 int GetDeviceFromHandle(int64_t handle) {
43 return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
44 }
45
46 } // namespace
47
48 class XRTMemoryManager::DeviceContext {
49 struct Alloc {
Alloctensorflow::XRTMemoryManager::DeviceContext::Alloc50 explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
51 : tuple(std::move(tuple)) {}
52
53 RefPtr<XRTTupleAllocation> tuple;
54 };
55
56 using AllocList = std::list<Alloc>;
57
58 public:
Register(RefPtr<XRTTupleAllocation> tuple)59 int64_t Register(RefPtr<XRTTupleAllocation> tuple) {
60 while (true) {
61 int64_t handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
62 mutex_lock lock(lock_);
63 allocs_.emplace_front(tuple);
64 if (alloc_map_.emplace(handle, allocs_.begin()).second) {
65 return handle;
66 }
67 // The chances of hitting an existing handle are so remote, it is much
68 // more convenient to add to the list before, and eventually removing.
69 allocs_.erase(allocs_.begin());
70 }
71 }
72
Release(int64_t handle)73 bool Release(int64_t handle) {
74 mutex_lock lock(lock_);
75 auto it = alloc_map_.find(handle);
76 if (it == alloc_map_.end()) {
77 return false;
78 }
79 allocs_.erase(it->second);
80 alloc_map_.erase(it);
81 return true;
82 }
83
Lookup(int64_t handle)84 RefPtr<XRTTupleAllocation> Lookup(int64_t handle) {
85 mutex_lock lock(lock_);
86 auto it = alloc_map_.find(handle);
87 if (it == alloc_map_.end()) {
88 return nullptr;
89 }
90 // LRU
91 allocs_.splice(allocs_.begin(), allocs_, it->second);
92 return it->second->tuple;
93 }
94
Clear()95 void Clear() {
96 mutex_lock lock(lock_);
97 alloc_map_.clear();
98 allocs_.clear();
99 }
100
CompactAllocations(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)101 Status CompactAllocations(XRTMemoryManager* memory_manager,
102 xla::Backend* backend,
103 se::DeviceMemoryAllocator* allocator) {
104 profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations",
105 /*level=*/2);
106 auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell());
107 VLOG(4) << "CompactAllocations started";
108 mutex_lock lock(lock_);
109 Status status;
110 std::vector<AllocList::iterator> swapped;
111 // We are swapping out from the most recently used allocations. This is
112 // desirable since the most recently used will be finding themselves at the
113 // bottom of the allocation space. Since these are more likely to be pinned
114 // allocations, a further trim done by following TryFreeMemory() call will
115 // eventually drop the higher located allocations, with better chance of
116 // reducing fragmentation.
117 // Also, by swapping out the pinned allocations first, those will also be
118 // the first to be restored, and hence if we will ever find OOM on the way
119 // out, we would more likely be swapping in not pinned ones.
120 for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
121 // We are compacting all the allocations, so we will temporarily swap out
122 // even pinned allocations.
123 auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
124 if (!swap_result_or.ok()) {
125 status = swap_result_or.status();
126 break;
127 }
128 if (swap_result_or.ValueOrDie()) {
129 swapped.push_back(it);
130 }
131 }
132 // At this point we have released all the device memory we could release.
133 // Load back the tuple allocations we have swapped out above.
134 for (auto& it : swapped) {
135 auto swap_result_or =
136 it->tuple->SwapIn(memory_manager, backend, allocator);
137 if (!swap_result_or.ok()) {
138 // If we failed to restored a pinned allocation, better to CHECK here
139 // than wondering why XRTTupleAllocation calls fail with errors about
140 // missing buffers.
141 CHECK(!it->tuple->IsPinned()); // Crash OK
142 if (status.ok()) {
143 status = swap_result_or.status();
144 }
145 }
146 }
147 VLOG(4) << "CompactAllocations finished: " << status;
148 return status;
149 }
150
151 // Tries to free size bytes by freeing some unpinned device memory. Returns
152 // the amount of memory which was able to free.
TryFreeMemory(xla::Backend * backend,size_t size)153 xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
154 profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2);
155 auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell());
156 mutex_lock lock(lock_);
157 size_t swapped_size = 0;
158 for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
159 TF_ASSIGN_OR_RETURN(bool swap_result,
160 it->tuple->SwapOut(backend, /*swap_pinned=*/false));
161 if (swap_result) {
162 swapped_size += it->tuple->GetDeviceMemorySize();
163 if (swapped_size >= size) {
164 break;
165 }
166 }
167 }
168 VLOG(3) << "Swapped out " << swapped_size << " bytes";
169 return swapped_size;
170 }
171
172 private:
CreateUid()173 static int64_t CreateUid() {
174 int64_t uid;
175 do {
176 uid = random::New64() & INT64_MAX;
177 } while (uid == InvalidKey());
178 return uid;
179 }
180
181 // We store Alloc records inside an std::list<Alloc> so we can LRU it, and
182 // store the list iterators within the handle map, as list iterators don't get
183 // invalidated by (other elements) removals or position swaps.
184 mutex lock_;
185 AllocList allocs_;
186 std::unordered_map<int64_t, AllocList::iterator> alloc_map_;
187 };
188
WorkingSet(RefPtr<XRTMemoryManager> memory_manager)189 XRTMemoryManager::WorkingSet::WorkingSet(
190 RefPtr<XRTMemoryManager> memory_manager)
191 : memory_manager_(std::move(memory_manager)) {}
192
~WorkingSet()193 XRTMemoryManager::WorkingSet::~WorkingSet() {
194 for (auto& tuple : pinned_tuples_) {
195 tuple->Unpin();
196 }
197 }
198
LookupAndPin(xla::Backend * backend,int64_t handle,se::DeviceMemoryAllocator * allocator)199 Status XRTMemoryManager::WorkingSet::LookupAndPin(
200 xla::Backend* backend, int64_t handle,
201 se::DeviceMemoryAllocator* allocator) {
202 TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
203 TF_RETURN_IF_ERROR(
204 tuple->PinAndSwapIn(memory_manager_.get(), backend, allocator).status());
205 pinned_tuples_.push_back(std::move(tuple));
206 return OkStatus();
207 }
208
Get(ResourceMgr * rm)209 /* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
210 static string* container = new string("XrtState");
211 static string* name = new string("MemoryManager");
212 XRTMemoryManager* memory_manager = nullptr;
213 TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
214 *container, *name, &memory_manager, [](XRTMemoryManager** ret) {
215 *ret = new XRTMemoryManager();
216 return OkStatus();
217 }));
218 return memory_manager;
219 }
220
Register(RefPtr<XRTTupleAllocation> tuple)221 int64_t XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
222 DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
223 /*create_if_missing=*/true);
224 return device_context->Register(std::move(tuple));
225 }
226
Lookup(int64_t handle)227 xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
228 int64_t handle) {
229 int device_ordinal = GetDeviceFromHandle(handle);
230 DeviceContext* device_context = GetDeviceContext(device_ordinal,
231 /*create_if_missing=*/false);
232 if (device_context == nullptr) {
233 return errors::NotFound("XRT memory handle not found: ", handle);
234 }
235 RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
236 if (tuple == nullptr) {
237 return errors::NotFound("XRT memory handle not found: ", handle);
238 }
239 return std::move(tuple);
240 }
241
Release(int64_t handle)242 Status XRTMemoryManager::Release(int64_t handle) {
243 int device_ordinal = GetDeviceFromHandle(handle);
244 DeviceContext* device_context = GetDeviceContext(device_ordinal,
245 /*create_if_missing=*/false);
246 if (device_context == nullptr || !device_context->Release(handle)) {
247 return errors::NotFound("XRT memory handle not found: ", handle);
248 }
249 return OkStatus();
250 }
251
CompactAllocations(xla::Backend * backend,int device_ordinal,se::DeviceMemoryAllocator * allocator)252 Status XRTMemoryManager::CompactAllocations(
253 xla::Backend* backend, int device_ordinal,
254 se::DeviceMemoryAllocator* allocator) {
255 DeviceContext* device_context = GetDeviceContext(device_ordinal,
256 /*create_if_missing=*/false);
257 return device_context != nullptr
258 ? device_context->CompactAllocations(this, backend, allocator)
259 : OkStatus();
260 }
261
ReleaseAllAllocations()262 void XRTMemoryManager::ReleaseAllAllocations() {
263 mutex_lock lock(lock_);
264 for (auto& device_context : device_contexts_) {
265 if (device_context != nullptr) {
266 device_context->Clear();
267 }
268 }
269 }
270
Allocate(xla::Backend * backend,int device_ordinal,size_t size,se::DeviceMemoryAllocator * allocator)271 xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
272 xla::Backend* backend, int device_ordinal, size_t size,
273 se::DeviceMemoryAllocator* allocator) {
274 auto memory_or =
275 allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
276 if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
277 VLOG(4) << "Allocate of " << size << " bytes failed on device "
278 << device_ordinal;
279
280 DeviceContext* device_context =
281 GetDeviceContext(device_ordinal,
282 /*create_if_missing=*/false);
283 if (device_context != nullptr) {
284 Status status = device_context->TryFreeMemory(backend, size).status();
285 if (status.ok()) {
286 // As long as there is no error, we still try again the allocation, even
287 // if the TryFreeMemory() call ended up freeing less memory than the
288 // required size. Fragmentation could make the memory allocation succeed
289 // even if the freed memory is indeed lower.
290 memory_or = allocator->Allocate(device_ordinal, size,
291 /*retry_on_failure=*/false);
292 } else if (status.code() != error::RESOURCE_EXHAUSTED) {
293 VLOG(4) << "Allocate of " << size << " bytes on device "
294 << device_ordinal << ": " << status;
295 return status;
296 }
297 }
298 }
299 return memory_or;
300 }
301
DebugString() const302 string XRTMemoryManager::DebugString() const {
303 // We might want to emit more detailed information here, like per device
304 // memory allocations.
305 return "XRTMemoryManager";
306 }
307
GetDeviceContext(int device_ordinal,bool create_if_missing)308 XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
309 int device_ordinal, bool create_if_missing) {
310 mutex_lock lock(lock_);
311 if (device_ordinal >= device_contexts_.size()) {
312 if (!create_if_missing) {
313 return nullptr;
314 }
315 device_contexts_.resize(device_ordinal + 1);
316 }
317 DeviceContext* device_context = device_contexts_[device_ordinal].get();
318 if (device_context == nullptr && create_if_missing) {
319 device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
320 device_context = device_contexts_[device_ordinal].get();
321 }
322 return device_context;
323 }
324
TryFreeMemoryStep(MemoryReclaimContext * mrctx,const Status & status)325 Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
326 const Status& status) {
327 DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
328 /*create_if_missing=*/false);
329 if (device_context == nullptr) {
330 return status;
331 }
332 if (!mrctx->done_freeing) {
333 // If the caller passed us a zero requested_free_size, we try to free chunks
334 // of kMaxFreeSize memory, until either the run function succeeds, or we run
335 // out of freeable memory.
336 const size_t kMaxFreeSize = 1000000000;
337 size_t free_size =
338 (mrctx->requested_free_size > 0)
339 ? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
340 kMaxFreeSize)
341 : kMaxFreeSize;
342 if (free_size > 0) {
343 auto free_size_or =
344 device_context->TryFreeMemory(mrctx->backend, free_size);
345 if (!free_size_or.ok()) {
346 return status;
347 }
348 size_t size = free_size_or.ValueOrDie();
349 mrctx->free_size += size;
350 if (size > 0) {
351 return OkStatus();
352 }
353 }
354 mrctx->done_freeing = true;
355 }
356 if (!mrctx->done_compacting) {
357 mrctx->done_compacting = true;
358 if (device_context
359 ->CompactAllocations(this, mrctx->backend, mrctx->allocator)
360 .ok()) {
361 return OkStatus();
362 }
363 }
364 return status;
365 }
366
367 } // namespace tensorflow
368