1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/infeed_ops.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/framework/allocator.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function_handle_cache.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/shape_inference.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/variant.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/kernels/transpose_functor.h"
38 #include "tensorflow/core/profiler/lib/traceme.h"
39 #include "tensorflow/core/tpu/kernels/transfer_ops.h"
40 #include "tensorflow/core/tpu/tpu_api.h"
41 #include "tensorflow/core/tpu/tpu_defs.h"
42 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
43 #include "tensorflow/stream_executor/multi_platform_manager.h"
44 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
45 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
46
47 namespace tensorflow {
48 namespace {
49
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
52 typedef std::deque<LinearizerBuffer> LinearizerBufferList;
53
54 // For the given shape, chooses a layout for infeed on TPU. The returned shape
55 // has the same dimensions as the original shape, and only the layout is
56 // changed.
GetTPUInfeedLayout(const xla::Shape & shape)57 xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
58 XLA_Shape c_shape;
59 XLA_Shape c_infeed_shape;
60
61 ApiConverter::ToC(shape, &c_shape);
62
63 tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
64 &c_infeed_shape);
65 xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
66 ApiConverter::Destroy(&c_shape);
67 ApiConverter::Destroy(&c_infeed_shape);
68 return infeed_shape;
69 }
70
71 // Transposes the given tensor using the tensorflow C++ transpose implementation
72 // to obtain a XLA literal for the host tensor laid out as the given layout. The
73 // returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
74 // is returned as F32[20,10,30]{2,1,0}.
TransposeTensor(OpKernelContext * ctx,const Tensor & input_tensor,const xla::Shape & xla_shape)75 xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
76 const Tensor& input_tensor,
77 const xla::Shape& xla_shape) {
78 profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
79 const int64_t rank = xla_shape.rank();
80 std::vector<int32> permutation(rank);
81 std::vector<int64_t> transposed_shapes(rank);
82 for (int64_t i = 0; i < rank; ++i) {
83 permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
84 transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
85 }
86
87 Tensor transposed_tensor;
88
89 // If this is a trivial transpose (i.e., bitcast), just create an aliased
90 // tensor with the transposed shape.
91 if (xla::LayoutUtil::IsMonotonicWithDim0Major(
92 xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
93 TensorShape shape;
94 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
95 TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
96 input_tensor, input_tensor.dtype(), shape));
97 return transposed_tensor;
98 }
99
100 AllocatorAttributes alloc_attr;
101 alloc_attr.set_on_host(true);
102 TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
103 TensorShape(transposed_shapes),
104 &transposed_tensor, alloc_attr));
105 // Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
106 if (input_tensor.NumElements() > 0) {
107 TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
108 input_tensor, permutation,
109 &transposed_tensor));
110 }
111 return transposed_tensor;
112 }
113
GetLayoutOverride(OpKernelConstruction * ctx,const char * attrn_name,std::vector<int64_t> * minor_to_major)114 xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
115 const char* attrn_name,
116 std::vector<int64_t>* minor_to_major) {
117 if (!ctx->HasAttr(attrn_name)) {
118 return false;
119 }
120 TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
121 return !minor_to_major->empty();
122 }
123
GetInfeedShapeWithLayout(OpKernelConstruction * ctx,const char * attrn_name,const xla::Shape & input_shape,xla::Shape * output_shape)124 Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
125 const char* attrn_name,
126 const xla::Shape& input_shape,
127 xla::Shape* output_shape) {
128 std::vector<int64_t> minor_to_major;
129 TF_ASSIGN_OR_RETURN(bool has_override,
130 GetLayoutOverride(ctx, attrn_name, &minor_to_major));
131 if (!has_override) {
132 *output_shape = input_shape;
133 if (output_shape->IsTuple()) {
134 int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
135 for (int64_t i = 0; i < tuple_elements; ++i) {
136 xla::Shape* sub_shape =
137 xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
138 *sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
139 }
140 } else {
141 *output_shape->mutable_layout() =
142 GetTPUInfeedLayout(*output_shape).layout();
143 }
144 return OkStatus();
145 }
146
147 auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
148 return GetTPUInfeedLayout(shape).layout();
149 };
150 return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
151 output_shape);
152 }
153
154 // LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
155 // PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
156 // buffers and references to input tensors whose underlying storage are shared
157 // with linearized buffers.
158 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
159 // specification. In particular, we cannot currently serialize an arbitrary
160 // `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
161 // object, so the `Encode()` and `Decode()` methods are not implemented.
162 struct LinearizedBuffersWrapper {
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper163 explicit LinearizedBuffersWrapper() {}
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper164 explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
165 std::vector<tensorflow::Tensor> ts)
166 : buffers(std::move(bufs)), tensors(std::move(ts)) {}
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper167 LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
168 // tensorflow::Variant requires this copy constructor to compile.
169 LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
170 }
171 LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
172 delete;
173 LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
174 LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
175 ~LinearizedBuffersWrapper() = default;
176
177 // These functions are tensorflow::Variant requirements.
TypeNametensorflow::__anon368a507e0111::LinearizedBuffersWrapper178 string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
Encodetensorflow::__anon368a507e0111::LinearizedBuffersWrapper179 void Encode(tensorflow::VariantTensorData* data) const {
180 LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
181 "objects.";
182 }
Decodetensorflow::__anon368a507e0111::LinearizedBuffersWrapper183 bool Decode(const tensorflow::VariantTensorData& data) {
184 LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
185 "objects.";
186 return false;
187 }
188
189 LinearizerBufferList buffers;
190 // Save references on tensors whose underlying storage are shared with
191 // LiteralLinearizer::Buffer in `buffers`.
192 std::vector<tensorflow::Tensor> tensors;
193 };
194
AutoTransposeAndLinearize(OpKernelContext * ctx,const Tensor & input_tensor,const xla::Shape & shape,LinearizerBufferList * linearized_buffers,std::vector<Tensor> * saved_input_tensors)195 Status AutoTransposeAndLinearize(OpKernelContext* ctx,
196 const Tensor& input_tensor,
197 const xla::Shape& shape,
198 LinearizerBufferList* linearized_buffers,
199 std::vector<Tensor>* saved_input_tensors) {
200 const Tensor* tensor = &input_tensor;
201 // If the given layout is not in dim0major layout, tranposes the tensor.
202 bool has_transposed = false;
203 Tensor transposed_tensor;
204 if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
205 // If the given layout is not in dim0major layout, transpose the tensor.
206 TF_ASSIGN_OR_RETURN(transposed_tensor,
207 TransposeTensor(ctx, input_tensor, shape));
208 tensor = &transposed_tensor;
209 has_transposed = true;
210 }
211
212 xla::BorrowingLiteral literal;
213 TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
214
215 TF_RETURN_IF_ERROR(
216 xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
217 ->LinearizeToBuffers(literal, linearized_buffers));
218
219 // The input tensor is ref-counted. Save a handle on the input tensor if
220 // its underlying storage is shared with linearized buffers to prevent
221 // input tensor from getting freed.
222 for (const auto& buffer : *linearized_buffers) {
223 if (!buffer.owns_data() && !has_transposed) {
224 // `buffer` is created from zero-copy fast path from the un-transposed
225 // input tensor so its underlying data is shared with input tensor.
226 // Save a handle to input tensor to increment its ref-count and avoid
227 // it getting deallocated after PrelinearizeTupleOp completes.
228 saved_input_tensors->push_back(*tensor);
229 // A literal can be linearized to zero to two buffers. If any of the
230 // linearized buffer shares storage with input tensor. We save exactly
231 // one handle on the input tensor.
232 break;
233 }
234 }
235 return OkStatus();
236 }
237
238 // PrelinearizeOp is used to linearize one tensor to the device format.
239 class PrelinearizeOp : public OpKernel {
240 public:
PrelinearizeOp(OpKernelConstruction * ctx)241 explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
242 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
243 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
244 xla::Shape shape;
245 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
246 OP_REQUIRES_OK(ctx,
247 GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
248 }
249
Compute(OpKernelContext * ctx)250 void Compute(OpKernelContext* ctx) override {
251 const Tensor& input_tensor = ctx->input(0);
252 // Validate input.
253 OP_REQUIRES(
254 ctx, input_tensor.dtype() == dtype_,
255 errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
256 DataType_Name(dtype_), ", got ",
257 DataType_Name(input_tensor.dtype())));
258 OP_REQUIRES(
259 ctx, input_tensor.shape() == shape_,
260 errors::InvalidArgument("Prelinearize shape mismatch; expected ",
261 shape_.DebugString(), ", got ",
262 input_tensor.shape().DebugString()));
263
264 // Auto-transpose and prelinearize.
265 LinearizerBufferList linearized_buffers;
266 std::vector<Tensor> saved_input_tensors;
267 auto status =
268 AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
269 &linearized_buffers, &saved_input_tensors);
270 OP_REQUIRES_OK(ctx, status);
271
272 // Write to output.
273 tensorflow::Tensor* output;
274 OP_REQUIRES_OK(ctx,
275 ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
276 output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
277 std::move(linearized_buffers), std::move(saved_input_tensors)};
278 }
279
IsExpensive()280 bool IsExpensive() override { return true; }
281
282 private:
283 TensorShape shape_;
284 DataType dtype_;
285 xla::Shape xla_shape_;
286
287 // PrelinearizeOp is neither copyable nor movable.
288 PrelinearizeOp(const PrelinearizeOp&) = delete;
289 PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
290 };
291
292 // PrelinearizeTupleOp is used to linearize multiple tensors to the device
293 // format.
294 class PrelinearizeTupleOp : public OpKernel {
295 public:
PrelinearizeTupleOp(OpKernelConstruction * ctx)296 explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
297 OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
298 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
299 OP_REQUIRES(
300 ctx, shapes_.size() == dtypes_.size(),
301 errors::InvalidArgument(
302 "shapes and dtypes must be the same length. shapes length = ",
303 shapes_.size(), ", dtypes length = ", dtypes_.size()));
304
305 std::vector<xla::Shape> xla_shapes;
306 for (int i = 0; i < shapes_.size(); i++) {
307 xla::Shape xla_shape;
308 OP_REQUIRES_OK(ctx,
309 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
310 xla_shapes.push_back(xla_shape);
311 }
312 OP_REQUIRES_OK(
313 ctx, GetInfeedShapeWithLayout(
314 ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
315 &tuple_shape_));
316 }
317
Compute(OpKernelContext * ctx)318 void Compute(OpKernelContext* ctx) override {
319 OpInputList values;
320 OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
321 OP_REQUIRES(ctx, values.size() == shapes_.size(),
322 errors::InvalidArgument(
323 "Wrong number of inputs to PrelinearizeTuple."));
324
325 LinearizerBufferList all_linearized_buffers;
326 std::vector<Tensor> all_saved_input_tensors;
327 for (int i = 0; i < values.size(); i++) {
328 // Validate input.
329 const Tensor& input_tensor = values[i];
330 OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
331 errors::InvalidArgument(
332 "PrelinearizeTuple dtype mismatch at tuple element ", i,
333 "; expected ", DataType_Name(dtypes_[i]), ", got ",
334 DataType_Name(input_tensor.dtype())));
335 OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
336 errors::InvalidArgument(
337 "PrelinearizeTuple shape mismatch at tuple element ", i,
338 "; expected ", shapes_[i].DebugString(), ", got ",
339 input_tensor.shape().DebugString()));
340
341 // Auto-transpose and prelinearize.
342 LinearizerBufferList linearized_buffers;
343 std::vector<Tensor> saved_input_tensors;
344 auto status = AutoTransposeAndLinearize(
345 ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
346 &saved_input_tensors);
347 OP_REQUIRES_OK(ctx, status);
348 all_linearized_buffers.insert(
349 all_linearized_buffers.end(),
350 std::make_move_iterator(linearized_buffers.begin()),
351 std::make_move_iterator(linearized_buffers.end()));
352 all_saved_input_tensors.insert(
353 all_saved_input_tensors.end(),
354 std::make_move_iterator(saved_input_tensors.begin()),
355 std::make_move_iterator(saved_input_tensors.end()));
356 }
357
358 tensorflow::Tensor* output;
359 OP_REQUIRES_OK(ctx,
360 ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
361 output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
362 std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
363 }
364
IsExpensive()365 bool IsExpensive() override { return true; }
366
367 private:
368 std::vector<TensorShape> shapes_;
369 DataTypeVector dtypes_;
370 xla::Shape tuple_shape_;
371
372 // PrelinearizeTupleOp is neither copyable nor movable.
373 PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
374 PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
375 };
376
377 class StreamExecutorInfeedEnqueueOp : public TpuInfeedEnqueueOp {
378 public:
StreamExecutorInfeedEnqueueOp(OpKernelConstruction * ctx)379 explicit StreamExecutorInfeedEnqueueOp(OpKernelConstruction* ctx)
380 : TpuInfeedEnqueueOp(ctx,
381 absl::make_unique<StreamExecutorTransferOpImpl>()) {}
382
383 private:
384 StreamExecutorInfeedEnqueueOp(const StreamExecutorInfeedEnqueueOp&) = delete;
385 StreamExecutorInfeedEnqueueOp& operator=(
386 const StreamExecutorInfeedEnqueueOp&) = delete;
387 };
388
389 class StreamExecutorInfeedEnqueueTupleOp : public TpuInfeedEnqueueTupleOp {
390 public:
StreamExecutorInfeedEnqueueTupleOp(OpKernelConstruction * ctx)391 explicit StreamExecutorInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
392 : TpuInfeedEnqueueTupleOp(
393 ctx, absl::make_unique<StreamExecutorTransferOpImpl>()) {}
394
395 private:
396 StreamExecutorInfeedEnqueueTupleOp(
397 const StreamExecutorInfeedEnqueueTupleOp&) = delete;
398 StreamExecutorInfeedEnqueueTupleOp& operator=(
399 const StreamExecutorInfeedEnqueueTupleOp&) = delete;
400 };
401
402 class StreamExecutorInfeedEnqueuePrelinearizedBufferOp
403 : public InfeedEnqueuePrelinearizedBufferOp {
404 public:
StreamExecutorInfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction * ctx)405 explicit StreamExecutorInfeedEnqueuePrelinearizedBufferOp(
406 OpKernelConstruction* ctx)
407 : InfeedEnqueuePrelinearizedBufferOp(
408 ctx, absl::make_unique<StreamExecutorTransferOpImpl>()) {}
409
410 private:
411 // InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
412 StreamExecutorInfeedEnqueuePrelinearizedBufferOp(
413 const StreamExecutorInfeedEnqueuePrelinearizedBufferOp&) = delete;
414 StreamExecutorInfeedEnqueuePrelinearizedBufferOp& operator=(
415 const StreamExecutorInfeedEnqueuePrelinearizedBufferOp&) = delete;
416 };
417 } // anonymous namespace
418
TpuInfeedEnqueueOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)419 TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(
420 OpKernelConstruction* ctx,
421 std::unique_ptr<TpuTransferOpInterface> transfer_op)
422 : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8,
423 std::move(transfer_op)) {
424 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
425 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
426 xla::Shape shape;
427 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
428 OP_REQUIRES_OK(ctx,
429 GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
430 }
431
DoWork(OpKernelContext * ctx,int device_ordinal)432 Status TpuInfeedEnqueueOp::DoWork(OpKernelContext* ctx, int device_ordinal) {
433 VLOG(1) << "TpuInfeedEnqueueOp::DoWork. iter_id=" << ctx->frame_iter().iter_id
434 << " device_ordinal=" << device_ordinal;
435 const Tensor& input_tensor = ctx->input(0);
436
437 // Validate runtime shape and fail if it doesn't match the contract.
438 if (input_tensor.dtype() != dtype_) {
439 return errors::InvalidArgument("Infeed dtype mismatch.");
440 }
441 if (input_tensor.shape() != shape_) {
442 return errors::InvalidArgument("Infeed shape mismatch; expected ",
443 shape_.DebugString(), ", got ",
444 input_tensor.shape().DebugString());
445 }
446
447 const Tensor* tensor = &input_tensor;
448 Tensor transposed_tensor;
449 if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
450 // If the given layout is not in dim0major layout, transpose the tensor.
451 TF_ASSIGN_OR_RETURN(transposed_tensor,
452 TransposeTensor(ctx, input_tensor, xla_shape_));
453 tensor = &transposed_tensor;
454 }
455
456 xla::BorrowingLiteral literal;
457 TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
458
459 // Transfer the given literal to the Infeed interface of the device.
460 TF_RETURN_IF_ERROR(
461 transfer_op_->TransferLiteralToInfeed(device_ordinal, literal));
462 VLOG(1) << "TpuInfeedEnqueueOp completes. iter_id="
463 << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
464 return OkStatus();
465 }
466
TpuInfeedEnqueueTupleOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)467 TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(
468 OpKernelConstruction* ctx,
469 std::unique_ptr<TpuTransferOpInterface> transfer_op)
470 : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8,
471 std::move(transfer_op)) {
472 OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
473 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
474 OP_REQUIRES(
475 ctx, shapes_.size() == dtypes_.size(),
476 errors::InvalidArgument("shapes and dtypes must be the same length."));
477
478 std::vector<xla::Shape> xla_shapes;
479 for (int i = 0; i < shapes_.size(); i++) {
480 xla::Shape xla_shape;
481 OP_REQUIRES_OK(ctx,
482 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
483 xla_shapes.push_back(xla_shape);
484 }
485 OP_REQUIRES_OK(
486 ctx, GetInfeedShapeWithLayout(ctx, "layouts",
487 xla::ShapeUtil::MakeTupleShape(xla_shapes),
488 &tuple_shape_));
489 }
490
DoWork(OpKernelContext * ctx,int device_ordinal)491 Status TpuInfeedEnqueueTupleOp::DoWork(OpKernelContext* ctx,
492 int device_ordinal) {
493 VLOG(1) << "TpuInfeedEnqueueTupleOp::DoWork. iter_id="
494 << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
495 OpInputList values;
496 TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
497 if (values.size() != shapes_.size()) {
498 return errors::InvalidArgument(
499 "Wrong number of inputs to InfeedEnqueueTuple.");
500 }
501
502 for (const auto& shapes : shapes_) {
503 VLOG(2) << "TransferLiteralToInfeed " << shapes.DebugString();
504 }
505
506 std::vector<Tensor> maybe_transposed_tensors;
507 maybe_transposed_tensors.reserve(values.size());
508 for (int i = 0; i < values.size(); i++) {
509 // Validate runtime shapes and fail if it doesn't match the contract.
510 const Tensor* tensor = &values[i];
511 if (tensor->shape() != shapes_[i]) {
512 return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
513 i, "; expected ", shapes_[i].DebugString(),
514 ", got ", tensor->shape().DebugString());
515 }
516 if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
517 tuple_shape_.tuple_shapes(i).layout())) {
518 // If the given layout is not in dim0major layout, tranposes the given
519 // tensor.
520 TF_ASSIGN_OR_RETURN(
521 Tensor transposed_tensor,
522 TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
523 maybe_transposed_tensors.emplace_back(transposed_tensor);
524 } else {
525 maybe_transposed_tensors.emplace_back(*tensor);
526 }
527 }
528
529 xla::BorrowingLiteral tuple;
530 TF_RETURN_IF_ERROR(
531 HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
532
533 // Transfer the given literal to the Infeed interface of the device.
534 TF_RETURN_IF_ERROR(
535 transfer_op_->TransferLiteralToInfeed(device_ordinal, tuple));
536
537 VLOG(1) << "TpuInfeedEnqueueTupleOp completes. iter_id="
538 << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
539
540 return OkStatus();
541 }
542
InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)543 InfeedEnqueuePrelinearizedBufferOp::InfeedEnqueuePrelinearizedBufferOp(
544 OpKernelConstruction* ctx,
545 std::unique_ptr<TpuTransferOpInterface> transfer_op)
546 : TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8,
547 std::move(transfer_op)) {}
DoWork(OpKernelContext * ctx,int device_ordinal)548 Status InfeedEnqueuePrelinearizedBufferOp::DoWork(OpKernelContext* ctx,
549 int device_ordinal) {
550 const Tensor& input_tensor = ctx->input(0);
551 const LinearizedBuffersWrapper* wrapper =
552 input_tensor.scalar<tensorflow::Variant>()()
553 .get<LinearizedBuffersWrapper>();
554 TF_RETURN_IF_ERROR(
555 transfer_op_->TransferBuffersToInfeed(device_ordinal, wrapper->buffers));
556
557 return OkStatus();
558 }
559
560 // These ops execute on either the TPU device or the CPU device. When running on
561 // CPU they must specify a non-negative value for device_ordinal to indicate
562 // which TPU to send infeed to.
563 REGISTER_KERNEL_BUILDER(
564 Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
565 StreamExecutorInfeedEnqueueOp);
566 REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
567 StreamExecutorInfeedEnqueueOp);
568
569 REGISTER_KERNEL_BUILDER(
570 Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
571 StreamExecutorInfeedEnqueueTupleOp);
572 REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
573 StreamExecutorInfeedEnqueueTupleOp);
574
575 // Prelinearize ops run on CPU as part of tf.data input pipeline.
576 REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
577 PrelinearizeOp);
578 REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
579 PrelinearizeTupleOp);
580
581 // InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
582 // select the right device to infeed.
583 REGISTER_KERNEL_BUILDER(
584 Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
585 StreamExecutorInfeedEnqueuePrelinearizedBufferOp);
586
587 } // namespace tensorflow
588