xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/xrt_memory_manager.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xrt/xrt_metrics.h"
24 #include "tensorflow/core/lib/monitoring/timed.h"
25 #include "tensorflow/core/lib/random/random.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 // We use kDeviceBits to store the device ordinal in the handle. We store the
32 // device in the upper part of the int64 handle to make sure the random bits are
33 // in the lower part which is better when storing the handle as a key for
34 // unordered maps.
35 const int kDeviceBits = 12;
36 
MakeDeviceHandle(int64_t device_ordinal,int64_t rnd_value)37 int64_t MakeDeviceHandle(int64_t device_ordinal, int64_t rnd_value) {
38   const int64_t kUidMask = (static_cast<int64_t>(1) << (64 - kDeviceBits)) - 1;
39   return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
40 }
41 
GetDeviceFromHandle(int64_t handle)42 int GetDeviceFromHandle(int64_t handle) {
43   return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
44 }
45 
46 }  // namespace
47 
48 class XRTMemoryManager::DeviceContext {
49   struct Alloc {
Alloctensorflow::XRTMemoryManager::DeviceContext::Alloc50     explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
51         : tuple(std::move(tuple)) {}
52 
53     RefPtr<XRTTupleAllocation> tuple;
54   };
55 
56   using AllocList = std::list<Alloc>;
57 
58  public:
Register(RefPtr<XRTTupleAllocation> tuple)59   int64_t Register(RefPtr<XRTTupleAllocation> tuple) {
60     while (true) {
61       int64_t handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
62       mutex_lock lock(lock_);
63       allocs_.emplace_front(tuple);
64       if (alloc_map_.emplace(handle, allocs_.begin()).second) {
65         return handle;
66       }
67       // The chances of hitting an existing handle are so remote, it is much
68       // more convenient to add to the list before, and eventually removing.
69       allocs_.erase(allocs_.begin());
70     }
71   }
72 
Release(int64_t handle)73   bool Release(int64_t handle) {
74     mutex_lock lock(lock_);
75     auto it = alloc_map_.find(handle);
76     if (it == alloc_map_.end()) {
77       return false;
78     }
79     allocs_.erase(it->second);
80     alloc_map_.erase(it);
81     return true;
82   }
83 
Lookup(int64_t handle)84   RefPtr<XRTTupleAllocation> Lookup(int64_t handle) {
85     mutex_lock lock(lock_);
86     auto it = alloc_map_.find(handle);
87     if (it == alloc_map_.end()) {
88       return nullptr;
89     }
90     // LRU
91     allocs_.splice(allocs_.begin(), allocs_, it->second);
92     return it->second->tuple;
93   }
94 
Clear()95   void Clear() {
96     mutex_lock lock(lock_);
97     alloc_map_.clear();
98     allocs_.clear();
99   }
100 
CompactAllocations(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)101   Status CompactAllocations(XRTMemoryManager* memory_manager,
102                             xla::Backend* backend,
103                             se::DeviceMemoryAllocator* allocator) {
104     profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations",
105                                /*level=*/2);
106     auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell());
107     VLOG(4) << "CompactAllocations started";
108     mutex_lock lock(lock_);
109     Status status;
110     std::vector<AllocList::iterator> swapped;
111     // We are swapping out from the most recently used allocations. This is
112     // desirable since the most recently used will be finding themselves at the
113     // bottom of the allocation space. Since these are more likely to be pinned
114     // allocations, a further trim done by following TryFreeMemory() call will
115     // eventually drop the higher located allocations, with better chance of
116     // reducing fragmentation.
117     // Also, by swapping out the pinned allocations first, those will also be
118     // the first to be restored, and hence if we will ever find OOM on the way
119     // out, we would more likely be swapping in not pinned ones.
120     for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
121       // We are compacting all the allocations, so we will temporarily swap out
122       // even pinned allocations.
123       auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
124       if (!swap_result_or.ok()) {
125         status = swap_result_or.status();
126         break;
127       }
128       if (swap_result_or.ValueOrDie()) {
129         swapped.push_back(it);
130       }
131     }
132     // At this point we have released all the device memory we could release.
133     // Load back the tuple allocations we have swapped out above.
134     for (auto& it : swapped) {
135       auto swap_result_or =
136           it->tuple->SwapIn(memory_manager, backend, allocator);
137       if (!swap_result_or.ok()) {
138         // If we failed to restored a pinned allocation, better to CHECK here
139         // than wondering why XRTTupleAllocation calls fail with errors about
140         // missing buffers.
141         CHECK(!it->tuple->IsPinned());  // Crash OK
142         if (status.ok()) {
143           status = swap_result_or.status();
144         }
145       }
146     }
147     VLOG(4) << "CompactAllocations finished: " << status;
148     return status;
149   }
150 
151   // Tries to free size bytes by freeing some unpinned device memory. Returns
152   // the amount of memory which was able to free.
TryFreeMemory(xla::Backend * backend,size_t size)153   xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
154     profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2);
155     auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell());
156     mutex_lock lock(lock_);
157     size_t swapped_size = 0;
158     for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
159       TF_ASSIGN_OR_RETURN(bool swap_result,
160                           it->tuple->SwapOut(backend, /*swap_pinned=*/false));
161       if (swap_result) {
162         swapped_size += it->tuple->GetDeviceMemorySize();
163         if (swapped_size >= size) {
164           break;
165         }
166       }
167     }
168     VLOG(3) << "Swapped out " << swapped_size << " bytes";
169     return swapped_size;
170   }
171 
172  private:
CreateUid()173   static int64_t CreateUid() {
174     int64_t uid;
175     do {
176       uid = random::New64() & INT64_MAX;
177     } while (uid == InvalidKey());
178     return uid;
179   }
180 
181   // We store Alloc records inside an std::list<Alloc> so we can LRU it, and
182   // store the list iterators within the handle map, as list iterators don't get
183   // invalidated by (other elements) removals or position swaps.
184   mutex lock_;
185   AllocList allocs_;
186   std::unordered_map<int64_t, AllocList::iterator> alloc_map_;
187 };
188 
WorkingSet(RefPtr<XRTMemoryManager> memory_manager)189 XRTMemoryManager::WorkingSet::WorkingSet(
190     RefPtr<XRTMemoryManager> memory_manager)
191     : memory_manager_(std::move(memory_manager)) {}
192 
~WorkingSet()193 XRTMemoryManager::WorkingSet::~WorkingSet() {
194   for (auto& tuple : pinned_tuples_) {
195     tuple->Unpin();
196   }
197 }
198 
LookupAndPin(xla::Backend * backend,int64_t handle,se::DeviceMemoryAllocator * allocator)199 Status XRTMemoryManager::WorkingSet::LookupAndPin(
200     xla::Backend* backend, int64_t handle,
201     se::DeviceMemoryAllocator* allocator) {
202   TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
203   TF_RETURN_IF_ERROR(
204       tuple->PinAndSwapIn(memory_manager_.get(), backend, allocator).status());
205   pinned_tuples_.push_back(std::move(tuple));
206   return OkStatus();
207 }
208 
Get(ResourceMgr * rm)209 /* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
210   static string* container = new string("XrtState");
211   static string* name = new string("MemoryManager");
212   XRTMemoryManager* memory_manager = nullptr;
213   TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
214       *container, *name, &memory_manager, [](XRTMemoryManager** ret) {
215         *ret = new XRTMemoryManager();
216         return OkStatus();
217       }));
218   return memory_manager;
219 }
220 
Register(RefPtr<XRTTupleAllocation> tuple)221 int64_t XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
222   DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
223                                                    /*create_if_missing=*/true);
224   return device_context->Register(std::move(tuple));
225 }
226 
Lookup(int64_t handle)227 xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
228     int64_t handle) {
229   int device_ordinal = GetDeviceFromHandle(handle);
230   DeviceContext* device_context = GetDeviceContext(device_ordinal,
231                                                    /*create_if_missing=*/false);
232   if (device_context == nullptr) {
233     return errors::NotFound("XRT memory handle not found: ", handle);
234   }
235   RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
236   if (tuple == nullptr) {
237     return errors::NotFound("XRT memory handle not found: ", handle);
238   }
239   return std::move(tuple);
240 }
241 
Release(int64_t handle)242 Status XRTMemoryManager::Release(int64_t handle) {
243   int device_ordinal = GetDeviceFromHandle(handle);
244   DeviceContext* device_context = GetDeviceContext(device_ordinal,
245                                                    /*create_if_missing=*/false);
246   if (device_context == nullptr || !device_context->Release(handle)) {
247     return errors::NotFound("XRT memory handle not found: ", handle);
248   }
249   return OkStatus();
250 }
251 
CompactAllocations(xla::Backend * backend,int device_ordinal,se::DeviceMemoryAllocator * allocator)252 Status XRTMemoryManager::CompactAllocations(
253     xla::Backend* backend, int device_ordinal,
254     se::DeviceMemoryAllocator* allocator) {
255   DeviceContext* device_context = GetDeviceContext(device_ordinal,
256                                                    /*create_if_missing=*/false);
257   return device_context != nullptr
258              ? device_context->CompactAllocations(this, backend, allocator)
259              : OkStatus();
260 }
261 
ReleaseAllAllocations()262 void XRTMemoryManager::ReleaseAllAllocations() {
263   mutex_lock lock(lock_);
264   for (auto& device_context : device_contexts_) {
265     if (device_context != nullptr) {
266       device_context->Clear();
267     }
268   }
269 }
270 
Allocate(xla::Backend * backend,int device_ordinal,size_t size,se::DeviceMemoryAllocator * allocator)271 xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
272     xla::Backend* backend, int device_ordinal, size_t size,
273     se::DeviceMemoryAllocator* allocator) {
274   auto memory_or =
275       allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
276   if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
277     VLOG(4) << "Allocate of " << size << " bytes failed on device "
278             << device_ordinal;
279 
280     DeviceContext* device_context =
281         GetDeviceContext(device_ordinal,
282                          /*create_if_missing=*/false);
283     if (device_context != nullptr) {
284       Status status = device_context->TryFreeMemory(backend, size).status();
285       if (status.ok()) {
286         // As long as there is no error, we still try again the allocation, even
287         // if the TryFreeMemory() call ended up freeing less memory than the
288         // required size. Fragmentation could make the memory allocation succeed
289         // even if the freed memory is indeed lower.
290         memory_or = allocator->Allocate(device_ordinal, size,
291                                         /*retry_on_failure=*/false);
292       } else if (status.code() != error::RESOURCE_EXHAUSTED) {
293         VLOG(4) << "Allocate of " << size << " bytes on device "
294                 << device_ordinal << ": " << status;
295         return status;
296       }
297     }
298   }
299   return memory_or;
300 }
301 
DebugString() const302 string XRTMemoryManager::DebugString() const {
303   // We might want to emit more detailed information here, like per device
304   // memory allocations.
305   return "XRTMemoryManager";
306 }
307 
GetDeviceContext(int device_ordinal,bool create_if_missing)308 XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
309     int device_ordinal, bool create_if_missing) {
310   mutex_lock lock(lock_);
311   if (device_ordinal >= device_contexts_.size()) {
312     if (!create_if_missing) {
313       return nullptr;
314     }
315     device_contexts_.resize(device_ordinal + 1);
316   }
317   DeviceContext* device_context = device_contexts_[device_ordinal].get();
318   if (device_context == nullptr && create_if_missing) {
319     device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
320     device_context = device_contexts_[device_ordinal].get();
321   }
322   return device_context;
323 }
324 
TryFreeMemoryStep(MemoryReclaimContext * mrctx,const Status & status)325 Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
326                                            const Status& status) {
327   DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
328                                                    /*create_if_missing=*/false);
329   if (device_context == nullptr) {
330     return status;
331   }
332   if (!mrctx->done_freeing) {
333     // If the caller passed us a zero requested_free_size, we try to free chunks
334     // of kMaxFreeSize memory, until either the run function succeeds, or we run
335     // out of freeable memory.
336     const size_t kMaxFreeSize = 1000000000;
337     size_t free_size =
338         (mrctx->requested_free_size > 0)
339             ? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
340                                kMaxFreeSize)
341             : kMaxFreeSize;
342     if (free_size > 0) {
343       auto free_size_or =
344           device_context->TryFreeMemory(mrctx->backend, free_size);
345       if (!free_size_or.ok()) {
346         return status;
347       }
348       size_t size = free_size_or.ValueOrDie();
349       mrctx->free_size += size;
350       if (size > 0) {
351         return OkStatus();
352       }
353     }
354     mrctx->done_freeing = true;
355   }
356   if (!mrctx->done_compacting) {
357     mrctx->done_compacting = true;
358     if (device_context
359             ->CompactAllocations(this, mrctx->backend, mrctx->allocator)
360             .ok()) {
361       return OkStatus();
362     }
363   }
364   return status;
365 }
366 
367 }  // namespace tensorflow
368