xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/transfer_manager.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/cleanup/cleanup.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/service/compiler.h"
26 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/notification.h"
33 
34 using absl::StrCat;
35 
36 namespace xla {
37 
38 /* static */ absl::Mutex TransferManager::platform_transfer_manager_mutex_(
39     absl::kConstInit);
40 
41 /* static */ absl::flat_hash_map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()42 TransferManager::GetPlatformTransferManagers() {
43   static auto* r =
44       new absl::flat_hash_map<se::Platform::Id, TransferManager::State>;
45   return r;
46 }
47 
~TransferMetadata()48 TransferManager::TransferMetadata::~TransferMetadata() {}
49 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)50 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
51     se::Stream* stream, const ShapedBuffer& device_buffer,
52     const TransferMetadata* transfer_metadata) {
53   StatusOr<Literal> ret;
54 
55   se::Stream* substream = stream->GetOrCreateSubStream();
56   substream->ThenWaitFor(stream);
57   absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
58 
59   tensorflow::Notification n;
60   Status s;
61   Literal literal(device_buffer.on_host_shape());
62   TransferLiteralFromDevice(
63       substream, device_buffer, &literal,
64       [&](Status status) {
65         s = status;
66         n.Notify();
67       },
68       transfer_metadata);
69   n.WaitForNotification();
70   if (!s.ok()) {
71     return s;
72   }
73   return std::move(literal);
74 }
75 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)76 Status TransferManager::TransferLiteralFromDevice(
77     se::Stream* stream, const ShapedBuffer& device_buffer,
78     const MutableBorrowingLiteral& literal,
79     const TransferMetadata* transfer_metadata) {
80   se::Stream* substream = stream->GetOrCreateSubStream();
81   absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
82 
83   Status ret;
84   tensorflow::Notification n;
85   TransferLiteralFromDevice(
86       substream, device_buffer, literal,
87       [&](Status status) {
88         ret = status;
89         n.Notify();
90       },
91       transfer_metadata);
92   n.WaitForNotification();
93   return ret;
94 }
95 
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)96 Status TransferManager::TransferLiteralToDevice(
97     se::Stream* stream, const LiteralSlice& literal,
98     const ShapedBuffer& device_buffer,
99     const TransferMetadata* transfer_metadata) {
100   // Implement the synchronous version by waiting on the asynchronous version.
101   // Use a substream so that if we are called from a HostCallback we don't
102   // deadlock.
103   se::Stream* substream = stream->GetOrCreateSubStream();
104   substream->ThenWaitFor(stream);
105   absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
106   TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
107       substream, literal, device_buffer, transfer_metadata));
108   return substream->BlockHostUntilDone();
109 }
110 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)111 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
112     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
113     const TransferMetadata* transfer_metadata) {
114   StatusOr<Literal> ret;
115   // Implement the synchronous version by waiting on the asynchronous version.
116   // Use a substream so that if we are called from a HostCallback we don't
117   // deadlock.
118   se::Stream* substream = stream->GetOrCreateSubStream();
119   absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
120 
121   tensorflow::Notification n;
122   Literal literal(shape);
123   Status s;
124   TransferArrayFromDevice(
125       substream, shape, source, &literal,
126       [&](Status status) {
127         s = status;
128         n.Notify();
129       },
130       transfer_metadata);
131   n.WaitForNotification();
132   if (!s.ok()) {
133     return s;
134   }
135   return std::move(literal);
136 }
137 
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)138 Status TransferManager::TransferArrayToDevice(
139     se::Stream* stream, const LiteralSlice& literal,
140     const se::DeviceMemoryBase& dest,
141     const TransferMetadata* transfer_metadata) {
142   // Implement the synchronous version by waiting on the asynchronous version.
143   // Use a substream so that if we are called from a HostCallback we don't
144   // deadlock.
145   se::Stream* substream = stream->GetOrCreateSubStream();
146   absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); };
147   TF_RETURN_IF_ERROR(
148       TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
149   return substream->BlockHostUntilDone();
150 }
151 
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)152 Status TransferManager::TransferArrayToDeviceAsync(
153     se::Stream* stream, const LiteralSlice& literal,
154     const se::DeviceMemoryBase& dest,
155     const TransferMetadata* transfer_metadata) {
156   const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
157   TF_RET_CHECK(on_device_shape.IsArray())
158       << "On-device representation of "
159       << ShapeUtil::HumanString(literal.shape())
160       << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
161   if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
162     return FailedPrecondition(
163         "Allocation on device not large enough for array: "
164         "%d < %d",
165         dest.size(), GetByteSizeRequirement(on_device_shape));
166   }
167   ShapedBuffer shaped_buffer(on_device_shape,
168                              stream->parent()->device_ordinal());
169   shaped_buffer.set_buffer(dest, /*index=*/{});
170   return TransferLiteralToDevice(stream, literal, shaped_buffer,
171                                  transfer_metadata);
172 }
173 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)174 void TransferManager::TransferArrayFromDevice(
175     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
176     const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
177     const TransferMetadata* transfer_metadata) {
178   if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
179                                                  shape)) {
180     auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
181                         " has a differently shaped representation on-device: ",
182                         ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
183     return done(FailedPrecondition("%s", error));
184   }
185   if (source.size() < GetByteSizeRequirement(shape)) {
186     return done(
187         FailedPrecondition("Allocation on device not large enough for array: "
188                            "%d < %d",
189                            source.size(), GetByteSizeRequirement(shape)));
190   }
191   ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
192   shaped_buffer.set_buffer(source, /*index=*/{});
193   return TransferLiteralFromDevice(stream, shaped_buffer, literal,
194                                    std::move(done), transfer_metadata);
195 }
196 
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)197 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
198                                           ShapedBuffer* device_buffer,
199                                           Shape* device_shape) {
200   DCHECK(device_shape->is_dynamic());
201   Shape original_device_shape = *device_shape;
202   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
203 
204   TF_ASSIGN_OR_RETURN(auto compiler,
205                       Compiler::GetForPlatform(stream->parent()->platform()));
206   TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
207       [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
208         const Shape& buffer_shape =
209             ShapeUtil::GetSubshape(*device_shape, index);
210         if (buffer_shape.IsTuple()) {
211           return OkStatus();
212         }
213         Shape& device_sub_shape =
214             *ShapeUtil::GetMutableSubshape(device_shape, index);
215         if (device_sub_shape.is_static()) {
216           return OkStatus();
217         }
218 
219         // Read the dynamic shape metadata from the device stream.  The dynamic
220         // shape itself is stored at the end of the buffer.
221         auto shape_size_fn = compiler->ShapeSizeBytesFunction();
222         Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
223         const int64_t offset = shape_size_fn(buffer_shape_static);
224         int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
225         if (metadata_size == 0) {
226           return InvalidArgument("Dynamic shape metadata size should not be 0");
227         }
228         auto buffer_8 = se::DeviceMemory<uint8_t>(*buffer);
229         auto metadata_buffer =
230             stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
231         TF_ASSIGN_OR_RETURN(
232             auto metadata,
233             TransferArrayFromDevice(
234                 stream,
235                 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
236                 metadata_buffer));
237 
238         // Update shape size from metadata.
239         for (int64_t i = 0; i < metadata.element_count(); ++i) {
240           device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32_t>({i});
241         }
242         return OkStatus();
243       }));
244   device_shape->clear_dynamic_dimensions();
245 
246   TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
247                                                    original_device_shape));
248   return OkStatus();
249 }
250 
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)251 /* static */ void TransferManager::RegisterTransferManager(
252     se::Platform::Id platform_id,
253     TransferManagerCreationFunction creation_function) {
254   absl::MutexLock lock(&TransferManager::platform_transfer_manager_mutex_);
255   auto* managers = GetPlatformTransferManagers();
256   CHECK(managers->find(platform_id) == managers->end());
257   (*managers)[platform_id].creation_function = creation_function;
258 }
259 
GetForPlatform(const se::Platform * platform)260 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
261     const se::Platform* platform) {
262   absl::MutexLock lock(&TransferManager::platform_transfer_manager_mutex_);
263   auto* managers = GetPlatformTransferManagers();
264 
265   auto it = managers->find(platform->id());
266   if (it == managers->end()) {
267     return NotFound(
268         "could not find registered transfer manager for platform %s -- check "
269         "target linkage",
270         platform->Name());
271   }
272 
273   if (it->second.manager == nullptr) {
274     // Lazily create the transfer manager the first time it is needed
275     it->second.manager = (*it->second.creation_function)();
276   }
277 
278   return it->second.manager.get();
279 }
280 
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)281 Status TransferManager::WriteTupleIndexTables(
282     se::Stream* stream, const ShapedBuffer& device_buffer) {
283   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
284   return stream->BlockHostUntilDone();
285 }
286 
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)287 Status TransferManager::WriteTupleIndexTablesAsync(
288     se::Stream* stream, const ShapedBuffer& device_buffer) {
289   VLOG(2) << "Writing tuple index tables for " << device_buffer;
290 
291   return ShapeUtil::ForEachSubshapeWithStatus(
292       device_buffer.on_device_shape(),
293       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
294         if (device_subshape.IsTuple() &&
295             ShapeUtil::TupleElementCount(device_subshape) > 0) {
296           se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
297           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
298                        device_memory.size());
299 
300           std::vector<se::DeviceMemoryBase> elements;
301           ShapeIndex element_index = index;
302           for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
303                ++i) {
304             element_index.push_back(i);
305             elements.push_back(device_buffer.buffer(element_index));
306             element_index.pop_back();
307           }
308           return WriteSingleTupleIndexTable(stream, elements, device_subshape,
309                                             &device_memory);
310         }
311 
312         return OkStatus();
313       });
314 }
315 
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)316 Status TransferManager::WriteRootTupleIndexTable(
317     se::Stream* stream, const ShapedBuffer& device_buffer) {
318   TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
319   if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
320     return OkStatus();
321   }
322   se::DeviceMemoryBase device_memory = device_buffer.buffer({});
323   TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
324                device_memory.size());
325 
326   std::vector<se::DeviceMemoryBase> elements;
327   for (int64_t i = 0;
328        i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
329     elements.push_back(device_buffer.buffer({i}));
330   }
331   return WriteSingleTupleIndexTable(
332       stream, elements, device_buffer.on_device_shape(), &device_memory);
333 }
334 
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)335 Status TransferManager::WriteRootTupleIndexTable(
336     se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
337   TF_RET_CHECK(buffer_tree.shape().IsTuple());
338   if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
339     return OkStatus();
340   }
341   se::DeviceMemoryBase device_memory =
342       buffer_tree.element({}).AsDeviceMemoryBase();
343   TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
344                device_memory.size());
345 
346   std::vector<se::DeviceMemoryBase> elements;
347   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
348        ++i) {
349     elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
350   }
351   return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
352                                     &device_memory);
353 }
354 
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64_t size,void * destination)355 Status TransferManager::TransferBufferFromDevice(
356     se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size,
357     void* destination) {
358   if (source.size() < size) {
359     return FailedPrecondition(
360         "Source allocation on device not large enough for data transfer: "
361         "%d < %d",
362         source.size(), size);
363   }
364   stream->ThenMemcpy(destination, source, size);
365   return OkStatus();
366 }
367 
TransferBufferToDevice(se::Stream * stream,int64_t size,const void * source,se::DeviceMemoryBase * destination)368 Status TransferManager::TransferBufferToDevice(
369     se::Stream* stream, int64_t size, const void* source,
370     se::DeviceMemoryBase* destination) {
371   if (destination->size() < size) {
372     return FailedPrecondition(
373         "Destination allocation on device not large enough for data transfer: "
374         "%d < %d",
375         destination->size(), size);
376   }
377   stream->ThenMemcpy(destination, source, size);
378   return OkStatus();
379 }
380 
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal,DeviceShapeRepresentationFn shape_representation_fn)381 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
382     const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
383     int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn) {
384   if (!LayoutUtil::HasLayout(on_host_shape)) {
385     return InvalidArgument("Shape must have a layout: %s",
386                            ShapeUtil::HumanStringWithLayout(on_host_shape));
387   }
388   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
389   Shape on_device_shape = (shape_representation_fn == nullptr)
390                               ? HostShapeToDeviceShape(on_host_shape)
391                               : shape_representation_fn(on_host_shape);
392   TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
393 
394   ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
395                                    device_ordinal);
396 
397   // Allocate an appropriate sized buffer for each element in the shape
398   // including the tuple pointer arrays.
399   for (auto& pair : shaped_buffer.buffers()) {
400     const ShapeIndex& index = pair.first;
401     se::DeviceMemoryBase& memory_base = pair.second;
402     const Shape& subshape =
403         ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
404     TF_ASSIGN_OR_RETURN(auto memory,
405                         allocator->Allocate(shaped_buffer.device_ordinal(),
406                                             GetByteSizeRequirement(subshape),
407                                             /*retry_on_failure=*/true,
408                                             LayoutUtil::MemorySpace(subshape)));
409     // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
410     memory_base = memory.Release();
411   }
412 
413   return std::move(shaped_buffer);
414 }
415 
ChooseCompactLayoutForShape(const Shape & host_shape) const416 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
417     const Shape& host_shape) const {
418   return LayoutUtil::GetWithDefaultLayout(host_shape);
419 }
420 
ChooseGoodInfeedLayout(const Shape & shape) const421 xla::Shape TransferManager::ChooseGoodInfeedLayout(const Shape& shape) const {
422   return LayoutUtil::GetWithDefaultLayout(shape);
423 }
424 
425 }  // namespace xla
426