1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA TensorList operators.
17
18 #include <limits>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
23 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/bounds_check.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/partial_tensor_shape.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor_types.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace tensorflow {
45
46 namespace {
47
48 // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
49 // may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
50 // dimension is static, a constant dimension is returned. If a dim is dynamic, a
51 // dynamic XlaOp representing the dynamic size is returned.
GetTensorListDynamicDims(XlaOpKernelContext * ctx,const xla::Shape & element_shape,const xla::Shape & list_shape,int64_t num_elements)52 StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
53 XlaOpKernelContext* ctx, const xla::Shape& element_shape,
54 const xla::Shape& list_shape, int64_t num_elements) {
55 std::vector<int64_t> dynamic_sizes;
56 // The multiplier can be a dynamic value.
57 TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
58 std::vector<bool> dims_are_dynamic;
59 TF_RETURN_IF_ERROR(
60 ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
61 bool leading_dim_is_dynamic;
62 TF_RETURN_IF_ERROR(
63 ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
64 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
65 // Set dynamic dimension size to 0 for initialization value.
66 std::vector<xla::XlaOp> dynamic_dims;
67 dynamic_dims.reserve(1 + element_shape.dimensions_size());
68 if (leading_dim_is_dynamic) {
69 dynamic_dims.push_back(ctx->Input(1));
70 } else {
71 dynamic_dims.push_back(
72 xla::ConstantR0<int32>(ctx->builder(), num_elements));
73 }
74 for (int64_t dim = 0; dim < element_shape.dimensions_size(); ++dim) {
75 if (dims_are_dynamic[dim]) {
76 auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
77 dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
78 dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
79 dynamic_dims.push_back(dynamic_dim_size);
80 } else {
81 dynamic_dims.push_back(
82 xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
83 }
84 }
85 list_dynamic_dims.push_back(std::move(dynamic_dims));
86 return list_dynamic_dims;
87 }
88
89 class TensorListLengthOp : public XlaOpKernel {
90 public:
TensorListLengthOp(OpKernelConstruction * ctx)91 explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
92
Compile(XlaOpKernelContext * ctx)93 void Compile(XlaOpKernelContext* ctx) override {
94 int64_t leading_dim;
95 xla::XlaOp leading_dim_size;
96 bool leading_dim_is_dynamic;
97 OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
98 &leading_dim_is_dynamic,
99 &leading_dim_size));
100 ctx->SetOutput(0, leading_dim_size);
101 }
102
103 private:
104 TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
105 };
106
107 REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
108
109 // "input" is the shape input for EmptyTensorList/TensorListReserve ops.
110 // If "input" is a compile time constant and not "unknown rank" (-1), return
111 // its value in "*shape".
TryGetElementShapeFromInput(XlaOpKernelContext * ctx,xla::XlaOp input,xla::PrimitiveType dtype,bool * got_shape,xla::Shape * shape)112 Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input,
113 xla::PrimitiveType dtype, bool* got_shape,
114 xla::Shape* shape) {
115 auto is_compile_time_constant_or = input.builder()->IsConstant(input);
116 TF_RETURN_IF_ERROR(is_compile_time_constant_or.status());
117
118 bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie();
119 if (!is_compile_time_constant) {
120 *got_shape = false;
121 return OkStatus();
122 }
123
124 PartialTensorShape partial_shape;
125 TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape));
126 if (!partial_shape.IsFullyDefined()) {
127 *got_shape = false;
128 return OkStatus();
129 }
130
131 *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes());
132 *got_shape = true;
133 return OkStatus();
134 }
135
136 class TensorListReserveOp : public XlaOpKernel {
137 public:
TensorListReserveOp(OpKernelConstruction * ctx)138 explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
139 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
140 // Only non-nested TensorList is supported for now.
141 OP_REQUIRES(
142 ctx, dtype_ != DT_VARIANT,
143 errors::Unimplemented(
144 "Only non-nested TensorList is supported for TensorListReserve."));
145 }
146
Compile(XlaOpKernelContext * ctx)147 void Compile(XlaOpKernelContext* ctx) override {
148 int64_t num_elements;
149 OP_REQUIRES_OK(ctx,
150 ctx->ConstantInputAsIntScalar(
151 1, &num_elements, xla::ValueInferenceMode::kUpperBound));
152 bool num_element_is_dynamic;
153 OP_REQUIRES_OK(
154 ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
155 OP_REQUIRES(
156 ctx, num_elements >= 0,
157 errors::InvalidArgument(
158 "XLA compilation requires a fixed tensor list size. Set the number "
159 "of elements. This could also happen if you're using a TensorArray "
160 "in a while loop that does not have its maximum_iteration set, you "
161 "can fix this by setting maximum_iteration to a suitable value."));
162
163 // If element shape is compile time constant and it's not "unknown rank"
164 // shape (-1), create an initialized TensorList. Otherwise create an
165 // uninitialized TensorList.
166 xla::XlaOp element_shape_handle = ctx->Input(0);
167 xla::PrimitiveType type;
168 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
169 bool got_shape;
170 xla::Shape element_shape;
171 OP_REQUIRES_OK(ctx,
172 TryGetElementShapeFromInput(ctx, element_shape_handle, type,
173 &got_shape, &element_shape));
174 if (got_shape) {
175 xla::Shape list_shape;
176 OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
177 element_shape, num_elements,
178 num_element_is_dynamic, &list_shape));
179 // Set up dynamic dimension sizes to create the zero tensor.
180 auto list_dynamic_dims_or = GetTensorListDynamicDims(
181 ctx, element_shape, list_shape, num_elements);
182 OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
183 xla::XlaOp new_list;
184 OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
185 ctx->builder(), list_shape,
186 list_dynamic_dims_or.ValueOrDie(), &new_list));
187 xla::XlaOp result;
188 OP_REQUIRES_OK(
189 ctx,
190 SetTensorListPushIndex(
191 new_list, xla::ConstantR0<int32>(ctx->builder(), num_elements),
192 &result));
193 ctx->SetTensorListOutput(0, result);
194 return;
195 }
196
197 xla::XlaOp result = BuildUninitializedTensorList(
198 ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
199 ctx->SetTensorListOutput(0, result);
200 }
201
202 private:
203 DataType dtype_;
204
205 TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
206 };
207
208 REGISTER_XLA_OP(Name("TensorListReserve")
209 .CompileTimeConstantInput("element_shape")
210 .CompileTimeConstantInput("num_elements"),
211 TensorListReserveOp);
212
213 class EmptyTensorListOp : public XlaOpKernel {
214 public:
EmptyTensorListOp(OpKernelConstruction * ctx)215 explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
216 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
217 }
218
Compile(XlaOpKernelContext * ctx)219 void Compile(XlaOpKernelContext* ctx) override {
220 int64_t max_num_elements;
221 OP_REQUIRES_OK(
222 ctx, ctx->ConstantInputAsIntScalar(
223 1, &max_num_elements, xla::ValueInferenceMode::kUpperBound));
224 bool num_element_is_dynamic;
225 OP_REQUIRES_OK(
226 ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
227 OP_REQUIRES(ctx, max_num_elements >= 0,
228 errors::InvalidArgument(
229 "XLA compilation requires a fixed tensor list size. Set "
230 "the max number of elements. This could also happen if "
231 "you're using a TensorArray in a while loop that does not "
232 "have its maximum_iteration set, you can fix this by "
233 "setting maximum_iteration to a suitable value."));
234
235 if (dtype_ != DT_VARIANT) {
236 // We are creating a non-nested TensorList.
237 // If element shape is compile time constant and it's not "unknown
238 // rank" shape (-1), create an initialized TensorList. Otherwise
239 // create an uninitialized TensorList.
240 xla::XlaOp element_shape_handle = ctx->Input(0);
241 xla::PrimitiveType type;
242 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
243 bool got_shape;
244 xla::Shape element_shape;
245 OP_REQUIRES_OK(
246 ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type,
247 &got_shape, &element_shape));
248 if (got_shape) {
249 xla::Shape list_shape;
250 OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
251 element_shape, max_num_elements,
252 num_element_is_dynamic, &list_shape));
253 // Set up dynamic dimension sizes to create the zero tensor.
254 auto list_dynamic_dims_or = GetTensorListDynamicDims(
255 ctx, element_shape, list_shape, max_num_elements);
256 OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
257
258 xla::XlaOp result;
259 OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
260 ctx->builder(), list_shape,
261 list_dynamic_dims_or.ValueOrDie(), &result));
262
263 ctx->SetTensorListOutput(0, result);
264 return;
265 }
266 }
267
268 // We are creating a nested TensorList or a non-nested TensorList with
269 // unknown shape. Just create an uninitialized TensorList.
270 xla::XlaOp result =
271 BuildUninitializedTensorList(ctx->builder(), max_num_elements,
272 num_element_is_dynamic, ctx->Input(1));
273 ctx->SetTensorListOutput(0, result);
274 }
275
276 private:
277 DataType dtype_;
278
279 TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
280 };
281
282 REGISTER_XLA_OP(Name("EmptyTensorList")
283 .CompileTimeConstantInput("element_shape")
284 .CompileTimeConstantInput("max_num_elements")
285 .AllowVariantTypes(),
286 EmptyTensorListOp);
287
288 class TensorListElementShapeOp : public XlaOpKernel {
289 public:
TensorListElementShapeOp(OpKernelConstruction * ctx)290 explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
291 : XlaOpKernel(ctx) {
292 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
293 }
294
Compile(XlaOpKernelContext * ctx)295 void Compile(XlaOpKernelContext* ctx) override {
296 // Check that the TensorList is initialized.
297 bool is_initialized;
298 OP_REQUIRES_OK(ctx,
299 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
300 OP_REQUIRES(ctx, is_initialized,
301 errors::InvalidArgument("TensorList is not initialized"));
302
303 // Only non-nested TensorList is supported for now.
304 bool is_nested;
305 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
306 OP_REQUIRES(ctx, !is_nested,
307 errors::Unimplemented("Only non-nested TensorList is supported "
308 "for TensorListElementShape."));
309
310 // For non-nested TensorList, element shape is the buffer shape without
311 // the first dimension.
312 xla::XlaBuilder* b = ctx->builder();
313 xla::Shape list_shape;
314 OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape));
315 list_shape.DeleteDimension(0);
316
317 switch (shape_type_) {
318 case DT_INT64:
319 ctx->SetOutput(0, xla::ConstantR1<int64_t>(b, list_shape.dimensions()));
320 break;
321 case DT_INT32: {
322 std::vector<int32> size;
323 const auto& dimensions = list_shape.dimensions();
324 size.reserve(dimensions.size());
325 for (int64_t s : dimensions) {
326 size.push_back(s);
327 }
328 ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
329 break;
330 }
331 default:
332 ctx->CtxFailure(
333 errors::InvalidArgument("Unsupported shape type requested"));
334 return;
335 }
336 }
337
338 private:
339 DataType shape_type_;
340
341 TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
342 };
343
344 REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(),
345 TensorListElementShapeOp);
346
347 class TensorListGetItemOp : public XlaOpKernel {
348 public:
TensorListGetItemOp(OpKernelConstruction * ctx)349 explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
350 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
351 }
352
Compile(XlaOpKernelContext * ctx)353 void Compile(XlaOpKernelContext* ctx) override {
354 // Check that the TensorList is initialized.
355 bool is_initialized;
356 OP_REQUIRES_OK(ctx,
357 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
358 OP_REQUIRES(ctx, is_initialized,
359 errors::InvalidArgument("TensorList is not initialized"));
360
361 // Only non-nested TensorList is supported for now.
362 bool is_nested;
363 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
364 OP_REQUIRES(ctx, !is_nested,
365 errors::Unimplemented("Only non-nested TensorList is supported "
366 "for TensorListGetItem."));
367
368 xla::XlaOp list = ctx->Input(0);
369 xla::XlaOp index = ctx->Input(1);
370
371 xla::XlaOp result;
372 OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result));
373
374 ctx->SetOutput(0, result);
375 }
376
377 private:
378 DataType dtype_;
379
380 TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp);
381 };
382
383 REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp);
384
385 class TensorListGatherOp : public XlaOpKernel {
386 public:
TensorListGatherOp(OpKernelConstruction * ctx)387 explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
388 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
389 }
390
Compile(XlaOpKernelContext * ctx)391 void Compile(XlaOpKernelContext* ctx) override {
392 // Check that the TensorList is initialized.
393 bool is_initialized;
394 OP_REQUIRES_OK(ctx,
395 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
396 OP_REQUIRES(ctx, is_initialized,
397 errors::InvalidArgument("TensorList is not initialized"));
398
399 // Only non-nested TensorList is supported for now.
400 bool is_nested;
401 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
402 OP_REQUIRES(ctx, !is_nested,
403 errors::Unimplemented("Only non-nested TensorList is supported "
404 "for TensorListGather."));
405
406 DataType indices_type = ctx->input_type(1);
407
408 const TensorShape indices_shape = ctx->InputShape(1);
409 OP_REQUIRES(ctx, indices_shape.dims() == 1,
410 errors::InvalidArgument("indices must be rank 1"));
411
412 xla::XlaOp list = ctx->Input(0);
413 xla::XlaOp indices = ctx->Input(1);
414
415 xla::XlaOp buffer;
416 OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer));
417 xla::Shape buffer_xla_shape;
418 OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape));
419 TensorShape buffer_shape;
420 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape));
421
422 xla::XlaOp result;
423 OP_REQUIRES_OK(
424 ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0,
425 /*indices_are_nd=*/false, dtype_, indices_type,
426 ctx->builder(), &result));
427 ctx->SetOutput(0, result);
428 }
429
430 private:
431 DataType dtype_;
432
433 TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp);
434 };
435
436 REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp);
437
438 class TensorListStackOp : public XlaOpKernel {
439 public:
TensorListStackOp(OpKernelConstruction * ctx)440 explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
441
Compile(XlaOpKernelContext * ctx)442 void Compile(XlaOpKernelContext* ctx) override {
443 // Check that the TensorList is initialized.
444 bool is_initialized;
445 OP_REQUIRES_OK(ctx,
446 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
447 OP_REQUIRES(ctx, is_initialized,
448 errors::InvalidArgument("TensorList is not initialized"));
449
450 // Only non-nested TensorList is supported for now.
451 bool is_nested;
452 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
453 OP_REQUIRES(ctx, !is_nested,
454 errors::Unimplemented("Only non-nested TensorList is supported "
455 "for TensorListGetItem."));
456
457 xla::XlaOp buffer;
458 OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
459 ctx->SetOutput(0, buffer);
460 }
461
462 private:
463 TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp);
464 };
465
466 REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp);
467
468 class TensorListConcatOp : public XlaOpKernel {
469 public:
TensorListConcatOp(OpKernelConstruction * ctx)470 explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
471
Compile(XlaOpKernelContext * ctx)472 void Compile(XlaOpKernelContext* ctx) override {
473 xla::XlaOp input = ctx->Input(0);
474
475 // Check that the TensorList is initialized.
476 bool is_initialized;
477 OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
478 OP_REQUIRES(ctx, is_initialized,
479 errors::InvalidArgument("TensorList is not initialized"));
480
481 // Only non-nested TensorList is supported for now.
482 bool is_nested;
483 OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested));
484 OP_REQUIRES(ctx, !is_nested,
485 errors::Unimplemented("Only non-nested TensorList is supported "
486 "for TensorListConcat."));
487
488 xla::XlaOp buffer;
489 OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer));
490
491 xla::XlaBuilder* b = input.builder();
492 auto shape_or = b->GetShape(buffer);
493 OP_REQUIRES_OK(ctx, shape_or.status());
494 xla::Shape element_shape = std::move(shape_or).value();
495 std::vector<int64_t> element_dims =
496 xla::SpanToVector(element_shape.dimensions());
497 OP_REQUIRES(
498 ctx, element_dims.size() > 1,
499 errors::Unimplemented("TensorList of scalars is not supported"));
500 int64_t num_elements = element_dims[0];
501 int64_t tensor_lengths = element_dims[1];
502
503 std::vector<int64_t> new_dims = {num_elements * tensor_lengths};
504
505 for (int i = 2; i < element_dims.size(); i++) {
506 new_dims.push_back(element_dims[i]);
507 }
508
509 xla::XlaOp out = xla::Reshape(buffer, new_dims);
510 ctx->SetOutput(0, out);
511
512 // Second output is a tensor of lengths of returned tensors.
513 xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths);
514 ctx->SetOutput(1, lengths);
515 }
516
517 private:
518 TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp);
519 };
520
521 REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp);
522
523 class TensorListSplitOp : public XlaOpKernel {
524 public:
TensorListSplitOp(OpKernelConstruction * ctx)525 explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
526 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
527 // Only non-nested TensorList is supported for now.
528 OP_REQUIRES(
529 ctx, dtype_ != DT_VARIANT,
530 errors::Unimplemented(
531 "Only non-nested TensorList is supported for TensorListReserve."));
532 }
533
Compile(XlaOpKernelContext * ctx)534 void Compile(XlaOpKernelContext* ctx) override {
535 xla::XlaOp input_tensor = ctx->Input(0);
536
537 xla::XlaBuilder* b = input_tensor.builder();
538 auto shape_or = b->GetShape(input_tensor);
539 OP_REQUIRES_OK(ctx, shape_or.status());
540 xla::Shape element_shape = std::move(shape_or).value();
541 std::vector<int64_t> element_dims =
542 xla::SpanToVector(element_shape.dimensions());
543 OP_REQUIRES(
544 ctx, !element_dims.empty(),
545 errors::Unimplemented("Element dimensions have to be non-empty"));
546
547 std::vector<int64_t> lengths;
548 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
549 OP_REQUIRES(ctx, !lengths.empty(),
550 errors::Unimplemented("Length has to be non-empty"));
551 int64_t length = lengths[0];
552 for (int64_t len : lengths) {
553 OP_REQUIRES(ctx, len == length,
554 errors::Unimplemented("All lengths have to be the same"));
555 }
556 OP_REQUIRES(
557 ctx, element_dims[0] % length == 0,
558 errors::Unimplemented("Buffer size has to be a multiple of length"));
559 std::vector<int64_t> new_dims = {element_dims[0] / length, length};
560 for (int i = 1; i < element_dims.size(); i++) {
561 new_dims.push_back(element_dims[i]);
562 }
563
564 xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims);
565
566 xla::XlaOp result;
567 OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result));
568 ctx->SetTensorListOutput(0, result);
569 }
570
571 private:
572 DataType dtype_;
573
574 TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp);
575 };
576
577 REGISTER_XLA_OP(Name("TensorListSplit")
578 .CompileTimeConstantInput("element_shape")
579 .CompileTimeConstantInput("lengths"),
580 TensorListSplitOp);
581
582 class TensorListFromTensorOp : public XlaOpKernel {
583 public:
TensorListFromTensorOp(OpKernelConstruction * ctx)584 explicit TensorListFromTensorOp(OpKernelConstruction* ctx)
585 : XlaOpKernel(ctx) {}
586
Compile(XlaOpKernelContext * ctx)587 void Compile(XlaOpKernelContext* ctx) override {
588 const TensorShape& tensor_shape = ctx->InputShape(0);
589 int num_elements = tensor_shape.dim_size(0);
590 const xla::XlaOp tensor = ctx->Input(0);
591 xla::XlaOp result;
592 OP_REQUIRES_OK(ctx,
593 ExecuteTensorListFromTensor(num_elements, tensor, &result));
594 auto list_shape_or = ctx->builder()->GetShape(result);
595 ctx->SetTensorListOutput(0, result);
596 }
597
598 private:
599 TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp);
600 };
601
602 REGISTER_XLA_OP(
603 Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
604 TensorListFromTensorOp);
605
606 class TensorListSetItemOp : public XlaOpKernel {
607 public:
TensorListSetItemOp(OpKernelConstruction * ctx)608 explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
609
Compile(XlaOpKernelContext * ctx)610 void Compile(XlaOpKernelContext* ctx) override {
611 xla::XlaOp list = ctx->Input(0);
612 xla::XlaOp index = ctx->Input(1);
613 xla::XlaOp element = ctx->Input(2);
614 xla::XlaOp initialized_list;
615 OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement(
616 list, element, /*element_is_tensor_list=*/false,
617 &initialized_list));
618
619 // Only non-nested TensorList is supported for now.
620 bool is_nested;
621 OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested));
622 OP_REQUIRES(ctx, !is_nested,
623 errors::Unimplemented("Only non-nested TensorList is supported "
624 "for TensorListSetItem."));
625
626 xla::XlaOp result;
627 OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index,
628 element, &result));
629
630 ctx->SetTensorListOutput(0, result);
631 }
632
633 private:
634 TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp);
635 };
636
637 REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp);
638
639 class TensorListPushBackOp : public XlaOpKernel {
640 public:
TensorListPushBackOp(OpKernelConstruction * ctx)641 explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
642
Compile(XlaOpKernelContext * ctx)643 void Compile(XlaOpKernelContext* ctx) override {
644 xla::XlaOp list = ctx->Input(0);
645 xla::XlaOp element = ctx->Input(1);
646 bool element_is_tensor_list = IsTensorListInput(ctx, 1);
647 xla::XlaOp initialized_list;
648 OP_REQUIRES_OK(
649 ctx, GetInitializedTensorListForElement(
650 list, element, element_is_tensor_list, &initialized_list));
651
652 xla::XlaOp result;
653 OP_REQUIRES_OK(ctx,
654 ExecuteTensorListPushBack(initialized_list, element,
655 element_is_tensor_list, &result));
656
657 ctx->SetTensorListOutput(0, result);
658 }
659
660 private:
661 TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
662 };
663
664 REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(),
665 TensorListPushBackOp);
666
667 class TensorListPopBackOp : public XlaOpKernel {
668 public:
TensorListPopBackOp(OpKernelConstruction * ctx)669 explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
670
Compile(XlaOpKernelContext * ctx)671 void Compile(XlaOpKernelContext* ctx) override {
672 // Check that the TensorList is initialized.
673 bool is_initialized;
674 OP_REQUIRES_OK(ctx,
675 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
676 OP_REQUIRES(ctx, is_initialized,
677 errors::InvalidArgument("TensorList is not initialized"));
678
679 xla::XlaOp list = ctx->Input(0);
680 xla::XlaOp list_result, element_result;
681 bool element_is_tensor_list;
682 OP_REQUIRES_OK(ctx,
683 ExecuteTensorListPopBack(list, &list_result, &element_result,
684 &element_is_tensor_list));
685
686 ctx->SetTensorListOutput(0, list_result);
687 if (element_is_tensor_list) {
688 ctx->SetTensorListOutput(1, element_result);
689 } else {
690 ctx->SetOutput(1, element_result);
691 }
692 }
693
694 private:
695 DataType dtype_;
696
697 TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
698 };
699
700 REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(),
701 TensorListPopBackOp);
702
703 } // anonymous namespace
704 } // namespace tensorflow
705