xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/xrt_memory_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/service/backend.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/compiler/xrt/xrt_refptr.h"
26 #include "tensorflow/compiler/xrt/xrt_state.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/stream_executor/device_memory_allocator.h"
34 #include "tensorflow/stream_executor/stream_executor.h"
35 
36 namespace tensorflow {
37 
38 // The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
39 // object which leaves within the ResourceMgr. This is only one XRT memory
40 // manager object within the ResourceMgr container.
41 class XRTMemoryManager : public ResourceBase {
42   // The DeviceContext class, defined and implemented locally inside the
43   // xrt_memory_manager.cc file, holds, for each device, all the information
44   // related to the XRT memory management for such device.
45   class DeviceContext;
46 
47  public:
48   // A working set is a set of tuple allocations which are the input of a given
49   // operation, and as such they must be pinned on the device memory. The tuple
50   // allocations added to the WorkingSet will be unpinned at object destruction.
51   class WorkingSet {
52    public:
53     explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
54 
55     ~WorkingSet();
56 
57     // Looks up the tuple handle within the memory manager, and pins it to the
58     // device (if not already pinned).
59     Status LookupAndPin(xla::Backend* backend, int64_t handle,
60                         se::DeviceMemoryAllocator* allocator);
61 
PinnedTuples()62     const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
63       return pinned_tuples_;
64     }
65 
MemoryManager()66     const RefPtr<XRTMemoryManager>& MemoryManager() const {
67       return memory_manager_;
68     }
69 
70    private:
71     RefPtr<XRTMemoryManager> memory_manager_;
72     std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
73   };
74 
75   // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
76   static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
77 
78   // Registers an XRTTupleAllocation and returns the unique handle identifying
79   // it.
80   int64_t Register(RefPtr<XRTTupleAllocation> tuple);
81 
82   // Looks up an handle returned by the Register() API and returns the
83   // XRTTupleAllocation behind it.
84   xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64_t handle);
85 
Lookup(int64_t handle,RefPtr<XRTTupleAllocation> * tuple)86   Status Lookup(int64_t handle, RefPtr<XRTTupleAllocation>* tuple) {
87     TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
88     return OkStatus();
89   }
90 
91   // Releases an handle by dropping the references count held on the
92   // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
93   // references will continue to be valid.
94   Status Release(int64_t handle);
95 
96   // Tries to compact all the memory allocations on a given device. This is
97   // currently done by swapping-out all the existing allocation, and swapping
98   // them back in.
99   Status CompactAllocations(xla::Backend* backend, int device_ordinal,
100                             se::DeviceMemoryAllocator* allocator);
101 
102   // Releases all the device memory allocated by XRT within the resource
103   // manager.
104   void ReleaseAllAllocations();
105 
106   // Tries to allocate size bytes of device memory from the device_ordinal
107   // device. Might attempt to free some unpinned device memory, if the underline
108   // allocator call fails, and try the allocation again.
109   xla::StatusOr<se::OwningDeviceMemory> Allocate(
110       xla::Backend* backend, int device_ordinal, size_t size,
111       se::DeviceMemoryAllocator* allocator);
112 
113   // Runs the specified function and handling the error::RESOURCE_EXHAUSTED
114   // status code coming out of it. In such cases, we run different memory
115   // freeing operations trying to make runfn succeed. The requested_free_size
116   // argument represents an hint of the requested memory size which would make
117   // runfn succeed.
118   template <typename T>
119   xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
120                        xla::Backend* backend, int device_ordinal,
121                        size_t requested_free_size,
122                        se::DeviceMemoryAllocator* allocator);
123 
124   string DebugString() const override;
125 
126   // Returns the invalid key value, which will be never generated by the
127   // Intern() API.
InvalidKey()128   static int64_t InvalidKey() { return 0; }
129 
130  private:
131   // Structure used to track the progress of a try-to-free operation. It is
132   // initialized and the passed to the TryFreeMemoryStep() API.
133   struct MemoryReclaimContext {
MemoryReclaimContextMemoryReclaimContext134     MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
135                          size_t requested_free_size,
136                          se::DeviceMemoryAllocator* specific_allocator)
137         : backend(backend),
138           device_ordinal(device_ordinal),
139           requested_free_size(requested_free_size) {
140       allocator = specific_allocator;
141     }
142 
143     xla::Backend* const backend = nullptr;
144     se::DeviceMemoryAllocator* allocator = nullptr;
145     const int device_ordinal = 0;
146     const size_t requested_free_size = 0;
147     size_t free_size = 0;
148     bool done_freeing = false;
149     bool done_compacting = false;
150   };
151 
152   DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
153 
154   // Called multiple times while trying to make a memory consuming function call
155   // to fit. Performs progressively more expensive memory reduction operations,
156   // until returning error::RESOURCE_EXHAUSTED when no further reductions are
157   // possible.
158   Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
159 
160   mutex lock_;
161   std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
162 };
163 
164 template <typename T>
Run(const std::function<xla::StatusOr<T> ()> & runfn,xla::Backend * backend,int device_ordinal,size_t requested_free_size,se::DeviceMemoryAllocator * allocator)165 xla::StatusOr<T> XRTMemoryManager::Run(
166     const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
167     int device_ordinal, size_t requested_free_size,
168     se::DeviceMemoryAllocator* allocator) {
169   MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size,
170                              allocator);
171   while (true) {
172     // We assume that runfn is a relatively fast-fail function compared to the
173     // operations required to free up the required memory. Here we call into the
174     // TryFreeMemoryStep() API multiple times, which will run progressively more
175     // expensive operations.
176     auto result_or = runfn();
177     if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
178       return result_or;
179     }
180     TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
181   }
182 }
183 
184 }  // namespace tensorflow
185 
186 #endif  // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
187