xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/xrt_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for keeping track of on-device state.
17 
18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
20 
21 #include <atomic>
22 #include <functional>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_refptr.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/device_memory_allocator.h"
40 #include "tensorflow/stream_executor/stream_executor.h"
41 
42 namespace tensorflow {
43 
44 // Cannot include xrt_memory_manager.h here, as it needs to include this file.
45 class XRTMemoryManager;
46 
47 // TODO(misard) make this a Tensor if and when that makes sense.
48 // A reference-counted wrapper around a buffer allocation. This maps an XLA
49 // tuple index or a non-tuple XLA shape to a region of device memory. The device
50 // memory buffer is freed when the reference count drops to zero.
51 class XRTBufferAllocation : public core::RefCounted {
52  public:
53   XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
54                       int device_ordinal, se::DeviceMemoryAllocator* allocator);
55   ~XRTBufferAllocation() override;
56 
57   // The region of device memory being wrapped.
58   const se::DeviceMemoryBase& allocation();
59 
DiscardAllocation()60   void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); }
61 
62  private:
63   se::DeviceMemoryBase allocation_;
64   int device_ordinal_;
65   se::DeviceMemoryAllocator* allocator_;
66 };
67 
68 // A XRTTupleAllocation represents an allocated memory area on the device.
69 // New tuples can be created in three ways: by passing a literal in which case
70 // device memory is allocated and the literal is transferred to that memory; by
71 // aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a
72 // vector of existing handles to create a new tuple. The underlying storage is
73 // reference-counted. When a handle is released, the reference count of each
74 // storage buffer is decremented, and buffers with no outstanding references are
75 // freed.
76 class XRTTupleAllocation : public core::RefCounted {
77  public:
78   ~XRTTupleAllocation() override;
79 
80   // Allocates new device memory buffers sufficient to store literal, transfers
81   // literal to that memory, and returns a XRTTupleAllocation handle to the
82   // allocated buffers.
83   static Status CreateAndTransfer(const xla::LiteralBase& literal,
84                                   XRTMemoryManager* memory_manager,
85                                   xla::Backend* backend, int device_ordinal,
86                                   XRTTupleAllocation** allocation,
87                                   se::DeviceMemoryAllocator* allocator);
88 
89   // Allocates new device memory buffers sufficient to store a tensor of
90   // the specified shape, and returns a XRTTupleAllocation handle to the
91   // allocated buffers.  The allocated buffers are not initialized.
92   static Status CreateUninitialized(const xla::Shape& shape,
93                                     XRTMemoryManager* memory_manager,
94                                     xla::Backend* backend, int device_ordinal,
95                                     XRTTupleAllocation** allocation,
96                                     se::DeviceMemoryAllocator* allocator);
97 
98   // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
99   static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
100                                  xla::Backend* backend, int device_ordinal,
101                                  XRTTupleAllocation** allocation,
102                                  se::DeviceMemoryAllocator* allocator);
103 
104   // Same as the CreateFromBuffer() API above, but with the shapes being passed
105   // as input. This API is used when creating tuple allocations with the output
106   // of XLA computations which emit dynamic shaped output via the output shape
107   // table.
108   static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
109                                  const xla::Shape& on_host_shape,
110                                  const xla::Shape& on_device_shape,
111                                  xla::Backend* backend, int device_ordinal,
112                                  XRTTupleAllocation** allocation,
113                                  se::DeviceMemoryAllocator* allocator);
114 
115   // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
116   // to the sub-shape. If alias_base_allocation is true, the buffers in the
117   // sub-shape will be shared between parent and the returned allocation,
118   // otherwise the overlapping buffers in parent will be replaced by
119   // nullptr.
120   static Status MakeSubBuffer(XRTTupleAllocation* parent,
121                               const xla::ShapeIndex& subshape,
122                               XRTTupleAllocation** allocation,
123                               bool alias_parent_allocation);
124 
125   // A structure describing a leaf of a tree of tuples to expand. Each leaf
126   // contains an allocation and indicates whether or not the allocation's handle
127   // should be freed after incorporating its buffers into the expanded tree.
128   struct ExpandedTupleInput {
129     RefPtr<XRTTupleAllocation> allocation;
130     bool release_allocation_after_use;
131   };
132 
133   // Returns a handle to a new tuple where the subtree of the new tuple at an
134   // index corresponding to a leaf of 'elements' is constructed from the
135   // allocation (i.e., a tuple or array) pointed to by that leaf. If
136   // release_allocation_after_use is false at a leaf, the new tuple will alias
137   // the input allocation at that leaf, otherwise the input allocation will be
138   // released. Input allocations may be repeated (appear in more than one leaf)
139   // in which case the corresponding buffers in the output tuple will alias. If
140   // an input is repeated, release_input_handle must be false for every leaf
141   // where that input appears. The latter property is not validated by MakeTuple
142   // and must be enforced by the caller.
143   static Status MakeTuple(XRTMemoryManager* memory_manager,
144                           xla::Backend* backend, int device_ordinal,
145                           const xla::ShapeTree<ExpandedTupleInput>& elements,
146                           XRTTupleAllocation** allocation,
147                           se::DeviceMemoryAllocator* allocator);
148 
149   // Copies the allocation from device to host and returns it in literal.
150   Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal);
151 
152   // Write a new literal value to the allocation.
153   Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
154 
155   // Stores the content of the tuple allocation into the internal literal, and
156   // releases all the device buffers. The swap_pinned flag tells whether a
157   // pinned allocation should be swapped out. It should be false on all cases,
158   // but during the memory compaction operation from the XRTMemoryManager.
159   // Returns a boolean telling whether the allocation was swapped out.
160   xla::StatusOr<bool> SwapOut(xla::Backend* backend, bool swap_pinned);
161 
162   // Allocates the device memory required to store the tuple value held within
163   // the internal literal, and transfer the literal value into the device
164   // memory. Returns a boolean telling whether the allocation was swapped in.
165   xla::StatusOr<bool> SwapIn(XRTMemoryManager* memory_manager,
166                              xla::Backend* backend,
167                              se::DeviceMemoryAllocator* allocator);
168 
169   // Pins the allocation first, then swap it in (if it is not already). After
170   // this API returns, the allocation is pinned and its content on device
171   // memory. The caller is responsible for releasing the pin-count using the
172   // Unpin() API.
173   xla::StatusOr<bool> PinAndSwapIn(XRTMemoryManager* memory_manager,
174                                    xla::Backend* backend,
175                                    se::DeviceMemoryAllocator* allocator);
176 
177   // Checks whether the allocation is currently swapped out.
178   bool IsSwapped() const;
179 
180   // Increases the pin-count of this allocation. If the pin-count is greater
181   // than 0, the allocation cannot be swapped. Returned the pin-count value
182   // before the increase.
183   int64_t Pin();
184 
185   // Decreases the pin-count of this allocation. Returned the pin-count value
186   // before the decrease.
187   int64_t Unpin();
188 
189   // Checks whether the allocation is currently pinned.
190   bool IsPinned() const;
191 
192   // True if none of the buffers in the allocation are aliased by any other live
193   // handle.
194   bool IsExclusiveOwner() const;
195 
196   // Retrieves the footprint in terms of device memory, of this allocation.
197   size_t GetDeviceMemorySize() const;
198 
199   // The ordinal of the device holding this tuple.
200   int device_ordinal() const;
201 
202   // Returns the shape of the tuple as seen by the host.
203   const xla::Shape& on_host_shape() const;
204 
205   // Returns the shape of the tuple as stored on the device.
206   const xla::Shape& on_device_shape() const;
207 
208   // Returns the buffer pointed to by the root of the tuple.
209   const se::DeviceMemoryBase& root_allocation() const;
210 
211   // Stops managing the storage for the allocation at buffer_index, e.g.,
212   // because it has been aliased to the output buffer of a computation.
213   void DiscardAllocation(const xla::ShapeIndex& buffer_index);
214 
215   // Returns the tree of allocations as a ShapedBuffer. This tree may not have
216   // the same shape as on_host_shape.
217   xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer();
218 
219   // Aliases the source buffer at source_index into the current tuple allocation
220   // dest_index.
221   Status AliasBufferFrom(const XRTTupleAllocation& source,
222                          const xla::ShapeIndex& source_index,
223                          const xla::ShapeIndex& dest_index);
224 
225   // Returns the device memory tree of this allocation. If the alias_checker
226   // function returns true for a given index, an owned device memory is returned
227   // to the caller. But the tuple allocation cannot release the ownership in
228   // full, as the execute operation might fail. So we rely on a call to
229   // AliasBufferFrom() to re-alias back the buffers. This is not great (to say
230   // the least), but the current aliasing logic relies on
231   // MaybeOwningDeviceMemory being owned, to detect the fact that the user may
232   // want to alias a buffer. Unfortunately to do that, it needs to release the
233   // ownership, which is a problem if the execute will fail.
234   // This calls for a refactoring of the whole owning/maybe-owning interface to
235   // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr).
236   // We'd need something similar to XRTTupleAllocation instead of
237   // ScopedShapedBuffer, which wants ownership and does not allow sharing.
238   xla::StatusOr<xla::ExecutionInput> ToExecutionInput(
239       const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
240           alias_checker);
241 
242  private:
243   // Creates a new handle with (tuple) shape.
244   XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator,
245                      const xla::Shape& on_host_shape,
246                      const xla::Shape& on_device_shape);
247 
248   // Inherits the allocations represented in buffer, which must have the same
249   // shape as buffers_.
250   void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
251                                   se::DeviceMemoryAllocator* allocator,
252                                   int device_ordinal);
253 
254   // Releases all the XRTBufferAllocation buffer references and set the
255   // corresponding shape tree entry to nullptr.
256   void ReleaseBuffers();
257 
258   // Stores the content of the allocation from device memory to the target host
259   // literal.
260   Status StoreToLiteral(xla::Backend* backend,
261                         xla::MutableLiteralBase* literal);
262 
263   // Sets the total size of the buffers held within this allocation buffers.
264   // This API should be called once when an XRTTupleAllocation object is
265   // created, as the XRTTupleAllocation shapes never change, and hence the
266   // device memory size.
267   void SetDeviceMemorySize();
268 
269   // Takes a tree 'elements' where each leaf is an allocation, validates that
270   // they are all on device_ordinal managed by allocator, and returns in
271   // host_shape and device_shape the host/device shapes of the expanded tree,
272   // where at each leaf of elements the shape of the allocation at elements is
273   // grafted on.
274   static Status ExpandTreeOfTuples(
275       const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
276       se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
277       xla::Shape* device_shape);
278 
279   // The lock which protects the internal operations of the tuple allocation. Is
280   // mutable to allow const-like operations to be declared as such.
281   mutable mutex lock_;
282 
283   // Location of the memory that is being managed.
284   const int device_ordinal_;
285   se::DeviceMemoryAllocator* const allocator_;
286 
287   // The shape that the caller thinks the tuple has.
288   const xla::Shape on_host_shape_;
289   // The shape that the tuple has on device. Store this explicitly instead of
290   // using a shape stored in ShapeTree because ShapeTree discards the layout.
291   const xla::Shape on_device_shape_;
292   // The tree of reference-counted buffers, which uses on_device_shape_ as its
293   // shape.
294   xla::ShapeTree<XRTBufferAllocation*> buffers_;
295   // The footprint of the allocation, when residing on device memory.
296   size_t device_memory_size_ = 0;
297   // If the allocation is swapped out, this is the literal storing its content.
298   std::unique_ptr<xla::Literal> literal_;
299   // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0
300   // then the allocation is pinned.
301   std::atomic<int64_t> pin_count_;
302 };
303 
304 }  // namespace tensorflow
305 
306 #endif  // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
307