1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/cleanup/cleanup.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/notification.h"
33
34 using absl::StrCat;
35
36 namespace xla {
37
38 /* static */ absl::Mutex TransferManager::platform_transfer_manager_mutex_(
39 absl::kConstInit);
40
41 /* static */ absl::flat_hash_map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()42 TransferManager::GetPlatformTransferManagers() {
43 static auto* r =
44 new absl::flat_hash_map<se::Platform::Id, TransferManager::State>;
45 return r;
46 }
47
~TransferMetadata()48 TransferManager::TransferMetadata::~TransferMetadata() {}
49
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)50 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
51 se::Stream* stream, const ShapedBuffer& device_buffer,
52 const TransferMetadata* transfer_metadata) {
53 StatusOr<Literal> ret;
54
55 se::Stream* substream = stream->GetOrCreateSubStream();
56 substream->ThenWaitFor(stream);
57 absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
58
59 tensorflow::Notification n;
60 Status s;
61 Literal literal(device_buffer.on_host_shape());
62 TransferLiteralFromDevice(
63 substream, device_buffer, &literal,
64 [&](Status status) {
65 s = status;
66 n.Notify();
67 },
68 transfer_metadata);
69 n.WaitForNotification();
70 if (!s.ok()) {
71 return s;
72 }
73 return std::move(literal);
74 }
75
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)76 Status TransferManager::TransferLiteralFromDevice(
77 se::Stream* stream, const ShapedBuffer& device_buffer,
78 const MutableBorrowingLiteral& literal,
79 const TransferMetadata* transfer_metadata) {
80 se::Stream* substream = stream->GetOrCreateSubStream();
81 absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
82
83 Status ret;
84 tensorflow::Notification n;
85 TransferLiteralFromDevice(
86 substream, device_buffer, literal,
87 [&](Status status) {
88 ret = status;
89 n.Notify();
90 },
91 transfer_metadata);
92 n.WaitForNotification();
93 return ret;
94 }
95
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)96 Status TransferManager::TransferLiteralToDevice(
97 se::Stream* stream, const LiteralSlice& literal,
98 const ShapedBuffer& device_buffer,
99 const TransferMetadata* transfer_metadata) {
100 // Implement the synchronous version by waiting on the asynchronous version.
101 // Use a substream so that if we are called from a HostCallback we don't
102 // deadlock.
103 se::Stream* substream = stream->GetOrCreateSubStream();
104 substream->ThenWaitFor(stream);
105 absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
106 TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
107 substream, literal, device_buffer, transfer_metadata));
108 return substream->BlockHostUntilDone();
109 }
110
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)111 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
112 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
113 const TransferMetadata* transfer_metadata) {
114 StatusOr<Literal> ret;
115 // Implement the synchronous version by waiting on the asynchronous version.
116 // Use a substream so that if we are called from a HostCallback we don't
117 // deadlock.
118 se::Stream* substream = stream->GetOrCreateSubStream();
119 absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
120
121 tensorflow::Notification n;
122 Literal literal(shape);
123 Status s;
124 TransferArrayFromDevice(
125 substream, shape, source, &literal,
126 [&](Status status) {
127 s = status;
128 n.Notify();
129 },
130 transfer_metadata);
131 n.WaitForNotification();
132 if (!s.ok()) {
133 return s;
134 }
135 return std::move(literal);
136 }
137
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)138 Status TransferManager::TransferArrayToDevice(
139 se::Stream* stream, const LiteralSlice& literal,
140 const se::DeviceMemoryBase& dest,
141 const TransferMetadata* transfer_metadata) {
142 // Implement the synchronous version by waiting on the asynchronous version.
143 // Use a substream so that if we are called from a HostCallback we don't
144 // deadlock.
145 se::Stream* substream = stream->GetOrCreateSubStream();
146 absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
147 TF_RETURN_IF_ERROR(
148 TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
149 return substream->BlockHostUntilDone();
150 }
151
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)152 Status TransferManager::TransferArrayToDeviceAsync(
153 se::Stream* stream, const LiteralSlice& literal,
154 const se::DeviceMemoryBase& dest,
155 const TransferMetadata* transfer_metadata) {
156 const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
157 TF_RET_CHECK(on_device_shape.IsArray())
158 << "On-device representation of "
159 << ShapeUtil::HumanString(literal.shape())
160 << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
161 if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
162 return FailedPrecondition(
163 "Allocation on device not large enough for array: "
164 "%d < %d",
165 dest.size(), GetByteSizeRequirement(on_device_shape));
166 }
167 ShapedBuffer shaped_buffer(on_device_shape,
168 stream->parent()->device_ordinal());
169 shaped_buffer.set_buffer(dest, /*index=*/{});
170 return TransferLiteralToDevice(stream, literal, shaped_buffer,
171 transfer_metadata);
172 }
173
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)174 void TransferManager::TransferArrayFromDevice(
175 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
176 const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
177 const TransferMetadata* transfer_metadata) {
178 if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
179 shape)) {
180 auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
181 " has a differently shaped representation on-device: ",
182 ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
183 return done(FailedPrecondition("%s", error));
184 }
185 if (source.size() < GetByteSizeRequirement(shape)) {
186 return done(
187 FailedPrecondition("Allocation on device not large enough for array: "
188 "%d < %d",
189 source.size(), GetByteSizeRequirement(shape)));
190 }
191 ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
192 shaped_buffer.set_buffer(source, /*index=*/{});
193 return TransferLiteralFromDevice(stream, shaped_buffer, literal,
194 std::move(done), transfer_metadata);
195 }
196
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)197 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
198 ShapedBuffer* device_buffer,
199 Shape* device_shape) {
200 DCHECK(device_shape->is_dynamic());
201 Shape original_device_shape = *device_shape;
202 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
203
204 TF_ASSIGN_OR_RETURN(auto compiler,
205 Compiler::GetForPlatform(stream->parent()->platform()));
206 TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
207 [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
208 const Shape& buffer_shape =
209 ShapeUtil::GetSubshape(*device_shape, index);
210 if (buffer_shape.IsTuple()) {
211 return OkStatus();
212 }
213 Shape& device_sub_shape =
214 *ShapeUtil::GetMutableSubshape(device_shape, index);
215 if (device_sub_shape.is_static()) {
216 return OkStatus();
217 }
218
219 // Read the dynamic shape metadata from the device stream. The dynamic
220 // shape itself is stored at the end of the buffer.
221 auto shape_size_fn = compiler->ShapeSizeBytesFunction();
222 Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
223 const int64_t offset = shape_size_fn(buffer_shape_static);
224 int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
225 if (metadata_size == 0) {
226 return InvalidArgument("Dynamic shape metadata size should not be 0");
227 }
228 auto buffer_8 = se::DeviceMemory<uint8_t>(*buffer);
229 auto metadata_buffer =
230 stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
231 TF_ASSIGN_OR_RETURN(
232 auto metadata,
233 TransferArrayFromDevice(
234 stream,
235 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
236 metadata_buffer));
237
238 // Update shape size from metadata.
239 for (int64_t i = 0; i < metadata.element_count(); ++i) {
240 device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32_t>({i});
241 }
242 return OkStatus();
243 }));
244 device_shape->clear_dynamic_dimensions();
245
246 TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
247 original_device_shape));
248 return OkStatus();
249 }
250
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)251 /* static */ void TransferManager::RegisterTransferManager(
252 se::Platform::Id platform_id,
253 TransferManagerCreationFunction creation_function) {
254 absl::MutexLock lock(&TransferManager::platform_transfer_manager_mutex_);
255 auto* managers = GetPlatformTransferManagers();
256 CHECK(managers->find(platform_id) == managers->end());
257 (*managers)[platform_id].creation_function = creation_function;
258 }
259
GetForPlatform(const se::Platform * platform)260 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
261 const se::Platform* platform) {
262 absl::MutexLock lock(&TransferManager::platform_transfer_manager_mutex_);
263 auto* managers = GetPlatformTransferManagers();
264
265 auto it = managers->find(platform->id());
266 if (it == managers->end()) {
267 return NotFound(
268 "could not find registered transfer manager for platform %s -- check "
269 "target linkage",
270 platform->Name());
271 }
272
273 if (it->second.manager == nullptr) {
274 // Lazily create the transfer manager the first time it is needed
275 it->second.manager = (*it->second.creation_function)();
276 }
277
278 return it->second.manager.get();
279 }
280
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)281 Status TransferManager::WriteTupleIndexTables(
282 se::Stream* stream, const ShapedBuffer& device_buffer) {
283 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
284 return stream->BlockHostUntilDone();
285 }
286
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)287 Status TransferManager::WriteTupleIndexTablesAsync(
288 se::Stream* stream, const ShapedBuffer& device_buffer) {
289 VLOG(2) << "Writing tuple index tables for " << device_buffer;
290
291 return ShapeUtil::ForEachSubshapeWithStatus(
292 device_buffer.on_device_shape(),
293 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
294 if (device_subshape.IsTuple() &&
295 ShapeUtil::TupleElementCount(device_subshape) > 0) {
296 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
297 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
298 device_memory.size());
299
300 std::vector<se::DeviceMemoryBase> elements;
301 ShapeIndex element_index = index;
302 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
303 ++i) {
304 element_index.push_back(i);
305 elements.push_back(device_buffer.buffer(element_index));
306 element_index.pop_back();
307 }
308 return WriteSingleTupleIndexTable(stream, elements, device_subshape,
309 &device_memory);
310 }
311
312 return OkStatus();
313 });
314 }
315
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)316 Status TransferManager::WriteRootTupleIndexTable(
317 se::Stream* stream, const ShapedBuffer& device_buffer) {
318 TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
319 if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
320 return OkStatus();
321 }
322 se::DeviceMemoryBase device_memory = device_buffer.buffer({});
323 TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
324 device_memory.size());
325
326 std::vector<se::DeviceMemoryBase> elements;
327 for (int64_t i = 0;
328 i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
329 elements.push_back(device_buffer.buffer({i}));
330 }
331 return WriteSingleTupleIndexTable(
332 stream, elements, device_buffer.on_device_shape(), &device_memory);
333 }
334
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)335 Status TransferManager::WriteRootTupleIndexTable(
336 se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
337 TF_RET_CHECK(buffer_tree.shape().IsTuple());
338 if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
339 return OkStatus();
340 }
341 se::DeviceMemoryBase device_memory =
342 buffer_tree.element({}).AsDeviceMemoryBase();
343 TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
344 device_memory.size());
345
346 std::vector<se::DeviceMemoryBase> elements;
347 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
348 ++i) {
349 elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
350 }
351 return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
352 &device_memory);
353 }
354
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64_t size,void * destination)355 Status TransferManager::TransferBufferFromDevice(
356 se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size,
357 void* destination) {
358 if (source.size() < size) {
359 return FailedPrecondition(
360 "Source allocation on device not large enough for data transfer: "
361 "%d < %d",
362 source.size(), size);
363 }
364 stream->ThenMemcpy(destination, source, size);
365 return OkStatus();
366 }
367
TransferBufferToDevice(se::Stream * stream,int64_t size,const void * source,se::DeviceMemoryBase * destination)368 Status TransferManager::TransferBufferToDevice(
369 se::Stream* stream, int64_t size, const void* source,
370 se::DeviceMemoryBase* destination) {
371 if (destination->size() < size) {
372 return FailedPrecondition(
373 "Destination allocation on device not large enough for data transfer: "
374 "%d < %d",
375 destination->size(), size);
376 }
377 stream->ThenMemcpy(destination, source, size);
378 return OkStatus();
379 }
380
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal,DeviceShapeRepresentationFn shape_representation_fn)381 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
382 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
383 int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn) {
384 if (!LayoutUtil::HasLayout(on_host_shape)) {
385 return InvalidArgument("Shape must have a layout: %s",
386 ShapeUtil::HumanStringWithLayout(on_host_shape));
387 }
388 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
389 Shape on_device_shape = (shape_representation_fn == nullptr)
390 ? HostShapeToDeviceShape(on_host_shape)
391 : shape_representation_fn(on_host_shape);
392 TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
393
394 ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
395 device_ordinal);
396
397 // Allocate an appropriate sized buffer for each element in the shape
398 // including the tuple pointer arrays.
399 for (auto& pair : shaped_buffer.buffers()) {
400 const ShapeIndex& index = pair.first;
401 se::DeviceMemoryBase& memory_base = pair.second;
402 const Shape& subshape =
403 ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
404 TF_ASSIGN_OR_RETURN(auto memory,
405 allocator->Allocate(shaped_buffer.device_ordinal(),
406 GetByteSizeRequirement(subshape),
407 /*retry_on_failure=*/true,
408 LayoutUtil::MemorySpace(subshape)));
409 // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
410 memory_base = memory.Release();
411 }
412
413 return std::move(shaped_buffer);
414 }
415
ChooseCompactLayoutForShape(const Shape & host_shape) const416 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
417 const Shape& host_shape) const {
418 return LayoutUtil::GetWithDefaultLayout(host_shape);
419 }
420
ChooseGoodInfeedLayout(const Shape & shape) const421 xla::Shape TransferManager::ChooseGoodInfeedLayout(const Shape& shape) const {
422 return LayoutUtil::GetWithDefaultLayout(shape);
423 }
424
425 } // namespace xla
426