1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for keeping track of on-device state. 17 18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 20 21 #include <atomic> 22 #include <functional> 23 #include <memory> 24 #include <string> 25 #include <vector> 26 27 #include "tensorflow/compiler/xla/literal.h" 28 #include "tensorflow/compiler/xla/service/backend.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/compiler/xrt/xrt_refptr.h" 35 #include "tensorflow/core/lib/core/refcount.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/stream_executor/device_memory_allocator.h" 40 #include "tensorflow/stream_executor/stream_executor.h" 41 42 namespace tensorflow { 43 44 // Cannot include xrt_memory_manager.h here, as it needs to include this file. 45 class XRTMemoryManager; 46 47 // TODO(misard) make this a Tensor if and when that makes sense. 48 // A reference-counted wrapper around a buffer allocation. This maps an XLA 49 // tuple index or a non-tuple XLA shape to a region of device memory. The device 50 // memory buffer is freed when the reference count drops to zero. 51 class XRTBufferAllocation : public core::RefCounted { 52 public: 53 XRTBufferAllocation(const se::DeviceMemoryBase& allocation, 54 int device_ordinal, se::DeviceMemoryAllocator* allocator); 55 ~XRTBufferAllocation() override; 56 57 // The region of device memory being wrapped. 58 const se::DeviceMemoryBase& allocation(); 59 DiscardAllocation()60 void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); } 61 62 private: 63 se::DeviceMemoryBase allocation_; 64 int device_ordinal_; 65 se::DeviceMemoryAllocator* allocator_; 66 }; 67 68 // A XRTTupleAllocation represents an allocated memory area on the device. 69 // New tuples can be created in three ways: by passing a literal in which case 70 // device memory is allocated and the literal is transferred to that memory; by 71 // aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a 72 // vector of existing handles to create a new tuple. The underlying storage is 73 // reference-counted. When a handle is released, the reference count of each 74 // storage buffer is decremented, and buffers with no outstanding references are 75 // freed. 76 class XRTTupleAllocation : public core::RefCounted { 77 public: 78 ~XRTTupleAllocation() override; 79 80 // Allocates new device memory buffers sufficient to store literal, transfers 81 // literal to that memory, and returns a XRTTupleAllocation handle to the 82 // allocated buffers. 83 static Status CreateAndTransfer(const xla::LiteralBase& literal, 84 XRTMemoryManager* memory_manager, 85 xla::Backend* backend, int device_ordinal, 86 XRTTupleAllocation** allocation, 87 se::DeviceMemoryAllocator* allocator); 88 89 // Allocates new device memory buffers sufficient to store a tensor of 90 // the specified shape, and returns a XRTTupleAllocation handle to the 91 // allocated buffers. The allocated buffers are not initialized. 92 static Status CreateUninitialized(const xla::Shape& shape, 93 XRTMemoryManager* memory_manager, 94 xla::Backend* backend, int device_ordinal, 95 XRTTupleAllocation** allocation, 96 se::DeviceMemoryAllocator* allocator); 97 98 // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. 99 static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, 100 xla::Backend* backend, int device_ordinal, 101 XRTTupleAllocation** allocation, 102 se::DeviceMemoryAllocator* allocator); 103 104 // Same as the CreateFromBuffer() API above, but with the shapes being passed 105 // as input. This API is used when creating tuple allocations with the output 106 // of XLA computations which emit dynamic shaped output via the output shape 107 // table. 108 static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, 109 const xla::Shape& on_host_shape, 110 const xla::Shape& on_device_shape, 111 xla::Backend* backend, int device_ordinal, 112 XRTTupleAllocation** allocation, 113 se::DeviceMemoryAllocator* allocator); 114 115 // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle 116 // to the sub-shape. If alias_base_allocation is true, the buffers in the 117 // sub-shape will be shared between parent and the returned allocation, 118 // otherwise the overlapping buffers in parent will be replaced by 119 // nullptr. 120 static Status MakeSubBuffer(XRTTupleAllocation* parent, 121 const xla::ShapeIndex& subshape, 122 XRTTupleAllocation** allocation, 123 bool alias_parent_allocation); 124 125 // A structure describing a leaf of a tree of tuples to expand. Each leaf 126 // contains an allocation and indicates whether or not the allocation's handle 127 // should be freed after incorporating its buffers into the expanded tree. 128 struct ExpandedTupleInput { 129 RefPtr<XRTTupleAllocation> allocation; 130 bool release_allocation_after_use; 131 }; 132 133 // Returns a handle to a new tuple where the subtree of the new tuple at an 134 // index corresponding to a leaf of 'elements' is constructed from the 135 // allocation (i.e., a tuple or array) pointed to by that leaf. If 136 // release_allocation_after_use is false at a leaf, the new tuple will alias 137 // the input allocation at that leaf, otherwise the input allocation will be 138 // released. Input allocations may be repeated (appear in more than one leaf) 139 // in which case the corresponding buffers in the output tuple will alias. If 140 // an input is repeated, release_input_handle must be false for every leaf 141 // where that input appears. The latter property is not validated by MakeTuple 142 // and must be enforced by the caller. 143 static Status MakeTuple(XRTMemoryManager* memory_manager, 144 xla::Backend* backend, int device_ordinal, 145 const xla::ShapeTree<ExpandedTupleInput>& elements, 146 XRTTupleAllocation** allocation, 147 se::DeviceMemoryAllocator* allocator); 148 149 // Copies the allocation from device to host and returns it in literal. 150 Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal); 151 152 // Write a new literal value to the allocation. 153 Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); 154 155 // Stores the content of the tuple allocation into the internal literal, and 156 // releases all the device buffers. The swap_pinned flag tells whether a 157 // pinned allocation should be swapped out. It should be false on all cases, 158 // but during the memory compaction operation from the XRTMemoryManager. 159 // Returns a boolean telling whether the allocation was swapped out. 160 xla::StatusOr<bool> SwapOut(xla::Backend* backend, bool swap_pinned); 161 162 // Allocates the device memory required to store the tuple value held within 163 // the internal literal, and transfer the literal value into the device 164 // memory. Returns a boolean telling whether the allocation was swapped in. 165 xla::StatusOr<bool> SwapIn(XRTMemoryManager* memory_manager, 166 xla::Backend* backend, 167 se::DeviceMemoryAllocator* allocator); 168 169 // Pins the allocation first, then swap it in (if it is not already). After 170 // this API returns, the allocation is pinned and its content on device 171 // memory. The caller is responsible for releasing the pin-count using the 172 // Unpin() API. 173 xla::StatusOr<bool> PinAndSwapIn(XRTMemoryManager* memory_manager, 174 xla::Backend* backend, 175 se::DeviceMemoryAllocator* allocator); 176 177 // Checks whether the allocation is currently swapped out. 178 bool IsSwapped() const; 179 180 // Increases the pin-count of this allocation. If the pin-count is greater 181 // than 0, the allocation cannot be swapped. Returned the pin-count value 182 // before the increase. 183 int64_t Pin(); 184 185 // Decreases the pin-count of this allocation. Returned the pin-count value 186 // before the decrease. 187 int64_t Unpin(); 188 189 // Checks whether the allocation is currently pinned. 190 bool IsPinned() const; 191 192 // True if none of the buffers in the allocation are aliased by any other live 193 // handle. 194 bool IsExclusiveOwner() const; 195 196 // Retrieves the footprint in terms of device memory, of this allocation. 197 size_t GetDeviceMemorySize() const; 198 199 // The ordinal of the device holding this tuple. 200 int device_ordinal() const; 201 202 // Returns the shape of the tuple as seen by the host. 203 const xla::Shape& on_host_shape() const; 204 205 // Returns the shape of the tuple as stored on the device. 206 const xla::Shape& on_device_shape() const; 207 208 // Returns the buffer pointed to by the root of the tuple. 209 const se::DeviceMemoryBase& root_allocation() const; 210 211 // Stops managing the storage for the allocation at buffer_index, e.g., 212 // because it has been aliased to the output buffer of a computation. 213 void DiscardAllocation(const xla::ShapeIndex& buffer_index); 214 215 // Returns the tree of allocations as a ShapedBuffer. This tree may not have 216 // the same shape as on_host_shape. 217 xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(); 218 219 // Aliases the source buffer at source_index into the current tuple allocation 220 // dest_index. 221 Status AliasBufferFrom(const XRTTupleAllocation& source, 222 const xla::ShapeIndex& source_index, 223 const xla::ShapeIndex& dest_index); 224 225 // Returns the device memory tree of this allocation. If the alias_checker 226 // function returns true for a given index, an owned device memory is returned 227 // to the caller. But the tuple allocation cannot release the ownership in 228 // full, as the execute operation might fail. So we rely on a call to 229 // AliasBufferFrom() to re-alias back the buffers. This is not great (to say 230 // the least), but the current aliasing logic relies on 231 // MaybeOwningDeviceMemory being owned, to detect the fact that the user may 232 // want to alias a buffer. Unfortunately to do that, it needs to release the 233 // ownership, which is a problem if the execute will fail. 234 // This calls for a refactoring of the whole owning/maybe-owning interface to 235 // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr). 236 // We'd need something similar to XRTTupleAllocation instead of 237 // ScopedShapedBuffer, which wants ownership and does not allow sharing. 238 xla::StatusOr<xla::ExecutionInput> ToExecutionInput( 239 const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>& 240 alias_checker); 241 242 private: 243 // Creates a new handle with (tuple) shape. 244 XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, 245 const xla::Shape& on_host_shape, 246 const xla::Shape& on_device_shape); 247 248 // Inherits the allocations represented in buffer, which must have the same 249 // shape as buffers_. 250 void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer, 251 se::DeviceMemoryAllocator* allocator, 252 int device_ordinal); 253 254 // Releases all the XRTBufferAllocation buffer references and set the 255 // corresponding shape tree entry to nullptr. 256 void ReleaseBuffers(); 257 258 // Stores the content of the allocation from device memory to the target host 259 // literal. 260 Status StoreToLiteral(xla::Backend* backend, 261 xla::MutableLiteralBase* literal); 262 263 // Sets the total size of the buffers held within this allocation buffers. 264 // This API should be called once when an XRTTupleAllocation object is 265 // created, as the XRTTupleAllocation shapes never change, and hence the 266 // device memory size. 267 void SetDeviceMemorySize(); 268 269 // Takes a tree 'elements' where each leaf is an allocation, validates that 270 // they are all on device_ordinal managed by allocator, and returns in 271 // host_shape and device_shape the host/device shapes of the expanded tree, 272 // where at each leaf of elements the shape of the allocation at elements is 273 // grafted on. 274 static Status ExpandTreeOfTuples( 275 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal, 276 se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, 277 xla::Shape* device_shape); 278 279 // The lock which protects the internal operations of the tuple allocation. Is 280 // mutable to allow const-like operations to be declared as such. 281 mutable mutex lock_; 282 283 // Location of the memory that is being managed. 284 const int device_ordinal_; 285 se::DeviceMemoryAllocator* const allocator_; 286 287 // The shape that the caller thinks the tuple has. 288 const xla::Shape on_host_shape_; 289 // The shape that the tuple has on device. Store this explicitly instead of 290 // using a shape stored in ShapeTree because ShapeTree discards the layout. 291 const xla::Shape on_device_shape_; 292 // The tree of reference-counted buffers, which uses on_device_shape_ as its 293 // shape. 294 xla::ShapeTree<XRTBufferAllocation*> buffers_; 295 // The footprint of the allocation, when residing on device memory. 296 size_t device_memory_size_ = 0; 297 // If the allocation is swapped out, this is the literal storing its content. 298 std::unique_ptr<xla::Literal> literal_; 299 // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0 300 // then the allocation is pinned. 301 std::atomic<int64_t> pin_count_; 302 }; 303 304 } // namespace tensorflow 305 306 #endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 307