xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/device_memory_allocator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/base/thread_annotations.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
26 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
27 #include "tensorflow/compiler/xla/stream_executor/platform.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace stream_executor {
33 
34 class DeviceMemoryAllocator;
35 
36 // Owning pointer for memory on a device.
37 //
38 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
39 // point to memory that resides on a "device" (e.g. a GPU).  When a
40 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
41 //
42 // We say that an instance of ScopedDeviceMemory is "active" if it currently
43 // owns a (possibly empty) slice of memory on the device.  Moving,
44 // Release()'ing, Free()'ing, and other actions can deactive an active object.
45 template <typename ElemT>
46 class ScopedDeviceMemory {
47  public:
48   // Default construction initializes the internal state to nullptr.  This
49   // mirrors the std::unique_ptr<> functionality, where default construction
50   // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()51   ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
52 
53   // Construct a ScopedDeviceMemory from a custom allocator.
54   //
55   // Parameters:
56   //  mem: Already-allocated device memory value for this scoped mechanism to
57   //       deallocate. This memory must have been allocated by parent.
58   //  device_ordinal: Device on which the memory was allocated.
59   //  allocator: Allocator used to deallocate memory when this instance goes
60   //             out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)61   ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
62                      DeviceMemoryAllocator *allocator)
63       : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
64     DCHECK_GE(device_ordinal_, 0);
65   }
66 
67   // A helper constructor to generate a scoped device memory given an already
68   // allocated memory and a stream executor.
69   //
70   // Precondition: memory was allocated by the stream executor `parent`.
71   ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
72 
73   // Constructor overload that places a literal array into device memory.
74   //
75   // Relies on the allocation function exposed by the stream executor `parent`,
76   // which will be also used for deallocating the memory
77   ScopedDeviceMemory(StreamExecutor *parent,
78                      std::initializer_list<ElemT> values);
79 
80   // Moves ownership of the memory from other to the constructed
81   // object.
82   //
83   // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)84   ScopedDeviceMemory(ScopedDeviceMemory &&other)
85       : wrapped_(other.Release()),
86         device_ordinal_(other.device_ordinal_),
87         allocator_(other.allocator_) {}
88 
89   // Releases the memory that was provided in the constructor, through the
90   // "parent" StreamExecutor.
~ScopedDeviceMemory()91   ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
92 
93   // Moves ownership of the memory from other to this object.
94   //
95   // Postcondition: other == nullptr.
96   ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
97     TF_CHECK_OK(Free());
98     wrapped_ = other.Release();
99     allocator_ = other.allocator_;
100     device_ordinal_ = other.device_ordinal_;
101     return *this;
102   }
103 
104   // Returns the memory that backs this scoped allocation converted to
105   // DeviceMemory<T> apparent type. This is useful for cases where the
106   // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
107   // allow copying, for scoped-object-lifetime reasons.
cref()108   const DeviceMemory<ElemT> &cref() const { return wrapped_; }
109 
110   // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
111   // operations. The value returned should not be used outside the scope of this
112   // ScopedDeviceMemory object's lifetime.
ptr()113   DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()114   const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
115 
116   // Smart-pointer-like operators for the wrapped DeviceMemory.
117   // This reference must not be used outside the lifetime of this
118   // ScopedDeviceMemory.
119   const DeviceMemory<ElemT> &operator*() const { return cref(); }
120   DeviceMemory<ElemT> *operator->() { return ptr(); }
121   const DeviceMemory<ElemT> *operator->() const { return ptr(); }
122 
is_null()123   bool is_null() const { return wrapped_.is_null(); }
124   bool operator==(std::nullptr_t other) const { return is_null(); }
125   bool operator!=(std::nullptr_t other) const { return !is_null(); }
126 
127   // Analogous to std::unique_ptr::release, releases ownership of the held
128   // memory and transfers it to the caller.
129   //
130   // Postcondition: *this == nullptr
Release()131   DeviceMemory<ElemT> Release() {
132     DeviceMemory<ElemT> tmp = wrapped_;
133     wrapped_ = DeviceMemory<ElemT>{};
134     return tmp;
135   }
136 
137   // The returned allocator is nonnull iff this object is active.
allocator()138   DeviceMemoryAllocator *allocator() const { return allocator_; }
139 
device_ordinal()140   int device_ordinal() const { return device_ordinal_; }
141 
142   // Frees the existing memory, resets the wrapped memory to null.
143   port::Status Free();
144 
145  private:
146   DeviceMemory<ElemT> wrapped_;       // Value we wrap with scoped-release.
147   int device_ordinal_;                // Negative one for inactive object.
148   DeviceMemoryAllocator *allocator_;  // Null if this object is inactive.
149 
150   SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
151 };
152 
153 // Type alias for compatibility with the previous managed memory implementation.
154 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
155 
156 // Memory allocator interface for the device.
157 //
158 // Intended usage is through Allocate() functions which return an owning smart
159 // pointer.
160 class DeviceMemoryAllocator {
161  public:
162   // Parameter platform indicates which platform the allocator allocates memory
163   // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)164   explicit DeviceMemoryAllocator(const Platform *platform)
165       : platform_(platform) {}
~DeviceMemoryAllocator()166   virtual ~DeviceMemoryAllocator() {}
167 
168   // Allocates memory on the device.
169   //
170   // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
171   // must not be null.  If size == 0, must return a null OwningDeviceMemory.
172   //
173   // 'retry_on_failure': If false, and the first attempt to allocate the memory
174   // fails, the allocation should return immediately without retrying.  An
175   // example use case is optional scratch spaces where a failure has only
176   // performance impact.
177   virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
178                                                       uint64_t size,
179                                                       bool retry_on_failure,
180                                                       int64_t memory_space) = 0;
181 
182   // Two-arg version of Allocate(), which sets retry-on-failure to true and
183   // memory_space to default (0).
184   //
185   // (We don't simply use a default argument on the virtual Allocate function
186   // because default args on virtual functions are disallowed by the Google
187   // style guide.)
Allocate(int device_ordinal,uint64_t size)188   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
189                                               uint64_t size) {
190     return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
191                     /*memory_space=*/0);
192   }
193 
194   // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure)195   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
196                                               bool retry_on_failure) {
197     return Allocate(device_ordinal, size, retry_on_failure,
198                     /*memory_space=*/0);
199   }
200 
201   // Typed version of the allocation, returning typed memory.
202   template <typename ElemT>
203   port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
204       int device_ordinal, uint64_t size, bool retry_on_failure = true,
205       int64_t memory_space = 0) {
206     return Allocate(device_ordinal, size, retry_on_failure, memory_space);
207   }
208 
209   // Must be a nop for null pointers. Should not be used.
210   //
211   // TODO(cheshire): Add deprecation notice.
212   virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
213 
214   // Return the platform that the allocator allocates memory on.
platform()215   const Platform *platform() const { return platform_; }
216 
217   // Can we call Deallocate() as soon as a computation has been scheduled on
218   // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()219   virtual bool AllowsAsynchronousDeallocation() const { return false; }
220 
221   // Returns a stream pointer on which it is always safe to access memory
222   // allocated by this allocator. It is not necessary to use the returned stream
223   // though, as clients may have additional information letting them safely use
224   // a different stream.
225   virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
226 
227  protected:
228   const Platform *platform_;
229 };
230 
231 // Default memory allocator for a platform which uses
232 // StreamExecutor::Allocate/Deallocate.
233 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
234  public:
235   // Create an allocator supporting a single device, corresponding to the passed
236   // executor.
237   explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
238 
239   // Create an allocator supporting multiple stream executors.
240   //
241   // Precondition: all stream_executors have different device ordinals.
242   StreamExecutorMemoryAllocator(
243       const Platform *platform,
244       absl::Span<StreamExecutor *const> stream_executors);
245 
246   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
247                                               bool retry_on_failure,
248                                               int64_t memory_space) override;
249 
250   // Pull in two-arg overload that sets retry_on_failure to true.
251   using DeviceMemoryAllocator::Allocate;
252 
253   port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
254 
255   bool AllowsAsynchronousDeallocation() const override;
256 
257   // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
258   // stream executor.
259   port::StatusOr<Stream *> GetStream(int device_ordinal) override;
260 
261   // Gets the stream executor for given device ordinal.
262   port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
263 
264  private:
265   // Available stream executors. Each stream executor has a different device
266   // ordinal.
267   std::vector<StreamExecutor *> stream_executors_;
268 
269   absl::Mutex mutex_;
270 
271   // Cache of streams for GetStream.
272   std::map<int, Stream> streams_ ABSL_GUARDED_BY(mutex_);
273 };
274 
275 template <typename ElemT>
Free()276 port::Status ScopedDeviceMemory<ElemT>::Free() {
277   if (!wrapped_.is_null()) {
278     CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
279     TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
280   }
281   wrapped_ = DeviceMemory<ElemT>{};
282   return ::tensorflow::OkStatus();
283 }
284 
285 }  // namespace stream_executor
286 
287 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_
288