xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/infeed_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/infeed_ops.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/framework/allocator.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function_handle_cache.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/shape_inference.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/variant.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/kernels/transpose_functor.h"
38 #include "tensorflow/core/profiler/lib/traceme.h"
39 #include "tensorflow/core/tpu/kernels/transfer_ops.h"
40 #include "tensorflow/core/tpu/tpu_api.h"
41 #include "tensorflow/core/tpu/tpu_defs.h"
42 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
43 #include "tensorflow/stream_executor/multi_platform_manager.h"
44 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
45 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
46 
47 namespace tensorflow {
48 namespace {
49 
50 typedef Eigen::ThreadPoolDevice CPUDevice;
51 typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
52 typedef std::deque<LinearizerBuffer> LinearizerBufferList;
53 
54 // For the given shape, chooses a layout for infeed on TPU. The returned shape
55 // has the same dimensions as the original shape, and only the layout is
56 // changed.
GetTPUInfeedLayout(const xla::Shape & shape)57 xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
58   XLA_Shape c_shape;
59   XLA_Shape c_infeed_shape;
60 
61   ApiConverter::ToC(shape, &c_shape);
62 
63   tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
64                                                              &c_infeed_shape);
65   xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
66   ApiConverter::Destroy(&c_shape);
67   ApiConverter::Destroy(&c_infeed_shape);
68   return infeed_shape;
69 }
70 
71 // Transposes the given tensor using the tensorflow C++ transpose implementation
72 // to obtain a XLA literal for the host tensor laid out as the given layout. The
73 // returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
74 // is returned as F32[20,10,30]{2,1,0}.
TransposeTensor(OpKernelContext * ctx,const Tensor & input_tensor,const xla::Shape & xla_shape)75 xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
76                                       const Tensor& input_tensor,
77                                       const xla::Shape& xla_shape) {
78   profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
79   const int64_t rank = xla_shape.rank();
80   std::vector<int32> permutation(rank);
81   std::vector<int64_t> transposed_shapes(rank);
82   for (int64_t i = 0; i < rank; ++i) {
83     permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
84     transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
85   }
86 
87   Tensor transposed_tensor;
88 
89   // If this is a trivial transpose (i.e., bitcast), just create an aliased
90   // tensor with the transposed shape.
91   if (xla::LayoutUtil::IsMonotonicWithDim0Major(
92           xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
93     TensorShape shape;
94     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
95     TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
96         input_tensor, input_tensor.dtype(), shape));
97     return transposed_tensor;
98   }
99 
100   AllocatorAttributes alloc_attr;
101   alloc_attr.set_on_host(true);
102   TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
103                                         TensorShape(transposed_shapes),
104                                         &transposed_tensor, alloc_attr));
105   // Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
106   if (input_tensor.NumElements() > 0) {
107     TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
108                                               input_tensor, permutation,
109                                               &transposed_tensor));
110   }
111   return transposed_tensor;
112 }
113 
GetLayoutOverride(OpKernelConstruction * ctx,const char * attrn_name,std::vector<int64_t> * minor_to_major)114 xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
115                                       const char* attrn_name,
116                                       std::vector<int64_t>* minor_to_major) {
117   if (!ctx->HasAttr(attrn_name)) {
118     return false;
119   }
120   TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
121   return !minor_to_major->empty();
122 }
123 
GetInfeedShapeWithLayout(OpKernelConstruction * ctx,const char * attrn_name,const xla::Shape & input_shape,xla::Shape * output_shape)124 Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
125                                 const char* attrn_name,
126                                 const xla::Shape& input_shape,
127                                 xla::Shape* output_shape) {
128   std::vector<int64_t> minor_to_major;
129   TF_ASSIGN_OR_RETURN(bool has_override,
130                       GetLayoutOverride(ctx, attrn_name, &minor_to_major));
131   if (!has_override) {
132     *output_shape = input_shape;
133     if (output_shape->IsTuple()) {
134       int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
135       for (int64_t i = 0; i < tuple_elements; ++i) {
136         xla::Shape* sub_shape =
137             xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
138         *sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
139       }
140     } else {
141       *output_shape->mutable_layout() =
142           GetTPUInfeedLayout(*output_shape).layout();
143     }
144     return OkStatus();
145   }
146 
147   auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
148     return GetTPUInfeedLayout(shape).layout();
149   };
150   return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
151                             output_shape);
152 }
153 
154 // LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
155 // PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
156 // buffers and references to input tensors whose underlying storage are shared
157 // with linearized buffers.
158 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
159 // specification. In particular, we cannot currently serialize an arbitrary
160 // `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
161 // object, so the `Encode()` and `Decode()` methods are not implemented.
162 struct LinearizedBuffersWrapper {
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper163   explicit LinearizedBuffersWrapper() {}
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper164   explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
165                                     std::vector<tensorflow::Tensor> ts)
166       : buffers(std::move(bufs)), tensors(std::move(ts)) {}
LinearizedBuffersWrappertensorflow::__anon368a507e0111::LinearizedBuffersWrapper167   LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
168     // tensorflow::Variant requires this copy constructor to compile.
169     LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
170   }
171   LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
172       delete;
173   LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
174   LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
175   ~LinearizedBuffersWrapper() = default;
176 
177   // These functions are tensorflow::Variant requirements.
TypeNametensorflow::__anon368a507e0111::LinearizedBuffersWrapper178   string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
Encodetensorflow::__anon368a507e0111::LinearizedBuffersWrapper179   void Encode(tensorflow::VariantTensorData* data) const {
180     LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
181                   "objects.";
182   }
Decodetensorflow::__anon368a507e0111::LinearizedBuffersWrapper183   bool Decode(const tensorflow::VariantTensorData& data) {
184     LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
185                   "objects.";
186     return false;
187   }
188 
189   LinearizerBufferList buffers;
190   // Save references on tensors whose underlying storage are shared with
191   // LiteralLinearizer::Buffer in `buffers`.
192   std::vector<tensorflow::Tensor> tensors;
193 };
194 
AutoTransposeAndLinearize(OpKernelContext * ctx,const Tensor & input_tensor,const xla::Shape & shape,LinearizerBufferList * linearized_buffers,std::vector<Tensor> * saved_input_tensors)195 Status AutoTransposeAndLinearize(OpKernelContext* ctx,
196                                  const Tensor& input_tensor,
197                                  const xla::Shape& shape,
198                                  LinearizerBufferList* linearized_buffers,
199                                  std::vector<Tensor>* saved_input_tensors) {
200   const Tensor* tensor = &input_tensor;
201   // If the given layout is not in dim0major layout, tranposes the tensor.
202   bool has_transposed = false;
203   Tensor transposed_tensor;
204   if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
205     // If the given layout is not in dim0major layout, transpose the tensor.
206     TF_ASSIGN_OR_RETURN(transposed_tensor,
207                         TransposeTensor(ctx, input_tensor, shape));
208     tensor = &transposed_tensor;
209     has_transposed = true;
210   }
211 
212   xla::BorrowingLiteral literal;
213   TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
214 
215   TF_RETURN_IF_ERROR(
216       xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
217           ->LinearizeToBuffers(literal, linearized_buffers));
218 
219   // The input tensor is ref-counted. Save a handle on the input tensor if
220   // its underlying storage is shared with linearized buffers to prevent
221   // input tensor from getting freed.
222   for (const auto& buffer : *linearized_buffers) {
223     if (!buffer.owns_data() && !has_transposed) {
224       // `buffer` is created from zero-copy fast path from the un-transposed
225       // input tensor so its underlying data is shared with input tensor.
226       // Save a handle to input tensor to increment its ref-count and avoid
227       // it getting deallocated after PrelinearizeTupleOp completes.
228       saved_input_tensors->push_back(*tensor);
229       // A literal can be linearized to zero to two buffers. If any of the
230       // linearized buffer shares storage with input tensor. We save exactly
231       // one handle on the input tensor.
232       break;
233     }
234   }
235   return OkStatus();
236 }
237 
238 // PrelinearizeOp is used to linearize one tensor to the device format.
239 class PrelinearizeOp : public OpKernel {
240  public:
PrelinearizeOp(OpKernelConstruction * ctx)241   explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
242     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
243     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
244     xla::Shape shape;
245     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
246     OP_REQUIRES_OK(ctx,
247                    GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
248   }
249 
Compute(OpKernelContext * ctx)250   void Compute(OpKernelContext* ctx) override {
251     const Tensor& input_tensor = ctx->input(0);
252     // Validate input.
253     OP_REQUIRES(
254         ctx, input_tensor.dtype() == dtype_,
255         errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
256                                 DataType_Name(dtype_), ", got ",
257                                 DataType_Name(input_tensor.dtype())));
258     OP_REQUIRES(
259         ctx, input_tensor.shape() == shape_,
260         errors::InvalidArgument("Prelinearize shape mismatch; expected ",
261                                 shape_.DebugString(), ", got ",
262                                 input_tensor.shape().DebugString()));
263 
264     // Auto-transpose and prelinearize.
265     LinearizerBufferList linearized_buffers;
266     std::vector<Tensor> saved_input_tensors;
267     auto status =
268         AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
269                                   &linearized_buffers, &saved_input_tensors);
270     OP_REQUIRES_OK(ctx, status);
271 
272     // Write to output.
273     tensorflow::Tensor* output;
274     OP_REQUIRES_OK(ctx,
275                    ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
276     output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
277         std::move(linearized_buffers), std::move(saved_input_tensors)};
278   }
279 
IsExpensive()280   bool IsExpensive() override { return true; }
281 
282  private:
283   TensorShape shape_;
284   DataType dtype_;
285   xla::Shape xla_shape_;
286 
287   // PrelinearizeOp is neither copyable nor movable.
288   PrelinearizeOp(const PrelinearizeOp&) = delete;
289   PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
290 };
291 
292 // PrelinearizeTupleOp is used to linearize multiple tensors to the device
293 // format.
294 class PrelinearizeTupleOp : public OpKernel {
295  public:
PrelinearizeTupleOp(OpKernelConstruction * ctx)296   explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
297     OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
298     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
299     OP_REQUIRES(
300         ctx, shapes_.size() == dtypes_.size(),
301         errors::InvalidArgument(
302             "shapes and dtypes must be the same length. shapes length = ",
303             shapes_.size(), ", dtypes length = ", dtypes_.size()));
304 
305     std::vector<xla::Shape> xla_shapes;
306     for (int i = 0; i < shapes_.size(); i++) {
307       xla::Shape xla_shape;
308       OP_REQUIRES_OK(ctx,
309                      TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
310       xla_shapes.push_back(xla_shape);
311     }
312     OP_REQUIRES_OK(
313         ctx, GetInfeedShapeWithLayout(
314                  ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
315                  &tuple_shape_));
316   }
317 
Compute(OpKernelContext * ctx)318   void Compute(OpKernelContext* ctx) override {
319     OpInputList values;
320     OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
321     OP_REQUIRES(ctx, values.size() == shapes_.size(),
322                 errors::InvalidArgument(
323                     "Wrong number of inputs to PrelinearizeTuple."));
324 
325     LinearizerBufferList all_linearized_buffers;
326     std::vector<Tensor> all_saved_input_tensors;
327     for (int i = 0; i < values.size(); i++) {
328       // Validate input.
329       const Tensor& input_tensor = values[i];
330       OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
331                   errors::InvalidArgument(
332                       "PrelinearizeTuple dtype mismatch at tuple element ", i,
333                       "; expected ", DataType_Name(dtypes_[i]), ", got ",
334                       DataType_Name(input_tensor.dtype())));
335       OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
336                   errors::InvalidArgument(
337                       "PrelinearizeTuple shape mismatch at tuple element ", i,
338                       "; expected ", shapes_[i].DebugString(), ", got ",
339                       input_tensor.shape().DebugString()));
340 
341       // Auto-transpose and prelinearize.
342       LinearizerBufferList linearized_buffers;
343       std::vector<Tensor> saved_input_tensors;
344       auto status = AutoTransposeAndLinearize(
345           ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
346           &saved_input_tensors);
347       OP_REQUIRES_OK(ctx, status);
348       all_linearized_buffers.insert(
349           all_linearized_buffers.end(),
350           std::make_move_iterator(linearized_buffers.begin()),
351           std::make_move_iterator(linearized_buffers.end()));
352       all_saved_input_tensors.insert(
353           all_saved_input_tensors.end(),
354           std::make_move_iterator(saved_input_tensors.begin()),
355           std::make_move_iterator(saved_input_tensors.end()));
356     }
357 
358     tensorflow::Tensor* output;
359     OP_REQUIRES_OK(ctx,
360                    ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
361     output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
362         std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
363   }
364 
IsExpensive()365   bool IsExpensive() override { return true; }
366 
367  private:
368   std::vector<TensorShape> shapes_;
369   DataTypeVector dtypes_;
370   xla::Shape tuple_shape_;
371 
372   // PrelinearizeTupleOp is neither copyable nor movable.
373   PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
374   PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
375 };
376 
377 class StreamExecutorInfeedEnqueueOp : public TpuInfeedEnqueueOp {
378  public:
StreamExecutorInfeedEnqueueOp(OpKernelConstruction * ctx)379   explicit StreamExecutorInfeedEnqueueOp(OpKernelConstruction* ctx)
380       : TpuInfeedEnqueueOp(ctx,
381                            absl::make_unique<StreamExecutorTransferOpImpl>()) {}
382 
383  private:
384   StreamExecutorInfeedEnqueueOp(const StreamExecutorInfeedEnqueueOp&) = delete;
385   StreamExecutorInfeedEnqueueOp& operator=(
386       const StreamExecutorInfeedEnqueueOp&) = delete;
387 };
388 
389 class StreamExecutorInfeedEnqueueTupleOp : public TpuInfeedEnqueueTupleOp {
390  public:
StreamExecutorInfeedEnqueueTupleOp(OpKernelConstruction * ctx)391   explicit StreamExecutorInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
392       : TpuInfeedEnqueueTupleOp(
393             ctx, absl::make_unique<StreamExecutorTransferOpImpl>()) {}
394 
395  private:
396   StreamExecutorInfeedEnqueueTupleOp(
397       const StreamExecutorInfeedEnqueueTupleOp&) = delete;
398   StreamExecutorInfeedEnqueueTupleOp& operator=(
399       const StreamExecutorInfeedEnqueueTupleOp&) = delete;
400 };
401 
402 class StreamExecutorInfeedEnqueuePrelinearizedBufferOp
403     : public InfeedEnqueuePrelinearizedBufferOp {
404  public:
StreamExecutorInfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction * ctx)405   explicit StreamExecutorInfeedEnqueuePrelinearizedBufferOp(
406       OpKernelConstruction* ctx)
407       : InfeedEnqueuePrelinearizedBufferOp(
408             ctx, absl::make_unique<StreamExecutorTransferOpImpl>()) {}
409 
410  private:
411   // InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
412   StreamExecutorInfeedEnqueuePrelinearizedBufferOp(
413       const StreamExecutorInfeedEnqueuePrelinearizedBufferOp&) = delete;
414   StreamExecutorInfeedEnqueuePrelinearizedBufferOp& operator=(
415       const StreamExecutorInfeedEnqueuePrelinearizedBufferOp&) = delete;
416 };
417 }  // anonymous namespace
418 
TpuInfeedEnqueueOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)419 TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(
420     OpKernelConstruction* ctx,
421     std::unique_ptr<TpuTransferOpInterface> transfer_op)
422     : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8,
423                                std::move(transfer_op)) {
424   OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
425   OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
426   xla::Shape shape;
427   OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
428   OP_REQUIRES_OK(ctx,
429                  GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
430 }
431 
DoWork(OpKernelContext * ctx,int device_ordinal)432 Status TpuInfeedEnqueueOp::DoWork(OpKernelContext* ctx, int device_ordinal) {
433   VLOG(1) << "TpuInfeedEnqueueOp::DoWork. iter_id=" << ctx->frame_iter().iter_id
434           << " device_ordinal=" << device_ordinal;
435   const Tensor& input_tensor = ctx->input(0);
436 
437   // Validate runtime shape and fail if it doesn't match the contract.
438   if (input_tensor.dtype() != dtype_) {
439     return errors::InvalidArgument("Infeed dtype mismatch.");
440   }
441   if (input_tensor.shape() != shape_) {
442     return errors::InvalidArgument("Infeed shape mismatch; expected ",
443                                    shape_.DebugString(), ", got ",
444                                    input_tensor.shape().DebugString());
445   }
446 
447   const Tensor* tensor = &input_tensor;
448   Tensor transposed_tensor;
449   if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
450     // If the given layout is not in dim0major layout, transpose the tensor.
451     TF_ASSIGN_OR_RETURN(transposed_tensor,
452                         TransposeTensor(ctx, input_tensor, xla_shape_));
453     tensor = &transposed_tensor;
454   }
455 
456   xla::BorrowingLiteral literal;
457   TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
458 
459   // Transfer the given literal to the Infeed interface of the device.
460   TF_RETURN_IF_ERROR(
461       transfer_op_->TransferLiteralToInfeed(device_ordinal, literal));
462   VLOG(1) << "TpuInfeedEnqueueOp completes. iter_id="
463           << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
464   return OkStatus();
465 }
466 
TpuInfeedEnqueueTupleOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)467 TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(
468     OpKernelConstruction* ctx,
469     std::unique_ptr<TpuTransferOpInterface> transfer_op)
470     : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8,
471                                std::move(transfer_op)) {
472   OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
473   OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
474   OP_REQUIRES(
475       ctx, shapes_.size() == dtypes_.size(),
476       errors::InvalidArgument("shapes and dtypes must be the same length."));
477 
478   std::vector<xla::Shape> xla_shapes;
479   for (int i = 0; i < shapes_.size(); i++) {
480     xla::Shape xla_shape;
481     OP_REQUIRES_OK(ctx,
482                    TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
483     xla_shapes.push_back(xla_shape);
484   }
485   OP_REQUIRES_OK(
486       ctx, GetInfeedShapeWithLayout(ctx, "layouts",
487                                     xla::ShapeUtil::MakeTupleShape(xla_shapes),
488                                     &tuple_shape_));
489 }
490 
DoWork(OpKernelContext * ctx,int device_ordinal)491 Status TpuInfeedEnqueueTupleOp::DoWork(OpKernelContext* ctx,
492                                        int device_ordinal) {
493   VLOG(1) << "TpuInfeedEnqueueTupleOp::DoWork. iter_id="
494           << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
495   OpInputList values;
496   TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
497   if (values.size() != shapes_.size()) {
498     return errors::InvalidArgument(
499         "Wrong number of inputs to InfeedEnqueueTuple.");
500   }
501 
502   for (const auto& shapes : shapes_) {
503     VLOG(2) << "TransferLiteralToInfeed " << shapes.DebugString();
504   }
505 
506   std::vector<Tensor> maybe_transposed_tensors;
507   maybe_transposed_tensors.reserve(values.size());
508   for (int i = 0; i < values.size(); i++) {
509     // Validate runtime shapes and fail if it doesn't match the contract.
510     const Tensor* tensor = &values[i];
511     if (tensor->shape() != shapes_[i]) {
512       return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
513                                      i, "; expected ", shapes_[i].DebugString(),
514                                      ", got ", tensor->shape().DebugString());
515     }
516     if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
517             tuple_shape_.tuple_shapes(i).layout())) {
518       // If the given layout is not in dim0major layout, tranposes the given
519       // tensor.
520       TF_ASSIGN_OR_RETURN(
521           Tensor transposed_tensor,
522           TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
523       maybe_transposed_tensors.emplace_back(transposed_tensor);
524     } else {
525       maybe_transposed_tensors.emplace_back(*tensor);
526     }
527   }
528 
529   xla::BorrowingLiteral tuple;
530   TF_RETURN_IF_ERROR(
531       HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
532 
533   // Transfer the given literal to the Infeed interface of the device.
534   TF_RETURN_IF_ERROR(
535       transfer_op_->TransferLiteralToInfeed(device_ordinal, tuple));
536 
537   VLOG(1) << "TpuInfeedEnqueueTupleOp completes. iter_id="
538           << ctx->frame_iter().iter_id << " device_ordinal=" << device_ordinal;
539 
540   return OkStatus();
541 }
542 
InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction * ctx,std::unique_ptr<TpuTransferOpInterface> transfer_op)543 InfeedEnqueuePrelinearizedBufferOp::InfeedEnqueuePrelinearizedBufferOp(
544     OpKernelConstruction* ctx,
545     std::unique_ptr<TpuTransferOpInterface> transfer_op)
546     : TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8,
547                                std::move(transfer_op)) {}
DoWork(OpKernelContext * ctx,int device_ordinal)548 Status InfeedEnqueuePrelinearizedBufferOp::DoWork(OpKernelContext* ctx,
549                                                   int device_ordinal) {
550   const Tensor& input_tensor = ctx->input(0);
551   const LinearizedBuffersWrapper* wrapper =
552       input_tensor.scalar<tensorflow::Variant>()()
553           .get<LinearizedBuffersWrapper>();
554   TF_RETURN_IF_ERROR(
555       transfer_op_->TransferBuffersToInfeed(device_ordinal, wrapper->buffers));
556 
557   return OkStatus();
558 }
559 
560 // These ops execute on either the TPU device or the CPU device. When running on
561 // CPU they must specify a non-negative value for device_ordinal to indicate
562 // which TPU to send infeed to.
563 REGISTER_KERNEL_BUILDER(
564     Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
565     StreamExecutorInfeedEnqueueOp);
566 REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
567                         StreamExecutorInfeedEnqueueOp);
568 
569 REGISTER_KERNEL_BUILDER(
570     Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
571     StreamExecutorInfeedEnqueueTupleOp);
572 REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
573                         StreamExecutorInfeedEnqueueTupleOp);
574 
575 // Prelinearize ops run on CPU as part of tf.data input pipeline.
576 REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
577                         PrelinearizeOp);
578 REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
579                         PrelinearizeTupleOp);
580 
581 // InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
582 // select the right device to infeed.
583 REGISTER_KERNEL_BUILDER(
584     Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
585     StreamExecutorInfeedEnqueuePrelinearizedBufferOp);
586 
587 }  // namespace tensorflow
588