1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_transfer_manager.h"
17
18 #include <utility>
19
20 #include "absl/cleanup/cleanup.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
24 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h"
25 #include "tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h"
26 #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
28 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h"
29 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h"
30 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h"
31 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_id.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/tpu/tpu_api.h"
34
35 namespace tensorflow {
36 namespace tpu {
37
38 using Status = stream_executor::port::Status;
39 template <typename T>
40 using StatusOr = stream_executor::port::StatusOr<T>;
41
TpuTransferManager()42 TpuTransferManager::TpuTransferManager() {
43 manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
44 }
45
~TpuTransferManager()46 TpuTransferManager::~TpuTransferManager() {
47 tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
48 }
49
PlatformId() const50 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
51 return GetTpuPlatformId();
52 }
53
HostShapeToDeviceShape(const xla::Shape & host_shape) const54 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
55 const xla::Shape& host_shape) const {
56 XLA_Shape c_host_shape;
57 XLA_Shape c_device_shape;
58
59 ApiConverter::ToC(host_shape, &c_host_shape);
60
61 tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
62 manager_, &c_host_shape, &c_device_shape);
63 xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
64 ApiConverter::Destroy(&c_host_shape);
65 ApiConverter::Destroy(&c_device_shape);
66 return device_shape;
67 }
68
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)69 Status TpuTransferManager::TransferLiteralToDeviceAsync(
70 stream_executor::Stream* stream, const xla::LiteralSlice& literal,
71 const xla::ShapedBuffer& device_buffer,
72 const TransferMetadata* transfer_metadata) {
73 StatusHelper status;
74
75 XLA_Literal c_literal;
76 ApiConverter::ToC(literal, &c_literal);
77
78 XLA_ShapedBuffer c_device_buffer;
79 ApiConverter::ToC(device_buffer, &c_device_buffer);
80
81 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
82 manager_,
83 TpuPlatform::GetRegisteredPlatform()->LookupStream(
84 stream->implementation()),
85 &c_literal, &c_device_buffer, status.c_status);
86 ApiConverter::Destroy(&c_device_buffer);
87 ApiConverter::Destroy(&c_literal);
88 return status.status();
89 }
90
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)91 Status TpuTransferManager::TransferLiteralToInfeed(
92 stream_executor::StreamExecutor* executor,
93 const xla::LiteralSlice& literal) {
94 StatusHelper status;
95 XLA_Literal c_literal;
96 ApiConverter::ToC(literal, &c_literal);
97 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
98
99 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
100 manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
101
102 ApiConverter::Destroy(&c_literal);
103
104 return status.status();
105 }
106
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)107 Status TpuTransferManager::TransferBuffersToInfeed(
108 se::StreamExecutor* executor,
109 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
110 StatusHelper status;
111 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
112
113 std::vector<int64_t> buffers_size;
114 std::vector<uint32_t*> buffers_array;
115
116 buffers_size.reserve(buffers.size());
117 buffers_array.reserve(buffers.size());
118
119 for (int64_t i = 0; i < buffers.size(); ++i) {
120 absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
121 buffers_array.push_back(const_cast<uint32_t*>(span.data()));
122 buffers_size.push_back(span.size());
123 }
124
125 tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
126 manager_, tpu_executor->se_executor(), buffers_array.data(),
127 buffers_size.data(), buffers_size.size(), status.c_status);
128 return status.status();
129 }
130
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)131 Status TpuTransferManager::TransferLiteralFromOutfeed(
132 stream_executor::StreamExecutor* executor,
133 xla::MutableBorrowingLiteral literal) {
134 StatusHelper status;
135 XLA_Shape c_shape;
136 XLA_Literal c_literal;
137 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
138
139 ApiConverter::ToC(literal.shape(), &c_shape);
140 ApiConverter::ToC(literal, &c_literal);
141
142 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
143 manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
144 status.c_status);
145
146 ApiConverter::Destroy(&c_shape);
147 ApiConverter::Destroy(&c_literal);
148
149 return status.status();
150 }
151
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)152 Status TpuTransferManager::ResetDevices(
153 absl::Span<stream_executor::StreamExecutor* const> executor) {
154 StatusHelper status;
155 std::vector<SE_StreamExecutor*> se;
156 se.reserve(executor.size());
157 for (int64_t i = 0; i < executor.size(); ++i) {
158 se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
159 ->se_executor());
160 }
161
162 tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
163 manager_, se.data(), executor.size(), status.c_status);
164 return status.status();
165 }
166
167 struct TransferFromDeviceState {
168 std::atomic<int64_t> remaining_transfers;
169 TF_Status* overall_status =
170 tpu::ExecutorApiFn()->TpuStatus_NewFn(); // OK or the first error
171 std::function<void(Status)> done;
172
TransferFinishedtensorflow::tpu::TransferFromDeviceState173 void TransferFinished(TF_Status* status) {
174 if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
175 tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
176 std::swap(overall_status, status);
177 }
178 tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
179
180 if (--remaining_transfers == 0) {
181 done(StatusHelper::FromC(overall_status));
182 tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
183 delete this;
184 }
185 }
186 };
187
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)188 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
189 reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
190 }
191
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)192 void TpuTransferManager::TransferLiteralFromDevice(
193 stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
194 xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
195 const TransferMetadata* transfer_metadata) {
196 TransferFromDeviceState* state = new TransferFromDeviceState;
197 state->remaining_transfers = 1;
198 state->done = done;
199 XLA_ShapedBuffer c_device_buffer;
200 ApiConverter::ToC(device_buffer, &c_device_buffer);
201 XLA_Literal c_literal;
202 ApiConverter::ToC(literal, &c_literal);
203
204 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
205 manager_,
206 TpuPlatform::GetRegisteredPlatform()->LookupStream(
207 stream->implementation()),
208 &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
209 ApiConverter::Destroy(&c_device_buffer);
210 ApiConverter::Destroy(&c_literal);
211 }
212
GetByteSizeRequirement(const xla::Shape & shape) const213 int64_t TpuTransferManager::GetByteSizeRequirement(
214 const xla::Shape& shape) const {
215 XLA_Shape c_shape;
216 ApiConverter::ToC(shape, &c_shape);
217
218 int64_t size_in_bytes =
219 tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
220 manager_, &c_shape);
221
222 ApiConverter::Destroy(&c_shape);
223 return size_in_bytes;
224 }
225
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const226 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
227 const xla::Shape& host_shape) const {
228 XLA_Shape c_host_shape;
229 ApiConverter::ToC(host_shape, &c_host_shape);
230 XLA_Shape c_output;
231 StatusHelper status;
232 tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
233 manager_, &c_host_shape, &c_output, status.c_status);
234 // TODO(skyewm): use a scoped version of XLA_Shape
235 ApiConverter::Destroy(&c_host_shape);
236 if (!status.status().ok()) {
237 ApiConverter::Destroy(&c_output);
238 return status.status();
239 }
240 xla::Shape output = ApiConverter::FromC(&c_output);
241 ApiConverter::Destroy(&c_output);
242 return output;
243 }
244
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const245 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
246 stream_executor::StreamExecutor* executor,
247 const xla::ShapedBuffer& device_buffer) const {
248 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
249 XLA_ShapedBuffer c_device_buffer;
250 ApiConverter::ToC(device_buffer, &c_device_buffer);
251 absl::Cleanup cleanup = [&c_device_buffer]() {
252 ApiConverter::Destroy(&c_device_buffer);
253 };
254 return tpu::ExecutorApiFn()
255 ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
256 manager_, tpu_executor->se_executor(), &c_device_buffer);
257 }
258
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const259 bool TpuTransferManager::CanBufferBeAccessedNow(
260 se::StreamExecutor* executor,
261 const se::DeviceMemoryBase& device_buffer) const {
262 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
263 SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
264 device_buffer.size(),
265 device_buffer.payload()};
266 return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
267 manager_, tpu_executor->se_executor(), &c_device_buffer);
268 }
269
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)270 Status TpuTransferManager::WriteSingleTupleIndexTable(
271 stream_executor::Stream* stream,
272 absl::Span<const stream_executor::DeviceMemoryBase> elements,
273 const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
274 CHECK_GT(elements.size(), 0);
275 SE_DeviceMemoryBase* elements_bases =
276 new SE_DeviceMemoryBase[elements.size()];
277 for (int i = 0; i < elements.size(); i++) {
278 elements_bases[i] =
279 SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
280 elements[i].size(), elements[i].payload()};
281 }
282 XLA_Shape c_shape;
283 ApiConverter::ToC(shape, &c_shape);
284 SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
285 region->payload()};
286 StatusHelper status;
287
288 tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
289 manager_,
290 TpuPlatform::GetRegisteredPlatform()->LookupStream(
291 stream->implementation()),
292 elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status);
293
294 delete[] elements_bases;
295 ApiConverter::Destroy(&c_shape);
296 return status.status();
297 }
298
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)299 Status TpuTransferManager::LinearizeToBuffers(
300 const xla::LiteralSlice& literal,
301 std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
302 XLA_Literal c_literal;
303 ApiConverter::ToC(literal, &c_literal);
304
305 char** buffers_array;
306 int64_t* buffers_size;
307 int64_t buffers_array_size;
308 StatusHelper status;
309
310 tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
311 manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
312 status.c_status);
313
314 for (int64_t i = 0; i < buffers_array_size; ++i) {
315 tpu::NoncopyableBuffer buf(buffers_size[i]);
316 memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
317 buffers_size[i]);
318 buffers->push_back(std::move(buf));
319 }
320
321 tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
322 buffers_array, buffers_size, buffers_array_size);
323
324 ApiConverter::Destroy(&c_literal);
325 return status.status();
326 }
327
ReadDynamicShapes(se::Stream * stream,xla::ShapedBuffer * device_buffer,xla::Shape * device_shape)328 Status TpuTransferManager::ReadDynamicShapes(se::Stream* stream,
329 xla::ShapedBuffer* device_buffer,
330 xla::Shape* device_shape) {
331 XLA_ShapedBuffer c_device_buffer;
332 XLA_Shape c_device_shape;
333 ApiConverter::ToC(*device_buffer, &c_device_buffer);
334 ApiConverter::ToC(*device_shape, &c_device_shape);
335 XLA_Shape c_updated_shape;
336 StatusHelper status;
337 ExecutorApiFn()->TpuTransferManager_ReadDynamicShapesFn(
338 TpuPlatform::GetRegisteredPlatform()->LookupStream(
339 stream->implementation()),
340 &c_device_buffer, c_device_shape, &c_updated_shape, status.c_status);
341 ApiConverter::Destroy(&c_device_buffer);
342 ApiConverter::Destroy(&c_device_shape);
343 if (!status.ok()) {
344 return status.status();
345 }
346 *device_shape = ApiConverter::FromC(&c_updated_shape);
347 ApiConverter::Destroy(&c_updated_shape);
348 return OkStatus();
349 }
350
351 } // namespace tpu
352 } // namespace tensorflow
353