xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/xrt_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "tensorflow/compiler/xla/service/backend.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
35 // readable. They need a type of const T& where in this case T is the
36 // following pointer.
37 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
38 
39 class BufferAllocStats {
40  public:
41   struct Stats {
42     int64_t count = 0;
43     int64_t size = 0;
44   };
45 
ReportAlloc(int64_t device,int64_t msize)46   Stats ReportAlloc(int64_t device, int64_t msize) {
47     mutex_lock lock(lock_);
48     Stats* device_stats = &stats_[device];
49     device_stats->count += 1;
50     device_stats->size += msize;
51     return *device_stats;
52   }
53 
ReportFree(int64_t device,int64_t msize)54   Stats ReportFree(int64_t device, int64_t msize) {
55     mutex_lock lock(lock_);
56     Stats* device_stats = &stats_[device];
57     device_stats->count -= 1;
58     device_stats->size -= msize;
59     return *device_stats;
60   }
61 
62  private:
63   mutable mutex lock_;
64   std::map<int64_t, Stats> stats_;
65 };
66 
GetAllocStats()67 BufferAllocStats* GetAllocStats() {
68   static BufferAllocStats* stats = new BufferAllocStats();
69   return stats;
70 }
71 
AllocateScopedShapedBuffer(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer,se::DeviceMemoryAllocator * allocator)72 Status AllocateScopedShapedBuffer(
73     XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
74     const xla::Shape& shape, std::unique_ptr<xla::ScopedShapedBuffer>* buffer,
75     se::DeviceMemoryAllocator* allocator) {
76   auto transfer_manager = backend->transfer_manager();
77   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
78 
79   // XLA may use a different representation on device than the representation on
80   // the host. XLA does not document any contract for the relationship between
81   // these representations :/ Right now, the device shape is always a superset
82   // of the host shape, meaning that for any valid ShapeIndex in the host shape
83   // that ShapeIndex is also valid in the device shape, but not vice versa. In
84   // particular, some host-side types are rewritten to be tuples. We rely on
85   // this property when making sub-buffers, because we assume that if the client
86   // requests the host-shape sub-buffer at index i, that will correspond to the
87   // right device-shape sub-buffer at the same index.
88   xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
89   VLOG(3) << "Allocating literal buffer: host_shape="
90           << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
91           << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
92 
93   // The ScopedShapedBuffer frees the buffers that have so far been allocated if
94   // it goes out of scope. That's useful if we return early as the result of an
95   // error allocating one of the later buffers.
96   *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
97       shape, on_device_shape, allocator, device_ordinal);
98   for (auto& index_to_buffer : (*buffer)->buffers()) {
99     const xla::Shape& subshape =
100         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
101     uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
102     TF_ASSIGN_OR_RETURN(
103         se::OwningDeviceMemory buffer,
104         memory_manager->Allocate(backend, device_ordinal, size, allocator));
105     // Move our buffer into shaped_buffer, which takes ownership of it.
106     index_to_buffer.second = buffer.Release();
107     VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
108             << " index " << index_to_buffer.first.ToString() << " (" << size
109             << " bytes)";
110   }
111 
112   TF_RETURN_IF_ERROR(
113       transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
114 
115   return OkStatus();
116 }
117 
118 }  // namespace
119 
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,se::DeviceMemoryAllocator * allocator)120 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
121                                          int device_ordinal,
122                                          se::DeviceMemoryAllocator* allocator)
123     : allocation_(allocation),
124       device_ordinal_(device_ordinal),
125       allocator_(allocator) {
126   if (VLOG_IS_ON(2)) {
127     auto stats =
128         GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
129     LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
130               << " count=" << stats.count << " size=" << stats.size;
131   }
132 }
133 
~XRTBufferAllocation()134 XRTBufferAllocation::~XRTBufferAllocation() {
135   if (VLOG_IS_ON(2)) {
136     GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
137   }
138   // Deallocate explicitly allows allocation_ to be null.
139   TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_));
140   VLOG(2) << "Freed buffer at " << allocation_.opaque() << " ("
141           << allocation_.size() << " bytes)";
142 }
143 
allocation()144 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
145   return allocation_;
146 }
147 
XRTTupleAllocation(int device_ordinal,se::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)148 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
149                                        se::DeviceMemoryAllocator* allocator,
150                                        const xla::Shape& on_host_shape,
151                                        const xla::Shape& on_device_shape)
152     : device_ordinal_(device_ordinal),
153       allocator_(allocator),
154       on_host_shape_(on_host_shape),
155       on_device_shape_(on_device_shape),
156       buffers_(&on_device_shape_),
157       pin_count_(0) {}
158 
~XRTTupleAllocation()159 XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); }
160 
ReleaseBuffers()161 void XRTTupleAllocation::ReleaseBuffers() {
162   for (auto& index_buffer : buffers_) {
163     if (index_buffer.second != nullptr) {
164       index_buffer.second->Unref();
165       index_buffer.second = nullptr;
166     }
167   }
168 }
169 
CreateAndTransfer(const xla::LiteralBase & literal,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)170 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
171     const xla::LiteralBase& literal, XRTMemoryManager* memory_manager,
172     xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation,
173     se::DeviceMemoryAllocator* allocator) {
174   auto transfer_manager = backend->transfer_manager();
175   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
176   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
177                                                 device_ordinal, literal.shape(),
178                                                 &scoped_buffer, allocator));
179   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
180   TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
181       stream.get(), literal, *scoped_buffer));
182 
183   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
184   // won't be freed when the buffer goes out of scope at the end of this
185   // call. To avoid a leak, there must be no error-case returns from here until
186   // the end of the method.
187   auto shaped_buffer = scoped_buffer->release();
188   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
189                                        shaped_buffer.on_host_shape(),
190                                        shaped_buffer.on_device_shape());
191   (*allocation)
192       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
193   (*allocation)->SetDeviceMemorySize();
194   return OkStatus();
195 }
196 
CreateUninitialized(const xla::Shape & shape,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)197 /*static*/ Status XRTTupleAllocation::CreateUninitialized(
198     const xla::Shape& shape, XRTMemoryManager* memory_manager,
199     xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation,
200     se::DeviceMemoryAllocator* allocator) {
201   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
202   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
203                                                 device_ordinal, shape,
204                                                 &scoped_buffer, allocator));
205 
206   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
207   // won't be freed when the buffer goes out of scope at the end of this
208   // call. To avoid a leak, there must be no error-case returns from here until
209   // the end of the method.
210   auto shaped_buffer = scoped_buffer->release();
211   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
212                                        shaped_buffer.on_host_shape(),
213                                        shaped_buffer.on_device_shape());
214   (*allocation)
215       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
216   (*allocation)->SetDeviceMemorySize();
217   return OkStatus();
218 }
219 
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)220 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
221     const xla::ShapedBuffer& shaped_buffer, const xla::Shape& on_host_shape,
222     const xla::Shape& on_device_shape, xla::Backend* backend,
223     int device_ordinal, XRTTupleAllocation** allocation,
224     se::DeviceMemoryAllocator* allocator) {
225   *allocation = new XRTTupleAllocation(device_ordinal, allocator, on_host_shape,
226                                        on_device_shape);
227   (*allocation)
228       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
229   (*allocation)->SetDeviceMemorySize();
230   return OkStatus();
231 }
232 
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)233 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
234     const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
235     int device_ordinal, XRTTupleAllocation** allocation,
236     se::DeviceMemoryAllocator* allocator) {
237   return CreateFromBuffer(shaped_buffer, shaped_buffer.on_host_shape(),
238                           shaped_buffer.on_device_shape(), backend,
239                           device_ordinal, allocation, allocator);
240 }
241 
ToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)242 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend,
243                                      xla::MutableLiteralBase* literal) {
244   mutex_lock lock(lock_);
245   return literal_ == nullptr ? StoreToLiteral(backend, literal)
246                              : literal->CopyFrom(*literal_);
247 }
248 
StoreToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)249 Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend,
250                                           xla::MutableLiteralBase* literal) {
251   auto transfer_manager = backend->transfer_manager();
252   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
253   TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
254   return transfer_manager->TransferLiteralFromDevice(stream.get(),
255                                                      shaped_buffer, literal);
256 }
257 
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)258 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
259                                         const xla::Literal& literal) {
260   if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
261     return errors::InvalidArgument(
262         "New literal shape not matching the existing one: literal=",
263         xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
264         " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
265   }
266   mutex_lock lock(lock_);
267   if (literal_ != nullptr) {
268     // The allocation is currently swapped out, and we have a host literal for
269     // its content. Just update the host literal with the new value.
270     return literal_->CopyFrom(literal);
271   }
272   TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
273   auto transfer_manager = backend->transfer_manager();
274   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
275   return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
276                                                    shaped_buffer);
277 }
278 
SwapOut(xla::Backend * backend,bool swap_pinned)279 xla::StatusOr<bool> XRTTupleAllocation::SwapOut(xla::Backend* backend,
280                                                 bool swap_pinned) {
281   mutex_lock lock(lock_);
282   if (literal_ == nullptr && (!IsPinned() || swap_pinned)) {
283     xla::Literal literal(on_host_shape());
284     TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal));
285     ReleaseBuffers();
286     literal_ = absl::make_unique<xla::Literal>(std::move(literal));
287     return true;
288   }
289   return false;
290 }
291 
SwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)292 xla::StatusOr<bool> XRTTupleAllocation::SwapIn(
293     XRTMemoryManager* memory_manager, xla::Backend* backend,
294     se::DeviceMemoryAllocator* allocator) {
295   // We need to call AllocateScopedShapedBuffer() outside the locks, since the
296   // XRTMemoryManager might end up calling back into the SwapOut() API.
297   // So we do a quick check before using the IsSwapped() API, and it can happen
298   // that the allocation becomes swapped in after the check. This means which we
299   // will end up doing an allocation, and then releasing it soon after (via its
300   // scoped variables). This is an unlikely scenario (two threads calling
301   // SwapIn() on the same allocation) though.
302   if (!IsSwapped()) {
303     return false;
304   }
305 
306   auto transfer_manager = backend->transfer_manager();
307   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
308   TF_RETURN_IF_ERROR(
309       AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(),
310                                  on_host_shape(), &scoped_buffer, allocator));
311   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
312 
313   mutex_lock lock(lock_);
314   if (literal_ != nullptr) {
315     TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
316         stream.get(), *literal_, *scoped_buffer));
317 
318     auto shaped_buffer = scoped_buffer->release();
319     InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal());
320     literal_ = nullptr;
321     return true;
322   }
323   return false;
324 }
325 
PinAndSwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)326 xla::StatusOr<bool> XRTTupleAllocation::PinAndSwapIn(
327     XRTMemoryManager* memory_manager, xla::Backend* backend,
328     se::DeviceMemoryAllocator* allocator) {
329   Pin();
330   return SwapIn(memory_manager, backend, allocator);
331 }
332 
IsSwapped() const333 bool XRTTupleAllocation::IsSwapped() const {
334   mutex_lock lock(lock_);
335   return literal_ != nullptr;
336 }
337 
Pin()338 int64_t XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); }
339 
Unpin()340 int64_t XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); }
341 
IsPinned() const342 bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; }
343 
DiscardAllocation(const xla::ShapeIndex & buffer_index)344 void XRTTupleAllocation::DiscardAllocation(
345     const xla::ShapeIndex& buffer_index) {
346   buffers_.element(buffer_index)->DiscardAllocation();
347 }
348 
on_host_shape() const349 const xla::Shape& XRTTupleAllocation::on_host_shape() const {
350   return on_host_shape_;
351 }
352 
on_device_shape() const353 const xla::Shape& XRTTupleAllocation::on_device_shape() const {
354   return on_device_shape_;
355 }
356 
device_ordinal() const357 int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; }
358 
root_allocation() const359 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const {
360   return buffers_.element({})->allocation();
361 }
362 
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)363 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
364     XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
365     XRTTupleAllocation** allocation, bool alias_parent_allocation) {
366   TF_ASSIGN_OR_RETURN(
367       const xla::Shape* host_sub_shape,
368       xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
369   TF_ASSIGN_OR_RETURN(
370       const xla::Shape* device_sub_shape,
371       xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
372 
373   *allocation =
374       new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
375                              *host_sub_shape, *device_sub_shape);
376   if (alias_parent_allocation) {
377     // Copy the subtree of allocations from the parent allocation.
378     (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
379     // Increment the refcount on each aliased buffer.
380     (*allocation)
381         ->buffers_.ForEachElement(
382             [](const xla::ShapeIndex& index,
383                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
384   } else {
385     // Find the buffers in the parent allocation that match the subtree, and
386     // move the parent allocation's buffer over to the new allocation.
387     (*allocation)
388         ->buffers_.ForEachMutableElement(
389             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
390               // Extend the allocation's index to the parent's frame by adding
391               // subshape as a prefix.
392               xla::ShapeIndex parent_index = subshape;
393               for (int i = 0; i < index.size(); ++i) {
394                 parent_index.push_back(index[i]);
395               }
396               *buffer = parent->buffers_.element(parent_index);
397               *parent->buffers_.mutable_element(parent_index) = nullptr;
398             });
399   }
400   (*allocation)->SetDeviceMemorySize();
401   return OkStatus();
402 }
403 
SetDeviceMemorySize()404 void XRTTupleAllocation::SetDeviceMemorySize() {
405   size_t size = 0;
406   for (auto& index_buffer : buffers_) {
407     if (index_buffer.second != nullptr) {
408       size += index_buffer.second->allocation().size();
409     }
410   }
411   device_memory_size_ = size;
412 }
413 
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,se::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)414 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
415     const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
416     se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
417     xla::Shape* device_shape) {
418   // Initialize both host and device shape to be the 'spine' of the new tuple
419   // shape, given by the shape of the tree of tuples.
420   *host_shape = elements.shape();
421   *device_shape = elements.shape();
422   // Now go over the leaves of the tree of tuples, and 'graft' the host/device
423   // shapes of the allocation at that leaf onto the expanded host/device shapes
424   // at the leaf position.
425   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
426       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
427         if (elements.IsLeaf(index)) {
428           if (element.allocation == nullptr) {
429             return errors::InvalidArgument(
430                 "MakeTuple elements has a null internal node at index ",
431                 index.ToString());
432           }
433           if (device_ordinal != element.allocation->device_ordinal() ||
434               allocator != element.allocation->allocator_) {
435             return errors::InvalidArgument(
436                 "MakeTuple elements must all be allocated on the same device "
437                 "as the destination.");
438           }
439           *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
440               element.allocation->on_host_shape();
441           *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
442               element.allocation->on_device_shape();
443         } else {
444           if (element.allocation != nullptr) {
445             return errors::InvalidArgument(
446                 "MakeTuple elements has a non-null internal node at index ",
447                 index.ToString());
448           }
449         }
450         return OkStatus();
451       }));
452   return OkStatus();
453 }
454 
MakeTuple(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)455 /*static*/ Status XRTTupleAllocation::MakeTuple(
456     XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
457     const xla::ShapeTree<ExpandedTupleInput>& elements,
458     XRTTupleAllocation** allocation, se::DeviceMemoryAllocator* allocator) {
459   auto transfer_manager = backend->transfer_manager();
460   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
461 
462   xla::Shape host_shape;
463   xla::Shape device_shape;
464   TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
465                                         &host_shape, &device_shape));
466 
467   // The aliasing is determined below based on whether or not all the inputs are
468   // released while being transferred. allocation_tmp is a local pointer that is
469   // copied to *allocation at the end only if the method succeeds.
470   XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation(
471       device_ordinal, allocator, host_shape, device_shape);
472   core::ScopedUnref allocation_unref(allocation_tmp);
473   // First allocate device memory for the new tuple index tables, one at each
474   // internal node of the elements tree. Do this in a separate pass into a
475   // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
476   // an allocation fails. Make sure the shape has layout so that the code that
477   // writes index tables will be happy lower down.
478   xla::Shape spine_shape = elements.shape();
479   xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
480   auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
481       spine_shape, spine_shape, allocator, device_ordinal);
482   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
483       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
484         if (!elements.IsLeaf(index)) {
485           const xla::Shape& subshape =
486               xla::ShapeUtil::GetSubshape(device_shape, index);
487           uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
488           TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
489                               memory_manager->Allocate(backend, device_ordinal,
490                                                        size, allocator));
491           VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index "
492                   << index.ToString();
493           // Move the new buffer into new_tuple_buffers, which takes ownership
494           // of it.
495           new_tuple_buffers->set_buffer(std::move(buffer), index);
496         }
497         return OkStatus();
498       }));
499   // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
500   // the newly-allocated index tables. Right now there's no owner for the new
501   // index tables, so next we will transfer ownership to the new allocation,
502   // taking care not to return early on any errors in the meantime.
503   xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
504   // Now fill in the remaining datastructures. After this ForEachElement
505   // completes:
506   //   1) Every leaf element of tuple_buffers will be the root buffer of
507   //      an existing allocation, and every internal element of tuple_buffers
508   //      will be a newly-allocated index table. tuple_buffers does not own any
509   //      of these.
510   //   2) Every element of allocation_tmp->buffers_ will be a correctly
511   //   constructed
512   //      XRTBufferAllocation wrapping the necessary allocations. For buffers in
513   //      existing allocations there will be a new reference owned by the new
514   //      allocation, and for newly-allocated index tables there will be a
515   //      single reference owned by the new allocation.
516   elements.ForEachElement([&](const xla::ShapeIndex& index,
517                               const ExpandedTupleInput& element) {
518     if (elements.IsLeaf(index)) {
519       allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
520                                                index);
521       tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
522       if (element.release_allocation_after_use) {
523         // Transfer the references from element's buffers to the new allocation
524         // rather than incrementing the refcount. The caller should have
525         // validated that release_allocation_after_use is false if
526         // element.allocation appears in more than one leaf.
527         element.allocation->buffers_.ForEachMutableElement(
528             [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) {
529               *buffer = nullptr;
530             });
531       } else {
532         // Increment the refcount on each newly-aliased buffer.
533         element.allocation->buffers_.ForEachElement(
534             [](const xla::ShapeIndex& index,
535                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
536       }
537     } else {
538       // This is an internal node of the tuple tree so take ownership of the
539       // newly-created index table.
540       *allocation_tmp->buffers_.mutable_element(index) =
541           new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
542                                   allocator);
543     }
544   });
545   allocation_tmp->SetDeviceMemorySize();
546   // Because the internal nodes of tuple_buffers are exactly the new index
547   // tables, WriteTupleIndexTables will write only the new index tables and not
548   // rewrite the index tables for the existing allocations.
549   TF_RETURN_IF_ERROR(
550       transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
551 
552   *allocation = allocation_tmp;
553   // Get another reference since allocation_tmp will be Unrefed automatically on
554   // exit.
555   (*allocation)->Ref();
556   return OkStatus();
557 }
558 
IsExclusiveOwner() const559 bool XRTTupleAllocation::IsExclusiveOwner() const {
560   for (const auto& index_buffer : buffers_) {
561     if (index_buffer.second != nullptr &&
562         !index_buffer.second->RefCountIsOne()) {
563       return false;
564     }
565   }
566   return true;
567 }
568 
GetDeviceMemorySize() const569 size_t XRTTupleAllocation::GetDeviceMemorySize() const {
570   return device_memory_size_;
571 }
572 
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,se::DeviceMemoryAllocator * allocator,int device_ordinal)573 void XRTTupleAllocation::InitializeFromShapedBuffer(
574     const xla::ShapedBuffer& shaped_buffer,
575     se::DeviceMemoryAllocator* allocator, int device_ordinal) {
576   for (auto& index_buffer : buffers_) {
577     if (index_buffer.second != nullptr) {
578       index_buffer.second->Unref();
579     }
580     // Make a reference-counted version of the allocated buffer.
581     index_buffer.second = new XRTBufferAllocation(
582         shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator);
583   }
584 }
585 
ToShapedBuffer()586 xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
587   xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
588                                   device_ordinal_);
589   for (const auto& index_buffer : buffers_) {
590     if (index_buffer.second == nullptr ||
591         (index_buffer.second->allocation().is_null() &&
592          index_buffer.second->allocation().size() > 0)) {
593       return errors::InvalidArgument("Literal buffer at index ",
594                                      index_buffer.first.ToString(),
595                                      " has been released");
596     }
597     shaped_buffer.set_buffer(index_buffer.second->allocation(),
598                              index_buffer.first);
599   }
600   return std::move(shaped_buffer);
601 }
602 
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)603 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
604                                            const xla::ShapeIndex& source_index,
605                                            const xla::ShapeIndex& dest_index) {
606   XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
607   XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
608   if (dest_buffer != nullptr) {
609     // We allow the destination size being zero, because there are cases where
610     // we are coming in later filling in null/uninitialized device buffers. In
611     // all other cases, the size of the new buffer must match.
612     if (source_buffer->allocation().size() !=
613             dest_buffer->allocation().size() &&
614         dest_buffer->allocation().size() != 0) {
615       return errors::InvalidArgument(
616           "Source buffer at index ", source_index.ToString(),
617           " does not match the size of destination buffer at index ",
618           dest_index.ToString(), ": ", source_buffer->allocation().size(),
619           " vs ", dest_buffer->allocation().size());
620     }
621   } else {
622     const xla::Shape& source_subshape =
623         xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index);
624     const xla::Shape& dest_subshape =
625         xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index);
626     if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) {
627       return errors::InvalidArgument(
628           "Source and destination subshapes do not match: source=",
629           xla::ShapeUtil::HumanStringWithLayout(source_subshape),
630           " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape));
631     }
632   }
633   *buffers_.mutable_element(dest_index) = source_buffer;
634   source_buffer->Ref();
635   if (dest_buffer != nullptr) {
636     // If we handed over the ownership of a buffer in ToExecutionInput(), we
637     // will be called here on the way back from execution, to alias back the
638     // buffer at that index. In that case the buffers will be the same. So we
639     // need to discard the memory at the destination buffer, before releasing
640     // the reference.
641     if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) &&
642         dest_buffer != source_buffer) {
643       dest_buffer->DiscardAllocation();
644     }
645     dest_buffer->Unref();
646   }
647   return OkStatus();
648 }
649 
ToExecutionInput(const std::function<xla::StatusOr<bool> (const xla::ShapeIndex &)> & alias_checker)650 xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
651     const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
652         alias_checker) {
653   xla::ExecutionInput result(on_device_shape(), on_host_shape());
654   for (const auto& index_buffer : buffers_) {
655     if (index_buffer.second == nullptr ||
656         (index_buffer.second->allocation().is_null() &&
657          index_buffer.second->allocation().size() > 0)) {
658       return errors::InvalidArgument("Literal buffer at index ",
659                                      index_buffer.first.ToString(),
660                                      " has been released");
661     }
662     TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first));
663     if (!should_alias) {
664       result.SetBuffer(
665           index_buffer.first,
666           xla::MaybeOwningDeviceMemory(index_buffer.second->allocation()));
667     } else {
668       // We keep the ownership of the device memory here.
669       result.SetUnownedBuffer(
670           index_buffer.first,
671           xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory(
672               index_buffer.second->allocation(), device_ordinal_, allocator_)));
673     }
674   }
675   return std::move(result);
676 }
677 
678 }  // namespace tensorflow
679