1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
18
19 #include <map>
20 #include <vector>
21
22 #include "absl/base/thread_annotations.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
26 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
27 #include "tensorflow/compiler/xla/stream_executor/platform.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace stream_executor {
33
34 class DeviceMemoryAllocator;
35
36 // Owning pointer for memory on a device.
37 //
38 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
39 // point to memory that resides on a "device" (e.g. a GPU). When a
40 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
41 //
42 // We say that an instance of ScopedDeviceMemory is "active" if it currently
43 // owns a (possibly empty) slice of memory on the device. Moving,
44 // Release()'ing, Free()'ing, and other actions can deactive an active object.
45 template <typename ElemT>
46 class ScopedDeviceMemory {
47 public:
48 // Default construction initializes the internal state to nullptr. This
49 // mirrors the std::unique_ptr<> functionality, where default construction
50 // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()51 ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
52
53 // Construct a ScopedDeviceMemory from a custom allocator.
54 //
55 // Parameters:
56 // mem: Already-allocated device memory value for this scoped mechanism to
57 // deallocate. This memory must have been allocated by parent.
58 // device_ordinal: Device on which the memory was allocated.
59 // allocator: Allocator used to deallocate memory when this instance goes
60 // out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)61 ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
62 DeviceMemoryAllocator *allocator)
63 : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
64 DCHECK_GE(device_ordinal_, 0);
65 }
66
67 // A helper constructor to generate a scoped device memory given an already
68 // allocated memory and a stream executor.
69 //
70 // Precondition: memory was allocated by the stream executor `parent`.
71 ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
72
73 // Constructor overload that places a literal array into device memory.
74 //
75 // Relies on the allocation function exposed by the stream executor `parent`,
76 // which will be also used for deallocating the memory
77 ScopedDeviceMemory(StreamExecutor *parent,
78 std::initializer_list<ElemT> values);
79
80 // Moves ownership of the memory from other to the constructed
81 // object.
82 //
83 // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)84 ScopedDeviceMemory(ScopedDeviceMemory &&other)
85 : wrapped_(other.Release()),
86 device_ordinal_(other.device_ordinal_),
87 allocator_(other.allocator_) {}
88
89 // Releases the memory that was provided in the constructor, through the
90 // "parent" StreamExecutor.
~ScopedDeviceMemory()91 ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
92
93 // Moves ownership of the memory from other to this object.
94 //
95 // Postcondition: other == nullptr.
96 ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
97 TF_CHECK_OK(Free());
98 wrapped_ = other.Release();
99 allocator_ = other.allocator_;
100 device_ordinal_ = other.device_ordinal_;
101 return *this;
102 }
103
104 // Returns the memory that backs this scoped allocation converted to
105 // DeviceMemory<T> apparent type. This is useful for cases where the
106 // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
107 // allow copying, for scoped-object-lifetime reasons.
cref()108 const DeviceMemory<ElemT> &cref() const { return wrapped_; }
109
110 // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
111 // operations. The value returned should not be used outside the scope of this
112 // ScopedDeviceMemory object's lifetime.
ptr()113 DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()114 const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
115
116 // Smart-pointer-like operators for the wrapped DeviceMemory.
117 // This reference must not be used outside the lifetime of this
118 // ScopedDeviceMemory.
119 const DeviceMemory<ElemT> &operator*() const { return cref(); }
120 DeviceMemory<ElemT> *operator->() { return ptr(); }
121 const DeviceMemory<ElemT> *operator->() const { return ptr(); }
122
is_null()123 bool is_null() const { return wrapped_.is_null(); }
124 bool operator==(std::nullptr_t other) const { return is_null(); }
125 bool operator!=(std::nullptr_t other) const { return !is_null(); }
126
127 // Analogous to std::unique_ptr::release, releases ownership of the held
128 // memory and transfers it to the caller.
129 //
130 // Postcondition: *this == nullptr
Release()131 DeviceMemory<ElemT> Release() {
132 DeviceMemory<ElemT> tmp = wrapped_;
133 wrapped_ = DeviceMemory<ElemT>{};
134 return tmp;
135 }
136
137 // The returned allocator is nonnull iff this object is active.
allocator()138 DeviceMemoryAllocator *allocator() const { return allocator_; }
139
device_ordinal()140 int device_ordinal() const { return device_ordinal_; }
141
142 // Frees the existing memory, resets the wrapped memory to null.
143 port::Status Free();
144
145 private:
146 DeviceMemory<ElemT> wrapped_; // Value we wrap with scoped-release.
147 int device_ordinal_; // Negative one for inactive object.
148 DeviceMemoryAllocator *allocator_; // Null if this object is inactive.
149
150 SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
151 };
152
153 // Type alias for compatibility with the previous managed memory implementation.
154 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
155
156 // Memory allocator interface for the device.
157 //
158 // Intended usage is through Allocate() functions which return an owning smart
159 // pointer.
160 class DeviceMemoryAllocator {
161 public:
162 // Parameter platform indicates which platform the allocator allocates memory
163 // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)164 explicit DeviceMemoryAllocator(const Platform *platform)
165 : platform_(platform) {}
~DeviceMemoryAllocator()166 virtual ~DeviceMemoryAllocator() {}
167
168 // Allocates memory on the device.
169 //
170 // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
171 // must not be null. If size == 0, must return a null OwningDeviceMemory.
172 //
173 // 'retry_on_failure': If false, and the first attempt to allocate the memory
174 // fails, the allocation should return immediately without retrying. An
175 // example use case is optional scratch spaces where a failure has only
176 // performance impact.
177 virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
178 uint64_t size,
179 bool retry_on_failure,
180 int64_t memory_space) = 0;
181
182 // Two-arg version of Allocate(), which sets retry-on-failure to true and
183 // memory_space to default (0).
184 //
185 // (We don't simply use a default argument on the virtual Allocate function
186 // because default args on virtual functions are disallowed by the Google
187 // style guide.)
Allocate(int device_ordinal,uint64_t size)188 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
189 uint64_t size) {
190 return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
191 /*memory_space=*/0);
192 }
193
194 // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure)195 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
196 bool retry_on_failure) {
197 return Allocate(device_ordinal, size, retry_on_failure,
198 /*memory_space=*/0);
199 }
200
201 // Typed version of the allocation, returning typed memory.
202 template <typename ElemT>
203 port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
204 int device_ordinal, uint64_t size, bool retry_on_failure = true,
205 int64_t memory_space = 0) {
206 return Allocate(device_ordinal, size, retry_on_failure, memory_space);
207 }
208
209 // Must be a nop for null pointers. Should not be used.
210 //
211 // TODO(cheshire): Add deprecation notice.
212 virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
213
214 // Return the platform that the allocator allocates memory on.
platform()215 const Platform *platform() const { return platform_; }
216
217 // Can we call Deallocate() as soon as a computation has been scheduled on
218 // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()219 virtual bool AllowsAsynchronousDeallocation() const { return false; }
220
221 // Returns a stream pointer on which it is always safe to access memory
222 // allocated by this allocator. It is not necessary to use the returned stream
223 // though, as clients may have additional information letting them safely use
224 // a different stream.
225 virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
226
227 protected:
228 const Platform *platform_;
229 };
230
231 // Default memory allocator for a platform which uses
232 // StreamExecutor::Allocate/Deallocate.
233 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
234 public:
235 // Create an allocator supporting a single device, corresponding to the passed
236 // executor.
237 explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
238
239 // Create an allocator supporting multiple stream executors.
240 //
241 // Precondition: all stream_executors have different device ordinals.
242 StreamExecutorMemoryAllocator(
243 const Platform *platform,
244 absl::Span<StreamExecutor *const> stream_executors);
245
246 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
247 bool retry_on_failure,
248 int64_t memory_space) override;
249
250 // Pull in two-arg overload that sets retry_on_failure to true.
251 using DeviceMemoryAllocator::Allocate;
252
253 port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
254
255 bool AllowsAsynchronousDeallocation() const override;
256
257 // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
258 // stream executor.
259 port::StatusOr<Stream *> GetStream(int device_ordinal) override;
260
261 // Gets the stream executor for given device ordinal.
262 port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
263
264 private:
265 // Available stream executors. Each stream executor has a different device
266 // ordinal.
267 std::vector<StreamExecutor *> stream_executors_;
268
269 absl::Mutex mutex_;
270
271 // Cache of streams for GetStream.
272 std::map<int, Stream> streams_ ABSL_GUARDED_BY(mutex_);
273 };
274
275 template <typename ElemT>
Free()276 port::Status ScopedDeviceMemory<ElemT>::Free() {
277 if (!wrapped_.is_null()) {
278 CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
279 TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
280 }
281 wrapped_ = DeviceMemory<ElemT>{};
282 return ::tensorflow::OkStatus();
283 }
284
285 } // namespace stream_executor
286
287 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
288