xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_transfer_manager.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_transfer_manager.h"
17 
18 #include <utility>
19 
20 #include "absl/cleanup/cleanup.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
24 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h"
25 #include "tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h"
26 #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
28 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h"
29 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h"
30 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h"
31 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_id.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/tpu/tpu_api.h"
34 
35 namespace tensorflow {
36 namespace tpu {
37 
38 using Status = stream_executor::port::Status;
39 template <typename T>
40 using StatusOr = stream_executor::port::StatusOr<T>;
41 
TpuTransferManager()42 TpuTransferManager::TpuTransferManager() {
43   manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
44 }
45 
~TpuTransferManager()46 TpuTransferManager::~TpuTransferManager() {
47   tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
48 }
49 
PlatformId() const50 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
51   return GetTpuPlatformId();
52 }
53 
HostShapeToDeviceShape(const xla::Shape & host_shape) const54 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
55     const xla::Shape& host_shape) const {
56   XLA_Shape c_host_shape;
57   XLA_Shape c_device_shape;
58 
59   ApiConverter::ToC(host_shape, &c_host_shape);
60 
61   tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
62       manager_, &c_host_shape, &c_device_shape);
63   xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
64   ApiConverter::Destroy(&c_host_shape);
65   ApiConverter::Destroy(&c_device_shape);
66   return device_shape;
67 }
68 
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)69 Status TpuTransferManager::TransferLiteralToDeviceAsync(
70     stream_executor::Stream* stream, const xla::LiteralSlice& literal,
71     const xla::ShapedBuffer& device_buffer,
72     const TransferMetadata* transfer_metadata) {
73   StatusHelper status;
74 
75   XLA_Literal c_literal;
76   ApiConverter::ToC(literal, &c_literal);
77 
78   XLA_ShapedBuffer c_device_buffer;
79   ApiConverter::ToC(device_buffer, &c_device_buffer);
80 
81   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
82       manager_,
83       TpuPlatform::GetRegisteredPlatform()->LookupStream(
84           stream->implementation()),
85       &c_literal, &c_device_buffer, status.c_status);
86   ApiConverter::Destroy(&c_device_buffer);
87   ApiConverter::Destroy(&c_literal);
88   return status.status();
89 }
90 
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)91 Status TpuTransferManager::TransferLiteralToInfeed(
92     stream_executor::StreamExecutor* executor,
93     const xla::LiteralSlice& literal) {
94   StatusHelper status;
95   XLA_Literal c_literal;
96   ApiConverter::ToC(literal, &c_literal);
97   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
98 
99   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
100       manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
101 
102   ApiConverter::Destroy(&c_literal);
103 
104   return status.status();
105 }
106 
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)107 Status TpuTransferManager::TransferBuffersToInfeed(
108     se::StreamExecutor* executor,
109     const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
110   StatusHelper status;
111   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
112 
113   std::vector<int64_t> buffers_size;
114   std::vector<uint32_t*> buffers_array;
115 
116   buffers_size.reserve(buffers.size());
117   buffers_array.reserve(buffers.size());
118 
119   for (int64_t i = 0; i < buffers.size(); ++i) {
120     absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
121     buffers_array.push_back(const_cast<uint32_t*>(span.data()));
122     buffers_size.push_back(span.size());
123   }
124 
125   tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
126       manager_, tpu_executor->se_executor(), buffers_array.data(),
127       buffers_size.data(), buffers_size.size(), status.c_status);
128   return status.status();
129 }
130 
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)131 Status TpuTransferManager::TransferLiteralFromOutfeed(
132     stream_executor::StreamExecutor* executor,
133     xla::MutableBorrowingLiteral literal) {
134   StatusHelper status;
135   XLA_Shape c_shape;
136   XLA_Literal c_literal;
137   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
138 
139   ApiConverter::ToC(literal.shape(), &c_shape);
140   ApiConverter::ToC(literal, &c_literal);
141 
142   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
143       manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
144       status.c_status);
145 
146   ApiConverter::Destroy(&c_shape);
147   ApiConverter::Destroy(&c_literal);
148 
149   return status.status();
150 }
151 
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)152 Status TpuTransferManager::ResetDevices(
153     absl::Span<stream_executor::StreamExecutor* const> executor) {
154   StatusHelper status;
155   std::vector<SE_StreamExecutor*> se;
156   se.reserve(executor.size());
157   for (int64_t i = 0; i < executor.size(); ++i) {
158     se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
159                      ->se_executor());
160   }
161 
162   tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
163       manager_, se.data(), executor.size(), status.c_status);
164   return status.status();
165 }
166 
167 struct TransferFromDeviceState {
168   std::atomic<int64_t> remaining_transfers;
169   TF_Status* overall_status =
170       tpu::ExecutorApiFn()->TpuStatus_NewFn();  // OK or the first error
171   std::function<void(Status)> done;
172 
TransferFinishedtensorflow::tpu::TransferFromDeviceState173   void TransferFinished(TF_Status* status) {
174     if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
175         tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
176       std::swap(overall_status, status);
177     }
178     tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
179 
180     if (--remaining_transfers == 0) {
181       done(StatusHelper::FromC(overall_status));
182       tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
183       delete this;
184     }
185   }
186 };
187 
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)188 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
189   reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
190 }
191 
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)192 void TpuTransferManager::TransferLiteralFromDevice(
193     stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
194     xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
195     const TransferMetadata* transfer_metadata) {
196   TransferFromDeviceState* state = new TransferFromDeviceState;
197   state->remaining_transfers = 1;
198   state->done = done;
199   XLA_ShapedBuffer c_device_buffer;
200   ApiConverter::ToC(device_buffer, &c_device_buffer);
201   XLA_Literal c_literal;
202   ApiConverter::ToC(literal, &c_literal);
203 
204   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
205       manager_,
206       TpuPlatform::GetRegisteredPlatform()->LookupStream(
207           stream->implementation()),
208       &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
209   ApiConverter::Destroy(&c_device_buffer);
210   ApiConverter::Destroy(&c_literal);
211 }
212 
GetByteSizeRequirement(const xla::Shape & shape) const213 int64_t TpuTransferManager::GetByteSizeRequirement(
214     const xla::Shape& shape) const {
215   XLA_Shape c_shape;
216   ApiConverter::ToC(shape, &c_shape);
217 
218   int64_t size_in_bytes =
219       tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
220           manager_, &c_shape);
221 
222   ApiConverter::Destroy(&c_shape);
223   return size_in_bytes;
224 }
225 
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const226 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
227     const xla::Shape& host_shape) const {
228   XLA_Shape c_host_shape;
229   ApiConverter::ToC(host_shape, &c_host_shape);
230   XLA_Shape c_output;
231   StatusHelper status;
232   tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
233       manager_, &c_host_shape, &c_output, status.c_status);
234   // TODO(skyewm): use a scoped version of XLA_Shape
235   ApiConverter::Destroy(&c_host_shape);
236   if (!status.status().ok()) {
237     ApiConverter::Destroy(&c_output);
238     return status.status();
239   }
240   xla::Shape output = ApiConverter::FromC(&c_output);
241   ApiConverter::Destroy(&c_output);
242   return output;
243 }
244 
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const245 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
246     stream_executor::StreamExecutor* executor,
247     const xla::ShapedBuffer& device_buffer) const {
248   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
249   XLA_ShapedBuffer c_device_buffer;
250   ApiConverter::ToC(device_buffer, &c_device_buffer);
251   absl::Cleanup cleanup = [&c_device_buffer]() {
252     ApiConverter::Destroy(&c_device_buffer);
253   };
254   return tpu::ExecutorApiFn()
255       ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
256           manager_, tpu_executor->se_executor(), &c_device_buffer);
257 }
258 
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const259 bool TpuTransferManager::CanBufferBeAccessedNow(
260     se::StreamExecutor* executor,
261     const se::DeviceMemoryBase& device_buffer) const {
262   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
263   SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
264                                       device_buffer.size(),
265                                       device_buffer.payload()};
266   return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
267       manager_, tpu_executor->se_executor(), &c_device_buffer);
268 }
269 
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)270 Status TpuTransferManager::WriteSingleTupleIndexTable(
271     stream_executor::Stream* stream,
272     absl::Span<const stream_executor::DeviceMemoryBase> elements,
273     const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
274   CHECK_GT(elements.size(), 0);
275   SE_DeviceMemoryBase* elements_bases =
276       new SE_DeviceMemoryBase[elements.size()];
277   for (int i = 0; i < elements.size(); i++) {
278     elements_bases[i] =
279         SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
280                             elements[i].size(), elements[i].payload()};
281   }
282   XLA_Shape c_shape;
283   ApiConverter::ToC(shape, &c_shape);
284   SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
285                                   region->payload()};
286   StatusHelper status;
287 
288   tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
289       manager_,
290       TpuPlatform::GetRegisteredPlatform()->LookupStream(
291           stream->implementation()),
292       elements_bases, elements.size(), &c_shape, &region_base, status.c_status);
293 
294   delete[] elements_bases;
295   ApiConverter::Destroy(&c_shape);
296   return status.status();
297 }
298 
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)299 Status TpuTransferManager::LinearizeToBuffers(
300     const xla::LiteralSlice& literal,
301     std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
302   XLA_Literal c_literal;
303   ApiConverter::ToC(literal, &c_literal);
304 
305   char** buffers_array;
306   int64_t* buffers_size;
307   int64_t buffers_array_size;
308   StatusHelper status;
309 
310   tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
311       manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
312       status.c_status);
313 
314   for (int64_t i = 0; i < buffers_array_size; ++i) {
315     tpu::NoncopyableBuffer buf(buffers_size[i]);
316     memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
317            buffers_size[i]);
318     buffers->push_back(std::move(buf));
319   }
320 
321   tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
322       buffers_array, buffers_size, buffers_array_size);
323 
324   ApiConverter::Destroy(&c_literal);
325   return status.status();
326 }
327 
ReadDynamicShapes(se::Stream * stream,xla::ShapedBuffer * device_buffer,xla::Shape * device_shape)328 Status TpuTransferManager::ReadDynamicShapes(se::Stream* stream,
329                                              xla::ShapedBuffer* device_buffer,
330                                              xla::Shape* device_shape) {
331   XLA_ShapedBuffer c_device_buffer;
332   XLA_Shape c_device_shape;
333   ApiConverter::ToC(*device_buffer, &c_device_buffer);
334   ApiConverter::ToC(*device_shape, &c_device_shape);
335   XLA_Shape c_updated_shape;
336   StatusHelper status;
337   ExecutorApiFn()->TpuTransferManager_ReadDynamicShapesFn(
338       TpuPlatform::GetRegisteredPlatform()->LookupStream(
339           stream->implementation()),
340       &c_device_buffer, c_device_shape, &c_updated_shape, status.c_status);
341   ApiConverter::Destroy(&c_device_buffer);
342   ApiConverter::Destroy(&c_device_shape);
343   if (!status.ok()) {
344     return status.status();
345   }
346   *device_shape = ApiConverter::FromC(&c_updated_shape);
347   ApiConverter::Destroy(&c_updated_shape);
348   return OkStatus();
349 }
350 
351 }  // namespace tpu
352 }  // namespace tensorflow
353