xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA TensorList operators.
17 
18 #include <limits>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
23 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/bounds_check.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/partial_tensor_shape.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor_types.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace tensorflow {
45 
46 namespace {
47 
48 // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
49 // may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
50 // dimension is static, a constant dimension is returned. If a dim is dynamic, a
51 // dynamic XlaOp representing the dynamic size is returned.
GetTensorListDynamicDims(XlaOpKernelContext * ctx,const xla::Shape & element_shape,const xla::Shape & list_shape,int64_t num_elements)52 StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
53     XlaOpKernelContext* ctx, const xla::Shape& element_shape,
54     const xla::Shape& list_shape, int64_t num_elements) {
55   std::vector<int64_t> dynamic_sizes;
56   // The multiplier can be a dynamic value.
57   TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
58   std::vector<bool> dims_are_dynamic;
59   TF_RETURN_IF_ERROR(
60       ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
61   bool leading_dim_is_dynamic;
62   TF_RETURN_IF_ERROR(
63       ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
64   std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
65   // Set dynamic dimension size to 0 for initialization value.
66   std::vector<xla::XlaOp> dynamic_dims;
67   dynamic_dims.reserve(1 + element_shape.dimensions_size());
68   if (leading_dim_is_dynamic) {
69     dynamic_dims.push_back(ctx->Input(1));
70   } else {
71     dynamic_dims.push_back(
72         xla::ConstantR0<int32>(ctx->builder(), num_elements));
73   }
74   for (int64_t dim = 0; dim < element_shape.dimensions_size(); ++dim) {
75     if (dims_are_dynamic[dim]) {
76       auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
77       dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
78       dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
79       dynamic_dims.push_back(dynamic_dim_size);
80     } else {
81       dynamic_dims.push_back(
82           xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
83     }
84   }
85   list_dynamic_dims.push_back(std::move(dynamic_dims));
86   return list_dynamic_dims;
87 }
88 
89 class TensorListLengthOp : public XlaOpKernel {
90  public:
TensorListLengthOp(OpKernelConstruction * ctx)91   explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
92 
Compile(XlaOpKernelContext * ctx)93   void Compile(XlaOpKernelContext* ctx) override {
94     int64_t leading_dim;
95     xla::XlaOp leading_dim_size;
96     bool leading_dim_is_dynamic;
97     OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
98                                                    &leading_dim_is_dynamic,
99                                                    &leading_dim_size));
100     ctx->SetOutput(0, leading_dim_size);
101   }
102 
103  private:
104   TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
105 };
106 
107 REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
108 
109 // "input" is the shape input for EmptyTensorList/TensorListReserve ops.
110 // If "input" is a compile time constant and not "unknown rank" (-1), return
111 // its value in "*shape".
TryGetElementShapeFromInput(XlaOpKernelContext * ctx,xla::XlaOp input,xla::PrimitiveType dtype,bool * got_shape,xla::Shape * shape)112 Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input,
113                                    xla::PrimitiveType dtype, bool* got_shape,
114                                    xla::Shape* shape) {
115   auto is_compile_time_constant_or = input.builder()->IsConstant(input);
116   TF_RETURN_IF_ERROR(is_compile_time_constant_or.status());
117 
118   bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie();
119   if (!is_compile_time_constant) {
120     *got_shape = false;
121     return OkStatus();
122   }
123 
124   PartialTensorShape partial_shape;
125   TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape));
126   if (!partial_shape.IsFullyDefined()) {
127     *got_shape = false;
128     return OkStatus();
129   }
130 
131   *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes());
132   *got_shape = true;
133   return OkStatus();
134 }
135 
136 class TensorListReserveOp : public XlaOpKernel {
137  public:
TensorListReserveOp(OpKernelConstruction * ctx)138   explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
139     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
140     // Only non-nested TensorList is supported for now.
141     OP_REQUIRES(
142         ctx, dtype_ != DT_VARIANT,
143         errors::Unimplemented(
144             "Only non-nested TensorList is supported for TensorListReserve."));
145   }
146 
Compile(XlaOpKernelContext * ctx)147   void Compile(XlaOpKernelContext* ctx) override {
148     int64_t num_elements;
149     OP_REQUIRES_OK(ctx,
150                    ctx->ConstantInputAsIntScalar(
151                        1, &num_elements, xla::ValueInferenceMode::kUpperBound));
152     bool num_element_is_dynamic;
153     OP_REQUIRES_OK(
154         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
155     OP_REQUIRES(
156         ctx, num_elements >= 0,
157         errors::InvalidArgument(
158             "XLA compilation requires a fixed tensor list size. Set the number "
159             "of elements. This could also happen if you're using a TensorArray "
160             "in a while loop that does not have its maximum_iteration set, you "
161             "can fix this by setting maximum_iteration to a suitable value."));
162 
163     // If element shape is compile time constant and it's not "unknown rank"
164     // shape (-1), create an initialized TensorList. Otherwise create an
165     // uninitialized TensorList.
166     xla::XlaOp element_shape_handle = ctx->Input(0);
167     xla::PrimitiveType type;
168     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
169     bool got_shape;
170     xla::Shape element_shape;
171     OP_REQUIRES_OK(ctx,
172                    TryGetElementShapeFromInput(ctx, element_shape_handle, type,
173                                                &got_shape, &element_shape));
174     if (got_shape) {
175       xla::Shape list_shape;
176       OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
177                               element_shape, num_elements,
178                               num_element_is_dynamic, &list_shape));
179       // Set up dynamic dimension sizes to create the zero tensor.
180       auto list_dynamic_dims_or = GetTensorListDynamicDims(
181           ctx, element_shape, list_shape, num_elements);
182       OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
183       xla::XlaOp new_list;
184       OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
185                               ctx->builder(), list_shape,
186                               list_dynamic_dims_or.ValueOrDie(), &new_list));
187       xla::XlaOp result;
188       OP_REQUIRES_OK(
189           ctx,
190           SetTensorListPushIndex(
191               new_list, xla::ConstantR0<int32>(ctx->builder(), num_elements),
192               &result));
193       ctx->SetTensorListOutput(0, result);
194       return;
195     }
196 
197     xla::XlaOp result = BuildUninitializedTensorList(
198         ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
199     ctx->SetTensorListOutput(0, result);
200   }
201 
202  private:
203   DataType dtype_;
204 
205   TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
206 };
207 
208 REGISTER_XLA_OP(Name("TensorListReserve")
209                     .CompileTimeConstantInput("element_shape")
210                     .CompileTimeConstantInput("num_elements"),
211                 TensorListReserveOp);
212 
213 class EmptyTensorListOp : public XlaOpKernel {
214  public:
EmptyTensorListOp(OpKernelConstruction * ctx)215   explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
216     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
217   }
218 
Compile(XlaOpKernelContext * ctx)219   void Compile(XlaOpKernelContext* ctx) override {
220     int64_t max_num_elements;
221     OP_REQUIRES_OK(
222         ctx, ctx->ConstantInputAsIntScalar(
223                  1, &max_num_elements, xla::ValueInferenceMode::kUpperBound));
224     bool num_element_is_dynamic;
225     OP_REQUIRES_OK(
226         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
227     OP_REQUIRES(ctx, max_num_elements >= 0,
228                 errors::InvalidArgument(
229                     "XLA compilation requires a fixed tensor list size. Set "
230                     "the max number of elements. This could also happen if "
231                     "you're using a TensorArray in a while loop that does not "
232                     "have its maximum_iteration set, you can fix this by "
233                     "setting maximum_iteration to a suitable value."));
234 
235     if (dtype_ != DT_VARIANT) {
236       // We are creating a non-nested TensorList.
237       // If element shape is compile time constant and it's not "unknown
238       // rank" shape (-1), create an initialized TensorList. Otherwise
239       // create an uninitialized TensorList.
240       xla::XlaOp element_shape_handle = ctx->Input(0);
241       xla::PrimitiveType type;
242       OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
243       bool got_shape;
244       xla::Shape element_shape;
245       OP_REQUIRES_OK(
246           ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type,
247                                            &got_shape, &element_shape));
248       if (got_shape) {
249         xla::Shape list_shape;
250         OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
251                                 element_shape, max_num_elements,
252                                 num_element_is_dynamic, &list_shape));
253         // Set up dynamic dimension sizes to create the zero tensor.
254         auto list_dynamic_dims_or = GetTensorListDynamicDims(
255             ctx, element_shape, list_shape, max_num_elements);
256         OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
257 
258         xla::XlaOp result;
259         OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
260                                 ctx->builder(), list_shape,
261                                 list_dynamic_dims_or.ValueOrDie(), &result));
262 
263         ctx->SetTensorListOutput(0, result);
264         return;
265       }
266     }
267 
268     // We are creating a nested TensorList or a non-nested TensorList with
269     // unknown shape. Just create an uninitialized TensorList.
270     xla::XlaOp result =
271         BuildUninitializedTensorList(ctx->builder(), max_num_elements,
272                                      num_element_is_dynamic, ctx->Input(1));
273     ctx->SetTensorListOutput(0, result);
274   }
275 
276  private:
277   DataType dtype_;
278 
279   TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
280 };
281 
282 REGISTER_XLA_OP(Name("EmptyTensorList")
283                     .CompileTimeConstantInput("element_shape")
284                     .CompileTimeConstantInput("max_num_elements")
285                     .AllowVariantTypes(),
286                 EmptyTensorListOp);
287 
288 class TensorListElementShapeOp : public XlaOpKernel {
289  public:
TensorListElementShapeOp(OpKernelConstruction * ctx)290   explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
291       : XlaOpKernel(ctx) {
292     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
293   }
294 
Compile(XlaOpKernelContext * ctx)295   void Compile(XlaOpKernelContext* ctx) override {
296     // Check that the TensorList is initialized.
297     bool is_initialized;
298     OP_REQUIRES_OK(ctx,
299                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
300     OP_REQUIRES(ctx, is_initialized,
301                 errors::InvalidArgument("TensorList is not initialized"));
302 
303     // Only non-nested TensorList is supported for now.
304     bool is_nested;
305     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
306     OP_REQUIRES(ctx, !is_nested,
307                 errors::Unimplemented("Only non-nested TensorList is supported "
308                                       "for TensorListElementShape."));
309 
310     // For non-nested TensorList, element shape is the buffer shape without
311     // the first dimension.
312     xla::XlaBuilder* b = ctx->builder();
313     xla::Shape list_shape;
314     OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape));
315     list_shape.DeleteDimension(0);
316 
317     switch (shape_type_) {
318       case DT_INT64:
319         ctx->SetOutput(0, xla::ConstantR1<int64_t>(b, list_shape.dimensions()));
320         break;
321       case DT_INT32: {
322         std::vector<int32> size;
323         const auto& dimensions = list_shape.dimensions();
324         size.reserve(dimensions.size());
325         for (int64_t s : dimensions) {
326           size.push_back(s);
327         }
328         ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
329         break;
330       }
331       default:
332         ctx->CtxFailure(
333             errors::InvalidArgument("Unsupported shape type requested"));
334         return;
335     }
336   }
337 
338  private:
339   DataType shape_type_;
340 
341   TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
342 };
343 
344 REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(),
345                 TensorListElementShapeOp);
346 
347 class TensorListGetItemOp : public XlaOpKernel {
348  public:
TensorListGetItemOp(OpKernelConstruction * ctx)349   explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
350     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
351   }
352 
Compile(XlaOpKernelContext * ctx)353   void Compile(XlaOpKernelContext* ctx) override {
354     // Check that the TensorList is initialized.
355     bool is_initialized;
356     OP_REQUIRES_OK(ctx,
357                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
358     OP_REQUIRES(ctx, is_initialized,
359                 errors::InvalidArgument("TensorList is not initialized"));
360 
361     // Only non-nested TensorList is supported for now.
362     bool is_nested;
363     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
364     OP_REQUIRES(ctx, !is_nested,
365                 errors::Unimplemented("Only non-nested TensorList is supported "
366                                       "for TensorListGetItem."));
367 
368     xla::XlaOp list = ctx->Input(0);
369     xla::XlaOp index = ctx->Input(1);
370 
371     xla::XlaOp result;
372     OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result));
373 
374     ctx->SetOutput(0, result);
375   }
376 
377  private:
378   DataType dtype_;
379 
380   TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp);
381 };
382 
383 REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp);
384 
385 class TensorListGatherOp : public XlaOpKernel {
386  public:
TensorListGatherOp(OpKernelConstruction * ctx)387   explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
388     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
389   }
390 
Compile(XlaOpKernelContext * ctx)391   void Compile(XlaOpKernelContext* ctx) override {
392     // Check that the TensorList is initialized.
393     bool is_initialized;
394     OP_REQUIRES_OK(ctx,
395                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
396     OP_REQUIRES(ctx, is_initialized,
397                 errors::InvalidArgument("TensorList is not initialized"));
398 
399     // Only non-nested TensorList is supported for now.
400     bool is_nested;
401     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
402     OP_REQUIRES(ctx, !is_nested,
403                 errors::Unimplemented("Only non-nested TensorList is supported "
404                                       "for TensorListGather."));
405 
406     DataType indices_type = ctx->input_type(1);
407 
408     const TensorShape indices_shape = ctx->InputShape(1);
409     OP_REQUIRES(ctx, indices_shape.dims() == 1,
410                 errors::InvalidArgument("indices must be rank 1"));
411 
412     xla::XlaOp list = ctx->Input(0);
413     xla::XlaOp indices = ctx->Input(1);
414 
415     xla::XlaOp buffer;
416     OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer));
417     xla::Shape buffer_xla_shape;
418     OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape));
419     TensorShape buffer_shape;
420     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape));
421 
422     xla::XlaOp result;
423     OP_REQUIRES_OK(
424         ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0,
425                        /*indices_are_nd=*/false, dtype_, indices_type,
426                        ctx->builder(), &result));
427     ctx->SetOutput(0, result);
428   }
429 
430  private:
431   DataType dtype_;
432 
433   TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp);
434 };
435 
436 REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp);
437 
438 class TensorListStackOp : public XlaOpKernel {
439  public:
TensorListStackOp(OpKernelConstruction * ctx)440   explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
441 
Compile(XlaOpKernelContext * ctx)442   void Compile(XlaOpKernelContext* ctx) override {
443     // Check that the TensorList is initialized.
444     bool is_initialized;
445     OP_REQUIRES_OK(ctx,
446                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
447     OP_REQUIRES(ctx, is_initialized,
448                 errors::InvalidArgument("TensorList is not initialized"));
449 
450     // Only non-nested TensorList is supported for now.
451     bool is_nested;
452     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
453     OP_REQUIRES(ctx, !is_nested,
454                 errors::Unimplemented("Only non-nested TensorList is supported "
455                                       "for TensorListGetItem."));
456 
457     xla::XlaOp buffer;
458     OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
459     ctx->SetOutput(0, buffer);
460   }
461 
462  private:
463   TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp);
464 };
465 
466 REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp);
467 
468 class TensorListConcatOp : public XlaOpKernel {
469  public:
TensorListConcatOp(OpKernelConstruction * ctx)470   explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
471 
Compile(XlaOpKernelContext * ctx)472   void Compile(XlaOpKernelContext* ctx) override {
473     xla::XlaOp input = ctx->Input(0);
474 
475     // Check that the TensorList is initialized.
476     bool is_initialized;
477     OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
478     OP_REQUIRES(ctx, is_initialized,
479                 errors::InvalidArgument("TensorList is not initialized"));
480 
481     // Only non-nested TensorList is supported for now.
482     bool is_nested;
483     OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested));
484     OP_REQUIRES(ctx, !is_nested,
485                 errors::Unimplemented("Only non-nested TensorList is supported "
486                                       "for TensorListConcat."));
487 
488     xla::XlaOp buffer;
489     OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer));
490 
491     xla::XlaBuilder* b = input.builder();
492     auto shape_or = b->GetShape(buffer);
493     OP_REQUIRES_OK(ctx, shape_or.status());
494     xla::Shape element_shape = std::move(shape_or).value();
495     std::vector<int64_t> element_dims =
496         xla::SpanToVector(element_shape.dimensions());
497     OP_REQUIRES(
498         ctx, element_dims.size() > 1,
499         errors::Unimplemented("TensorList of scalars is not supported"));
500     int64_t num_elements = element_dims[0];
501     int64_t tensor_lengths = element_dims[1];
502 
503     std::vector<int64_t> new_dims = {num_elements * tensor_lengths};
504 
505     for (int i = 2; i < element_dims.size(); i++) {
506       new_dims.push_back(element_dims[i]);
507     }
508 
509     xla::XlaOp out = xla::Reshape(buffer, new_dims);
510     ctx->SetOutput(0, out);
511 
512     // Second output is a tensor of lengths of returned tensors.
513     xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths);
514     ctx->SetOutput(1, lengths);
515   }
516 
517  private:
518   TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp);
519 };
520 
521 REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp);
522 
523 class TensorListSplitOp : public XlaOpKernel {
524  public:
TensorListSplitOp(OpKernelConstruction * ctx)525   explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
526     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
527     // Only non-nested TensorList is supported for now.
528     OP_REQUIRES(
529         ctx, dtype_ != DT_VARIANT,
530         errors::Unimplemented(
531             "Only non-nested TensorList is supported for TensorListReserve."));
532   }
533 
Compile(XlaOpKernelContext * ctx)534   void Compile(XlaOpKernelContext* ctx) override {
535     xla::XlaOp input_tensor = ctx->Input(0);
536 
537     xla::XlaBuilder* b = input_tensor.builder();
538     auto shape_or = b->GetShape(input_tensor);
539     OP_REQUIRES_OK(ctx, shape_or.status());
540     xla::Shape element_shape = std::move(shape_or).value();
541     std::vector<int64_t> element_dims =
542         xla::SpanToVector(element_shape.dimensions());
543     OP_REQUIRES(
544         ctx, !element_dims.empty(),
545         errors::Unimplemented("Element dimensions have to be non-empty"));
546 
547     std::vector<int64_t> lengths;
548     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
549     OP_REQUIRES(ctx, !lengths.empty(),
550                 errors::Unimplemented("Length has to be non-empty"));
551     int64_t length = lengths[0];
552     for (int64_t len : lengths) {
553       OP_REQUIRES(ctx, len == length,
554                   errors::Unimplemented("All lengths have to be the same"));
555     }
556     OP_REQUIRES(
557         ctx, element_dims[0] % length == 0,
558         errors::Unimplemented("Buffer size has to be a multiple of length"));
559     std::vector<int64_t> new_dims = {element_dims[0] / length, length};
560     for (int i = 1; i < element_dims.size(); i++) {
561       new_dims.push_back(element_dims[i]);
562     }
563 
564     xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims);
565 
566     xla::XlaOp result;
567     OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result));
568     ctx->SetTensorListOutput(0, result);
569   }
570 
571  private:
572   DataType dtype_;
573 
574   TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp);
575 };
576 
577 REGISTER_XLA_OP(Name("TensorListSplit")
578                     .CompileTimeConstantInput("element_shape")
579                     .CompileTimeConstantInput("lengths"),
580                 TensorListSplitOp);
581 
582 class TensorListFromTensorOp : public XlaOpKernel {
583  public:
TensorListFromTensorOp(OpKernelConstruction * ctx)584   explicit TensorListFromTensorOp(OpKernelConstruction* ctx)
585       : XlaOpKernel(ctx) {}
586 
Compile(XlaOpKernelContext * ctx)587   void Compile(XlaOpKernelContext* ctx) override {
588     const TensorShape& tensor_shape = ctx->InputShape(0);
589     int num_elements = tensor_shape.dim_size(0);
590     const xla::XlaOp tensor = ctx->Input(0);
591     xla::XlaOp result;
592     OP_REQUIRES_OK(ctx,
593                    ExecuteTensorListFromTensor(num_elements, tensor, &result));
594     auto list_shape_or = ctx->builder()->GetShape(result);
595     ctx->SetTensorListOutput(0, result);
596   }
597 
598  private:
599   TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp);
600 };
601 
602 REGISTER_XLA_OP(
603     Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
604     TensorListFromTensorOp);
605 
606 class TensorListSetItemOp : public XlaOpKernel {
607  public:
TensorListSetItemOp(OpKernelConstruction * ctx)608   explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
609 
Compile(XlaOpKernelContext * ctx)610   void Compile(XlaOpKernelContext* ctx) override {
611     xla::XlaOp list = ctx->Input(0);
612     xla::XlaOp index = ctx->Input(1);
613     xla::XlaOp element = ctx->Input(2);
614     xla::XlaOp initialized_list;
615     OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement(
616                             list, element, /*element_is_tensor_list=*/false,
617                             &initialized_list));
618 
619     // Only non-nested TensorList is supported for now.
620     bool is_nested;
621     OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested));
622     OP_REQUIRES(ctx, !is_nested,
623                 errors::Unimplemented("Only non-nested TensorList is supported "
624                                       "for TensorListSetItem."));
625 
626     xla::XlaOp result;
627     OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index,
628                                                  element, &result));
629 
630     ctx->SetTensorListOutput(0, result);
631   }
632 
633  private:
634   TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp);
635 };
636 
637 REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp);
638 
639 class TensorListPushBackOp : public XlaOpKernel {
640  public:
TensorListPushBackOp(OpKernelConstruction * ctx)641   explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
642 
Compile(XlaOpKernelContext * ctx)643   void Compile(XlaOpKernelContext* ctx) override {
644     xla::XlaOp list = ctx->Input(0);
645     xla::XlaOp element = ctx->Input(1);
646     bool element_is_tensor_list = IsTensorListInput(ctx, 1);
647     xla::XlaOp initialized_list;
648     OP_REQUIRES_OK(
649         ctx, GetInitializedTensorListForElement(
650                  list, element, element_is_tensor_list, &initialized_list));
651 
652     xla::XlaOp result;
653     OP_REQUIRES_OK(ctx,
654                    ExecuteTensorListPushBack(initialized_list, element,
655                                              element_is_tensor_list, &result));
656 
657     ctx->SetTensorListOutput(0, result);
658   }
659 
660  private:
661   TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
662 };
663 
664 REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(),
665                 TensorListPushBackOp);
666 
667 class TensorListPopBackOp : public XlaOpKernel {
668  public:
TensorListPopBackOp(OpKernelConstruction * ctx)669   explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
670 
Compile(XlaOpKernelContext * ctx)671   void Compile(XlaOpKernelContext* ctx) override {
672     // Check that the TensorList is initialized.
673     bool is_initialized;
674     OP_REQUIRES_OK(ctx,
675                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
676     OP_REQUIRES(ctx, is_initialized,
677                 errors::InvalidArgument("TensorList is not initialized"));
678 
679     xla::XlaOp list = ctx->Input(0);
680     xla::XlaOp list_result, element_result;
681     bool element_is_tensor_list;
682     OP_REQUIRES_OK(ctx,
683                    ExecuteTensorListPopBack(list, &list_result, &element_result,
684                                             &element_is_tensor_list));
685 
686     ctx->SetTensorListOutput(0, list_result);
687     if (element_is_tensor_list) {
688       ctx->SetTensorListOutput(1, element_result);
689     } else {
690       ctx->SetOutput(1, element_result);
691     }
692   }
693 
694  private:
695   DataType dtype_;
696 
697   TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
698 };
699 
700 REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(),
701                 TensorListPopBackOp);
702 
703 }  // anonymous namespace
704 }  // namespace tensorflow
705