1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/service/backend.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/compiler/xrt/xrt_refptr.h"
26 #include "tensorflow/compiler/xrt/xrt_state.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/stream_executor/device_memory_allocator.h"
34 #include "tensorflow/stream_executor/stream_executor.h"
35
36 namespace tensorflow {
37
38 // The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
39 // object which leaves within the ResourceMgr. This is only one XRT memory
40 // manager object within the ResourceMgr container.
41 class XRTMemoryManager : public ResourceBase {
42 // The DeviceContext class, defined and implemented locally inside the
43 // xrt_memory_manager.cc file, holds, for each device, all the information
44 // related to the XRT memory management for such device.
45 class DeviceContext;
46
47 public:
48 // A working set is a set of tuple allocations which are the input of a given
49 // operation, and as such they must be pinned on the device memory. The tuple
50 // allocations added to the WorkingSet will be unpinned at object destruction.
51 class WorkingSet {
52 public:
53 explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
54
55 ~WorkingSet();
56
57 // Looks up the tuple handle within the memory manager, and pins it to the
58 // device (if not already pinned).
59 Status LookupAndPin(xla::Backend* backend, int64_t handle,
60 se::DeviceMemoryAllocator* allocator);
61
PinnedTuples()62 const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
63 return pinned_tuples_;
64 }
65
MemoryManager()66 const RefPtr<XRTMemoryManager>& MemoryManager() const {
67 return memory_manager_;
68 }
69
70 private:
71 RefPtr<XRTMemoryManager> memory_manager_;
72 std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
73 };
74
75 // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
76 static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
77
78 // Registers an XRTTupleAllocation and returns the unique handle identifying
79 // it.
80 int64_t Register(RefPtr<XRTTupleAllocation> tuple);
81
82 // Looks up an handle returned by the Register() API and returns the
83 // XRTTupleAllocation behind it.
84 xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64_t handle);
85
Lookup(int64_t handle,RefPtr<XRTTupleAllocation> * tuple)86 Status Lookup(int64_t handle, RefPtr<XRTTupleAllocation>* tuple) {
87 TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
88 return OkStatus();
89 }
90
91 // Releases an handle by dropping the references count held on the
92 // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
93 // references will continue to be valid.
94 Status Release(int64_t handle);
95
96 // Tries to compact all the memory allocations on a given device. This is
97 // currently done by swapping-out all the existing allocation, and swapping
98 // them back in.
99 Status CompactAllocations(xla::Backend* backend, int device_ordinal,
100 se::DeviceMemoryAllocator* allocator);
101
102 // Releases all the device memory allocated by XRT within the resource
103 // manager.
104 void ReleaseAllAllocations();
105
106 // Tries to allocate size bytes of device memory from the device_ordinal
107 // device. Might attempt to free some unpinned device memory, if the underline
108 // allocator call fails, and try the allocation again.
109 xla::StatusOr<se::OwningDeviceMemory> Allocate(
110 xla::Backend* backend, int device_ordinal, size_t size,
111 se::DeviceMemoryAllocator* allocator);
112
113 // Runs the specified function and handling the error::RESOURCE_EXHAUSTED
114 // status code coming out of it. In such cases, we run different memory
115 // freeing operations trying to make runfn succeed. The requested_free_size
116 // argument represents an hint of the requested memory size which would make
117 // runfn succeed.
118 template <typename T>
119 xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
120 xla::Backend* backend, int device_ordinal,
121 size_t requested_free_size,
122 se::DeviceMemoryAllocator* allocator);
123
124 string DebugString() const override;
125
126 // Returns the invalid key value, which will be never generated by the
127 // Intern() API.
InvalidKey()128 static int64_t InvalidKey() { return 0; }
129
130 private:
131 // Structure used to track the progress of a try-to-free operation. It is
132 // initialized and the passed to the TryFreeMemoryStep() API.
133 struct MemoryReclaimContext {
MemoryReclaimContextMemoryReclaimContext134 MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
135 size_t requested_free_size,
136 se::DeviceMemoryAllocator* specific_allocator)
137 : backend(backend),
138 device_ordinal(device_ordinal),
139 requested_free_size(requested_free_size) {
140 allocator = specific_allocator;
141 }
142
143 xla::Backend* const backend = nullptr;
144 se::DeviceMemoryAllocator* allocator = nullptr;
145 const int device_ordinal = 0;
146 const size_t requested_free_size = 0;
147 size_t free_size = 0;
148 bool done_freeing = false;
149 bool done_compacting = false;
150 };
151
152 DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
153
154 // Called multiple times while trying to make a memory consuming function call
155 // to fit. Performs progressively more expensive memory reduction operations,
156 // until returning error::RESOURCE_EXHAUSTED when no further reductions are
157 // possible.
158 Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
159
160 mutex lock_;
161 std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
162 };
163
164 template <typename T>
Run(const std::function<xla::StatusOr<T> ()> & runfn,xla::Backend * backend,int device_ordinal,size_t requested_free_size,se::DeviceMemoryAllocator * allocator)165 xla::StatusOr<T> XRTMemoryManager::Run(
166 const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
167 int device_ordinal, size_t requested_free_size,
168 se::DeviceMemoryAllocator* allocator) {
169 MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size,
170 allocator);
171 while (true) {
172 // We assume that runfn is a relatively fast-fail function compared to the
173 // operations required to free up the required memory. Here we call into the
174 // TryFreeMemoryStep() API multiple times, which will run progressively more
175 // expensive operations.
176 auto result_or = runfn();
177 if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
178 return result_or;
179 }
180 TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
181 }
182 }
183
184 } // namespace tensorflow
185
186 #endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
187