1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25
26 #include "absl/memory/memory.h"
27 #include "tensorflow/compiler/xla/service/backend.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30
31 namespace tensorflow {
32 namespace {
33
34 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
35 // readable. They need a type of const T& where in this case T is the
36 // following pointer.
37 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
38
39 class BufferAllocStats {
40 public:
41 struct Stats {
42 int64_t count = 0;
43 int64_t size = 0;
44 };
45
ReportAlloc(int64_t device,int64_t msize)46 Stats ReportAlloc(int64_t device, int64_t msize) {
47 mutex_lock lock(lock_);
48 Stats* device_stats = &stats_[device];
49 device_stats->count += 1;
50 device_stats->size += msize;
51 return *device_stats;
52 }
53
ReportFree(int64_t device,int64_t msize)54 Stats ReportFree(int64_t device, int64_t msize) {
55 mutex_lock lock(lock_);
56 Stats* device_stats = &stats_[device];
57 device_stats->count -= 1;
58 device_stats->size -= msize;
59 return *device_stats;
60 }
61
62 private:
63 mutable mutex lock_;
64 std::map<int64_t, Stats> stats_;
65 };
66
GetAllocStats()67 BufferAllocStats* GetAllocStats() {
68 static BufferAllocStats* stats = new BufferAllocStats();
69 return stats;
70 }
71
AllocateScopedShapedBuffer(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer,se::DeviceMemoryAllocator * allocator)72 Status AllocateScopedShapedBuffer(
73 XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
74 const xla::Shape& shape, std::unique_ptr<xla::ScopedShapedBuffer>* buffer,
75 se::DeviceMemoryAllocator* allocator) {
76 auto transfer_manager = backend->transfer_manager();
77 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
78
79 // XLA may use a different representation on device than the representation on
80 // the host. XLA does not document any contract for the relationship between
81 // these representations :/ Right now, the device shape is always a superset
82 // of the host shape, meaning that for any valid ShapeIndex in the host shape
83 // that ShapeIndex is also valid in the device shape, but not vice versa. In
84 // particular, some host-side types are rewritten to be tuples. We rely on
85 // this property when making sub-buffers, because we assume that if the client
86 // requests the host-shape sub-buffer at index i, that will correspond to the
87 // right device-shape sub-buffer at the same index.
88 xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
89 VLOG(3) << "Allocating literal buffer: host_shape="
90 << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
91 << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
92
93 // The ScopedShapedBuffer frees the buffers that have so far been allocated if
94 // it goes out of scope. That's useful if we return early as the result of an
95 // error allocating one of the later buffers.
96 *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
97 shape, on_device_shape, allocator, device_ordinal);
98 for (auto& index_to_buffer : (*buffer)->buffers()) {
99 const xla::Shape& subshape =
100 xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
101 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
102 TF_ASSIGN_OR_RETURN(
103 se::OwningDeviceMemory buffer,
104 memory_manager->Allocate(backend, device_ordinal, size, allocator));
105 // Move our buffer into shaped_buffer, which takes ownership of it.
106 index_to_buffer.second = buffer.Release();
107 VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
108 << " index " << index_to_buffer.first.ToString() << " (" << size
109 << " bytes)";
110 }
111
112 TF_RETURN_IF_ERROR(
113 transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
114
115 return OkStatus();
116 }
117
118 } // namespace
119
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,se::DeviceMemoryAllocator * allocator)120 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
121 int device_ordinal,
122 se::DeviceMemoryAllocator* allocator)
123 : allocation_(allocation),
124 device_ordinal_(device_ordinal),
125 allocator_(allocator) {
126 if (VLOG_IS_ON(2)) {
127 auto stats =
128 GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
129 LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
130 << " count=" << stats.count << " size=" << stats.size;
131 }
132 }
133
~XRTBufferAllocation()134 XRTBufferAllocation::~XRTBufferAllocation() {
135 if (VLOG_IS_ON(2)) {
136 GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
137 }
138 // Deallocate explicitly allows allocation_ to be null.
139 TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_));
140 VLOG(2) << "Freed buffer at " << allocation_.opaque() << " ("
141 << allocation_.size() << " bytes)";
142 }
143
allocation()144 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
145 return allocation_;
146 }
147
XRTTupleAllocation(int device_ordinal,se::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)148 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
149 se::DeviceMemoryAllocator* allocator,
150 const xla::Shape& on_host_shape,
151 const xla::Shape& on_device_shape)
152 : device_ordinal_(device_ordinal),
153 allocator_(allocator),
154 on_host_shape_(on_host_shape),
155 on_device_shape_(on_device_shape),
156 buffers_(&on_device_shape_),
157 pin_count_(0) {}
158
~XRTTupleAllocation()159 XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); }
160
ReleaseBuffers()161 void XRTTupleAllocation::ReleaseBuffers() {
162 for (auto& index_buffer : buffers_) {
163 if (index_buffer.second != nullptr) {
164 index_buffer.second->Unref();
165 index_buffer.second = nullptr;
166 }
167 }
168 }
169
CreateAndTransfer(const xla::LiteralBase & literal,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)170 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
171 const xla::LiteralBase& literal, XRTMemoryManager* memory_manager,
172 xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation,
173 se::DeviceMemoryAllocator* allocator) {
174 auto transfer_manager = backend->transfer_manager();
175 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
176 TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
177 device_ordinal, literal.shape(),
178 &scoped_buffer, allocator));
179 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
180 TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
181 stream.get(), literal, *scoped_buffer));
182
183 // By releasing the ScopedShapedBuffer we ensure that the underlying storage
184 // won't be freed when the buffer goes out of scope at the end of this
185 // call. To avoid a leak, there must be no error-case returns from here until
186 // the end of the method.
187 auto shaped_buffer = scoped_buffer->release();
188 *allocation = new XRTTupleAllocation(device_ordinal, allocator,
189 shaped_buffer.on_host_shape(),
190 shaped_buffer.on_device_shape());
191 (*allocation)
192 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
193 (*allocation)->SetDeviceMemorySize();
194 return OkStatus();
195 }
196
CreateUninitialized(const xla::Shape & shape,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)197 /*static*/ Status XRTTupleAllocation::CreateUninitialized(
198 const xla::Shape& shape, XRTMemoryManager* memory_manager,
199 xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation,
200 se::DeviceMemoryAllocator* allocator) {
201 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
202 TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
203 device_ordinal, shape,
204 &scoped_buffer, allocator));
205
206 // By releasing the ScopedShapedBuffer we ensure that the underlying storage
207 // won't be freed when the buffer goes out of scope at the end of this
208 // call. To avoid a leak, there must be no error-case returns from here until
209 // the end of the method.
210 auto shaped_buffer = scoped_buffer->release();
211 *allocation = new XRTTupleAllocation(device_ordinal, allocator,
212 shaped_buffer.on_host_shape(),
213 shaped_buffer.on_device_shape());
214 (*allocation)
215 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
216 (*allocation)->SetDeviceMemorySize();
217 return OkStatus();
218 }
219
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)220 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
221 const xla::ShapedBuffer& shaped_buffer, const xla::Shape& on_host_shape,
222 const xla::Shape& on_device_shape, xla::Backend* backend,
223 int device_ordinal, XRTTupleAllocation** allocation,
224 se::DeviceMemoryAllocator* allocator) {
225 *allocation = new XRTTupleAllocation(device_ordinal, allocator, on_host_shape,
226 on_device_shape);
227 (*allocation)
228 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
229 (*allocation)->SetDeviceMemorySize();
230 return OkStatus();
231 }
232
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)233 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
234 const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
235 int device_ordinal, XRTTupleAllocation** allocation,
236 se::DeviceMemoryAllocator* allocator) {
237 return CreateFromBuffer(shaped_buffer, shaped_buffer.on_host_shape(),
238 shaped_buffer.on_device_shape(), backend,
239 device_ordinal, allocation, allocator);
240 }
241
ToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)242 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend,
243 xla::MutableLiteralBase* literal) {
244 mutex_lock lock(lock_);
245 return literal_ == nullptr ? StoreToLiteral(backend, literal)
246 : literal->CopyFrom(*literal_);
247 }
248
StoreToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)249 Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend,
250 xla::MutableLiteralBase* literal) {
251 auto transfer_manager = backend->transfer_manager();
252 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
253 TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
254 return transfer_manager->TransferLiteralFromDevice(stream.get(),
255 shaped_buffer, literal);
256 }
257
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)258 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
259 const xla::Literal& literal) {
260 if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
261 return errors::InvalidArgument(
262 "New literal shape not matching the existing one: literal=",
263 xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
264 " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
265 }
266 mutex_lock lock(lock_);
267 if (literal_ != nullptr) {
268 // The allocation is currently swapped out, and we have a host literal for
269 // its content. Just update the host literal with the new value.
270 return literal_->CopyFrom(literal);
271 }
272 TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
273 auto transfer_manager = backend->transfer_manager();
274 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
275 return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
276 shaped_buffer);
277 }
278
SwapOut(xla::Backend * backend,bool swap_pinned)279 xla::StatusOr<bool> XRTTupleAllocation::SwapOut(xla::Backend* backend,
280 bool swap_pinned) {
281 mutex_lock lock(lock_);
282 if (literal_ == nullptr && (!IsPinned() || swap_pinned)) {
283 xla::Literal literal(on_host_shape());
284 TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal));
285 ReleaseBuffers();
286 literal_ = absl::make_unique<xla::Literal>(std::move(literal));
287 return true;
288 }
289 return false;
290 }
291
SwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)292 xla::StatusOr<bool> XRTTupleAllocation::SwapIn(
293 XRTMemoryManager* memory_manager, xla::Backend* backend,
294 se::DeviceMemoryAllocator* allocator) {
295 // We need to call AllocateScopedShapedBuffer() outside the locks, since the
296 // XRTMemoryManager might end up calling back into the SwapOut() API.
297 // So we do a quick check before using the IsSwapped() API, and it can happen
298 // that the allocation becomes swapped in after the check. This means which we
299 // will end up doing an allocation, and then releasing it soon after (via its
300 // scoped variables). This is an unlikely scenario (two threads calling
301 // SwapIn() on the same allocation) though.
302 if (!IsSwapped()) {
303 return false;
304 }
305
306 auto transfer_manager = backend->transfer_manager();
307 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
308 TF_RETURN_IF_ERROR(
309 AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(),
310 on_host_shape(), &scoped_buffer, allocator));
311 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
312
313 mutex_lock lock(lock_);
314 if (literal_ != nullptr) {
315 TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
316 stream.get(), *literal_, *scoped_buffer));
317
318 auto shaped_buffer = scoped_buffer->release();
319 InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal());
320 literal_ = nullptr;
321 return true;
322 }
323 return false;
324 }
325
PinAndSwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend,se::DeviceMemoryAllocator * allocator)326 xla::StatusOr<bool> XRTTupleAllocation::PinAndSwapIn(
327 XRTMemoryManager* memory_manager, xla::Backend* backend,
328 se::DeviceMemoryAllocator* allocator) {
329 Pin();
330 return SwapIn(memory_manager, backend, allocator);
331 }
332
IsSwapped() const333 bool XRTTupleAllocation::IsSwapped() const {
334 mutex_lock lock(lock_);
335 return literal_ != nullptr;
336 }
337
Pin()338 int64_t XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); }
339
Unpin()340 int64_t XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); }
341
IsPinned() const342 bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; }
343
DiscardAllocation(const xla::ShapeIndex & buffer_index)344 void XRTTupleAllocation::DiscardAllocation(
345 const xla::ShapeIndex& buffer_index) {
346 buffers_.element(buffer_index)->DiscardAllocation();
347 }
348
on_host_shape() const349 const xla::Shape& XRTTupleAllocation::on_host_shape() const {
350 return on_host_shape_;
351 }
352
on_device_shape() const353 const xla::Shape& XRTTupleAllocation::on_device_shape() const {
354 return on_device_shape_;
355 }
356
device_ordinal() const357 int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; }
358
root_allocation() const359 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const {
360 return buffers_.element({})->allocation();
361 }
362
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)363 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
364 XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
365 XRTTupleAllocation** allocation, bool alias_parent_allocation) {
366 TF_ASSIGN_OR_RETURN(
367 const xla::Shape* host_sub_shape,
368 xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
369 TF_ASSIGN_OR_RETURN(
370 const xla::Shape* device_sub_shape,
371 xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
372
373 *allocation =
374 new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
375 *host_sub_shape, *device_sub_shape);
376 if (alias_parent_allocation) {
377 // Copy the subtree of allocations from the parent allocation.
378 (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
379 // Increment the refcount on each aliased buffer.
380 (*allocation)
381 ->buffers_.ForEachElement(
382 [](const xla::ShapeIndex& index,
383 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
384 } else {
385 // Find the buffers in the parent allocation that match the subtree, and
386 // move the parent allocation's buffer over to the new allocation.
387 (*allocation)
388 ->buffers_.ForEachMutableElement(
389 [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
390 // Extend the allocation's index to the parent's frame by adding
391 // subshape as a prefix.
392 xla::ShapeIndex parent_index = subshape;
393 for (int i = 0; i < index.size(); ++i) {
394 parent_index.push_back(index[i]);
395 }
396 *buffer = parent->buffers_.element(parent_index);
397 *parent->buffers_.mutable_element(parent_index) = nullptr;
398 });
399 }
400 (*allocation)->SetDeviceMemorySize();
401 return OkStatus();
402 }
403
SetDeviceMemorySize()404 void XRTTupleAllocation::SetDeviceMemorySize() {
405 size_t size = 0;
406 for (auto& index_buffer : buffers_) {
407 if (index_buffer.second != nullptr) {
408 size += index_buffer.second->allocation().size();
409 }
410 }
411 device_memory_size_ = size;
412 }
413
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,se::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)414 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
415 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
416 se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
417 xla::Shape* device_shape) {
418 // Initialize both host and device shape to be the 'spine' of the new tuple
419 // shape, given by the shape of the tree of tuples.
420 *host_shape = elements.shape();
421 *device_shape = elements.shape();
422 // Now go over the leaves of the tree of tuples, and 'graft' the host/device
423 // shapes of the allocation at that leaf onto the expanded host/device shapes
424 // at the leaf position.
425 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
426 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
427 if (elements.IsLeaf(index)) {
428 if (element.allocation == nullptr) {
429 return errors::InvalidArgument(
430 "MakeTuple elements has a null internal node at index ",
431 index.ToString());
432 }
433 if (device_ordinal != element.allocation->device_ordinal() ||
434 allocator != element.allocation->allocator_) {
435 return errors::InvalidArgument(
436 "MakeTuple elements must all be allocated on the same device "
437 "as the destination.");
438 }
439 *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
440 element.allocation->on_host_shape();
441 *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
442 element.allocation->on_device_shape();
443 } else {
444 if (element.allocation != nullptr) {
445 return errors::InvalidArgument(
446 "MakeTuple elements has a non-null internal node at index ",
447 index.ToString());
448 }
449 }
450 return OkStatus();
451 }));
452 return OkStatus();
453 }
454
MakeTuple(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation,se::DeviceMemoryAllocator * allocator)455 /*static*/ Status XRTTupleAllocation::MakeTuple(
456 XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
457 const xla::ShapeTree<ExpandedTupleInput>& elements,
458 XRTTupleAllocation** allocation, se::DeviceMemoryAllocator* allocator) {
459 auto transfer_manager = backend->transfer_manager();
460 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
461
462 xla::Shape host_shape;
463 xla::Shape device_shape;
464 TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
465 &host_shape, &device_shape));
466
467 // The aliasing is determined below based on whether or not all the inputs are
468 // released while being transferred. allocation_tmp is a local pointer that is
469 // copied to *allocation at the end only if the method succeeds.
470 XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation(
471 device_ordinal, allocator, host_shape, device_shape);
472 core::ScopedUnref allocation_unref(allocation_tmp);
473 // First allocate device memory for the new tuple index tables, one at each
474 // internal node of the elements tree. Do this in a separate pass into a
475 // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
476 // an allocation fails. Make sure the shape has layout so that the code that
477 // writes index tables will be happy lower down.
478 xla::Shape spine_shape = elements.shape();
479 xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
480 auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
481 spine_shape, spine_shape, allocator, device_ordinal);
482 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
483 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
484 if (!elements.IsLeaf(index)) {
485 const xla::Shape& subshape =
486 xla::ShapeUtil::GetSubshape(device_shape, index);
487 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
488 TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
489 memory_manager->Allocate(backend, device_ordinal,
490 size, allocator));
491 VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index "
492 << index.ToString();
493 // Move the new buffer into new_tuple_buffers, which takes ownership
494 // of it.
495 new_tuple_buffers->set_buffer(std::move(buffer), index);
496 }
497 return OkStatus();
498 }));
499 // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
500 // the newly-allocated index tables. Right now there's no owner for the new
501 // index tables, so next we will transfer ownership to the new allocation,
502 // taking care not to return early on any errors in the meantime.
503 xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
504 // Now fill in the remaining datastructures. After this ForEachElement
505 // completes:
506 // 1) Every leaf element of tuple_buffers will be the root buffer of
507 // an existing allocation, and every internal element of tuple_buffers
508 // will be a newly-allocated index table. tuple_buffers does not own any
509 // of these.
510 // 2) Every element of allocation_tmp->buffers_ will be a correctly
511 // constructed
512 // XRTBufferAllocation wrapping the necessary allocations. For buffers in
513 // existing allocations there will be a new reference owned by the new
514 // allocation, and for newly-allocated index tables there will be a
515 // single reference owned by the new allocation.
516 elements.ForEachElement([&](const xla::ShapeIndex& index,
517 const ExpandedTupleInput& element) {
518 if (elements.IsLeaf(index)) {
519 allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
520 index);
521 tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
522 if (element.release_allocation_after_use) {
523 // Transfer the references from element's buffers to the new allocation
524 // rather than incrementing the refcount. The caller should have
525 // validated that release_allocation_after_use is false if
526 // element.allocation appears in more than one leaf.
527 element.allocation->buffers_.ForEachMutableElement(
528 [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) {
529 *buffer = nullptr;
530 });
531 } else {
532 // Increment the refcount on each newly-aliased buffer.
533 element.allocation->buffers_.ForEachElement(
534 [](const xla::ShapeIndex& index,
535 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
536 }
537 } else {
538 // This is an internal node of the tuple tree so take ownership of the
539 // newly-created index table.
540 *allocation_tmp->buffers_.mutable_element(index) =
541 new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
542 allocator);
543 }
544 });
545 allocation_tmp->SetDeviceMemorySize();
546 // Because the internal nodes of tuple_buffers are exactly the new index
547 // tables, WriteTupleIndexTables will write only the new index tables and not
548 // rewrite the index tables for the existing allocations.
549 TF_RETURN_IF_ERROR(
550 transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
551
552 *allocation = allocation_tmp;
553 // Get another reference since allocation_tmp will be Unrefed automatically on
554 // exit.
555 (*allocation)->Ref();
556 return OkStatus();
557 }
558
IsExclusiveOwner() const559 bool XRTTupleAllocation::IsExclusiveOwner() const {
560 for (const auto& index_buffer : buffers_) {
561 if (index_buffer.second != nullptr &&
562 !index_buffer.second->RefCountIsOne()) {
563 return false;
564 }
565 }
566 return true;
567 }
568
GetDeviceMemorySize() const569 size_t XRTTupleAllocation::GetDeviceMemorySize() const {
570 return device_memory_size_;
571 }
572
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,se::DeviceMemoryAllocator * allocator,int device_ordinal)573 void XRTTupleAllocation::InitializeFromShapedBuffer(
574 const xla::ShapedBuffer& shaped_buffer,
575 se::DeviceMemoryAllocator* allocator, int device_ordinal) {
576 for (auto& index_buffer : buffers_) {
577 if (index_buffer.second != nullptr) {
578 index_buffer.second->Unref();
579 }
580 // Make a reference-counted version of the allocated buffer.
581 index_buffer.second = new XRTBufferAllocation(
582 shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator);
583 }
584 }
585
ToShapedBuffer()586 xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
587 xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
588 device_ordinal_);
589 for (const auto& index_buffer : buffers_) {
590 if (index_buffer.second == nullptr ||
591 (index_buffer.second->allocation().is_null() &&
592 index_buffer.second->allocation().size() > 0)) {
593 return errors::InvalidArgument("Literal buffer at index ",
594 index_buffer.first.ToString(),
595 " has been released");
596 }
597 shaped_buffer.set_buffer(index_buffer.second->allocation(),
598 index_buffer.first);
599 }
600 return std::move(shaped_buffer);
601 }
602
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)603 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
604 const xla::ShapeIndex& source_index,
605 const xla::ShapeIndex& dest_index) {
606 XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
607 XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
608 if (dest_buffer != nullptr) {
609 // We allow the destination size being zero, because there are cases where
610 // we are coming in later filling in null/uninitialized device buffers. In
611 // all other cases, the size of the new buffer must match.
612 if (source_buffer->allocation().size() !=
613 dest_buffer->allocation().size() &&
614 dest_buffer->allocation().size() != 0) {
615 return errors::InvalidArgument(
616 "Source buffer at index ", source_index.ToString(),
617 " does not match the size of destination buffer at index ",
618 dest_index.ToString(), ": ", source_buffer->allocation().size(),
619 " vs ", dest_buffer->allocation().size());
620 }
621 } else {
622 const xla::Shape& source_subshape =
623 xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index);
624 const xla::Shape& dest_subshape =
625 xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index);
626 if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) {
627 return errors::InvalidArgument(
628 "Source and destination subshapes do not match: source=",
629 xla::ShapeUtil::HumanStringWithLayout(source_subshape),
630 " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape));
631 }
632 }
633 *buffers_.mutable_element(dest_index) = source_buffer;
634 source_buffer->Ref();
635 if (dest_buffer != nullptr) {
636 // If we handed over the ownership of a buffer in ToExecutionInput(), we
637 // will be called here on the way back from execution, to alias back the
638 // buffer at that index. In that case the buffers will be the same. So we
639 // need to discard the memory at the destination buffer, before releasing
640 // the reference.
641 if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) &&
642 dest_buffer != source_buffer) {
643 dest_buffer->DiscardAllocation();
644 }
645 dest_buffer->Unref();
646 }
647 return OkStatus();
648 }
649
ToExecutionInput(const std::function<xla::StatusOr<bool> (const xla::ShapeIndex &)> & alias_checker)650 xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
651 const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
652 alias_checker) {
653 xla::ExecutionInput result(on_device_shape(), on_host_shape());
654 for (const auto& index_buffer : buffers_) {
655 if (index_buffer.second == nullptr ||
656 (index_buffer.second->allocation().is_null() &&
657 index_buffer.second->allocation().size() > 0)) {
658 return errors::InvalidArgument("Literal buffer at index ",
659 index_buffer.first.ToString(),
660 " has been released");
661 }
662 TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first));
663 if (!should_alias) {
664 result.SetBuffer(
665 index_buffer.first,
666 xla::MaybeOwningDeviceMemory(index_buffer.second->allocation()));
667 } else {
668 // We keep the ownership of the device memory here.
669 result.SetUnownedBuffer(
670 index_buffer.first,
671 xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory(
672 index_buffer.second->allocation(), device_ordinal_, allocator_)));
673 }
674 }
675 return std::move(result);
676 }
677
678 } // namespace tensorflow
679