1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
17
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/errors.h"
29
30 // TensorList is represented by a tuple.
31 // - The first part of the tuple is a buffer containing all the tensors,
32 // - The following parts are push indices for all nested levels of
33 // TensorLists. The last part is push index for the outermost TensorList.
34 //
35 // TensorList, as it name suggests, is conceptually a list of tensors. In actual
36 // representation of a non-nested TensorList, the buffer shape is
37 // [element_shape, tensor_list_size]. We will call tensor_list_size "leading
38 // dimension" below. Notice that the leading dimension must be a compile time
39 // constant, since it's part of the buffer shape.
40 //
41 // Example: consider a 3-level nested TensorList whose element type is scalar.
42 // Assume inner TensorList has leading dimension 4, middle TensorList has 3,
43 // and outer TensorList has 3.
44 // Assume that lower cased letter means there is data in that position, and "."
45 // means there is no data in that position.
46 // First element of outer TensorList:
47 // [ a . . . ]
48 // [ b c . . ]
49 // [ d e f . ]
50 // Second element of outer TensorList:
51 // [ g h i . ]
52 // [ j k . . ]
53 // [ . . . . ]
54 // Third element: not pushed yet.
55 //
56 // The first part of the tuple is an array of shape [3, 3, 4] containing data.
57 // The second part is an array of shape [3, 3], each element is push index
58 // for the inner TensorList. In this case, its values are:
59 // [ 1 2 3 ]
60 // [ 3 2 . ]
61 // [ . . . ]
62 // The third part is an array of shape [3], each element is push index for
63 // the middle TensorList. In this case, its values are:
64 // [ 3 ]
65 // [ 2 ]
66 // [ . ]
67 // The forth (and last) part is a scalar. It's the push index for the outer
68 // TensorList. In this case, its values is 2.
69 //
70 // Now imagine we need to push the following element to the outer TensorList:
71 // [ l . . . ]
72 // [ m n . . ]
73 // [ . . . . ]
74 // This element is represented by a tuple of 3 parts:
75 // First part is all data.
76 // Second part is push indices for the inner TensorList, which is [ 1 2 . ].
77 // Third part is push index for the middle TensorList, which is 2.
78 // Now let's do the push.
79 // First, we append its data to outer TensorList's data.
80 // Then we start to deal with push indices. Similar to data, we append push
81 // indices for each level of TensorList.
82 // For the inner TensorList: append push indices for the pushed element.
83 // [ 1 2 3 ] [ 1 2 3 ]
84 // [ 3 2 . ] + = [ 3 2 . ]
85 // [ . . . ] [ 1 2 . ] [ 1 2 . ]
86 // For the middle TensorList: append push indices for the pushed element.
87 // [ 3 ] [ 3 ]
88 // [ 2 ] + = [ 2 ]
89 // [ . ] [ 2 ] [ 2 ]
90 // For the outer TensorList: just add 1.
91 // 2 + 1 = 3
92 //
93 // Popping an element from the outer TensorList also follows a similar process.
94 // First part is data. We get data by slicing data with push index for outer
95 // TensorList (which is 3).
96 // Second part is push indices for inner TensorList. We get it by slicing
97 // push indices for inner TensorList with push index for outer TensorList (which
98 // is 3).
99 // [ 1 2 3 ]
100 // [ 3 2 . ]
101 // [ 1 2 . ] ===> This is what we want
102 // Third part is push index for middle TensorList. We get it by slicing
103 // push indices for middle TensorList with push index for outer TensorList
104 // (which is 3).
105 // [ 3 ]
106 // [ 2 ]
107 // [ 2 ] ===> This is what we want
108
109 namespace tensorflow {
110
IsTensorListInput(XlaOpKernelContext * ctx,int index)111 bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
112 return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
113 }
114
IsTensorListInitialized(xla::XlaOp list,bool * is_initialized)115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) {
116 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
117 *is_initialized = list_shape.IsTuple();
118 return OkStatus();
119 }
120
IsNestedTensorList(xla::XlaOp list,bool * is_nested_list)121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) {
122 bool is_initialized;
123 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
124 if (!is_initialized) {
125 return errors::InvalidArgument("TensorList is not initialized");
126 }
127 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
128 *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2);
129 return OkStatus();
130 }
131
BuildNonNestedTensorList(xla::XlaOp buffer,xla::XlaOp push_index,xla::XlaOp * output_list)132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
133 xla::XlaOp* output_list) {
134 TF_RET_CHECK(buffer.builder());
135 *output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
136 return OkStatus();
137 }
138
GetTensorListBufferShape(xla::XlaOp list,xla::Shape * buffer_shape)139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) {
140 bool is_initialized;
141 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
142 if (!is_initialized) {
143 return errors::InvalidArgument("TensorList is not initialized");
144 }
145 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
146 *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
147 return OkStatus();
148 }
149
GetTensorListBuffer(xla::XlaOp list,xla::XlaOp * buffer)150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) {
151 bool is_initialized;
152 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
153 if (!is_initialized) {
154 return errors::InvalidArgument("TensorList is not initialized");
155 }
156 *buffer = xla::GetTupleElement(list, 0);
157 return OkStatus();
158 }
159
GetTensorListPushIndex(xla::XlaOp list,xla::XlaOp * push_index)160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) {
161 bool is_initialized;
162 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
163 if (!is_initialized) {
164 return errors::InvalidArgument("TensorList is not initialized");
165 }
166 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
167 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
168 *push_index = xla::GetTupleElement(list, tuple_size - 1);
169 return OkStatus();
170 }
171
SetTensorListPushIndex(xla::XlaOp list,xla::XlaOp push_index,xla::XlaOp * result)172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
173 xla::XlaOp* result) {
174 bool is_initialized;
175 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
176 if (!is_initialized) {
177 return errors::InvalidArgument("TensorList is not initialized");
178 }
179 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
180 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
181 std::vector<xla::XlaOp> result_parts;
182 result_parts.reserve(tuple_size);
183 for (int i = 0; i < tuple_size - 1; i++) {
184 result_parts.push_back(xla::GetTupleElement(list, i));
185 }
186 result_parts.push_back(push_index);
187 *result = xla::Tuple(list.builder(), result_parts);
188 return OkStatus();
189 }
190
BuildUninitializedTensorList(xla::XlaBuilder * b,int64_t leading_dimension,bool leading_size_is_dynamic,xla::XlaOp leading_dim_size)191 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
192 int64_t leading_dimension,
193 bool leading_size_is_dynamic,
194 xla::XlaOp leading_dim_size) {
195 auto zero =
196 xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
197 auto broadcast =
198 xla::Broadcast(zero, std::vector<int64_t>{leading_dimension});
199 if (leading_size_is_dynamic) {
200 return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
201 } else {
202 return broadcast;
203 }
204 }
205
GetLeadingDimForTensorList(xla::XlaOp list,int64_t * leading_dim,bool * leading_dim_is_dynamic,xla::XlaOp * leading_dim_dynamic_size)206 Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim,
207 bool* leading_dim_is_dynamic,
208 xla::XlaOp* leading_dim_dynamic_size) {
209 bool is_initialized;
210 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
211 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
212 if (is_initialized) {
213 auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
214 *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
215 auto buffer = xla::GetTupleElement(list, 0);
216 *leading_dim = buffer_shape.dimensions(0);
217 *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
218 } else {
219 *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
220 *leading_dim = list_shape.dimensions(0);
221 *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
222 }
223 return OkStatus();
224 }
225
GetTensorListShapeFromElementTensorListShape(const xla::Shape & element_tensor_list_shape,int64_t leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)226 Status GetTensorListShapeFromElementTensorListShape(
227 const xla::Shape& element_tensor_list_shape, int64_t leading_dim,
228 bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
229 std::vector<xla::Shape> shapes;
230 int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
231 for (int i = 0; i < tuple_size; i++) {
232 const xla::Shape& shape =
233 xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
234 std::vector<int64_t> dimensions = xla::SpanToVector(shape.dimensions());
235 dimensions.insert(dimensions.begin(), leading_dim);
236 shapes.push_back(
237 xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
238 if (leading_dim_is_dynamic) {
239 shapes.back().set_dynamic_dimension(0, true);
240 }
241 }
242 shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
243 std::vector<int64_t>{}));
244 *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
245 return OkStatus();
246 }
247
GetTensorListShapeFromElementShape(const xla::Shape & element_shape,int64_t leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)248 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
249 int64_t leading_dim,
250 bool leading_dim_is_dynamic,
251 xla::Shape* tensor_list_shape) {
252 if (!element_shape.IsArray()) {
253 return errors::InvalidArgument(
254 "GetTensorListShapeFromElementShape() only supports normal tensor "
255 "shape. But element shape is ",
256 element_shape.DebugString());
257 }
258 std::vector<xla::Shape> shapes;
259 std::vector<int64_t> dimensions =
260 xla::SpanToVector(element_shape.dimensions());
261 dimensions.insert(dimensions.begin(), leading_dim);
262 shapes.push_back(
263 xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
264 shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
265 shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
266 std::vector<int64_t>{}));
267 *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
268 return OkStatus();
269 }
270
CreateZerosTensorListWithShape(xla::XlaBuilder * b,const xla::Shape & list_shape,const std::vector<std::vector<xla::XlaOp>> & dynamic_dims,xla::XlaOp * list)271 Status CreateZerosTensorListWithShape(
272 xla::XlaBuilder* b, const xla::Shape& list_shape,
273 const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
274 xla::XlaOp* list) {
275 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
276 std::vector<xla::XlaOp> elements;
277 TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
278 for (int i = 0; i < tuple_size - 1; i++) {
279 const xla::Shape& shape =
280 xla::ShapeUtil::GetTupleElementShape(list_shape, i);
281 xla::XlaOp zero =
282 xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
283 xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
284 TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
285 for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) {
286 zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
287 }
288 elements.push_back(zeros);
289 }
290 // List size (last item) has to be S32.
291 TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
292 .element_type() == xla::S32);
293 elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
294 *list = xla::Tuple(b, elements);
295 return OkStatus();
296 }
297
GetInitializedTensorListForElement(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * initialized_list)298 Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
299 bool element_is_tensor_list,
300 xla::XlaOp* initialized_list) {
301 int64_t leading_dim;
302 xla::XlaOp leading_dim_dynamic_size;
303 bool leading_dim_is_dynamic;
304 TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
305 list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
306
307 xla::XlaBuilder* b = list.builder();
308 xla::Shape list_shape;
309 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
310
311 if (element_is_tensor_list) {
312 TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
313 element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
314 } else {
315 TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
316 element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
317 }
318 bool is_initialized;
319 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
320 if (is_initialized) {
321 // Check shape of initialized list is correct.
322 TF_ASSIGN_OR_RETURN(xla::Shape original_list_shape, b->GetShape(list));
323 if (!xla::ShapeUtil::Compatible(original_list_shape, list_shape)) {
324 return errors::Internal(
325 "Invalid TensorList shape: ", original_list_shape.DebugString(),
326 ", expected: ", list_shape.DebugString());
327 }
328 *initialized_list = list;
329 return OkStatus();
330 } else {
331 // Prepare dynamic dimension dimensions for zero tensor list. The dynamic
332 // sizes are created by reading the dynamic dimension size of sub-elements.
333 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
334 for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
335 std::vector<xla::XlaOp> dynamic_dims;
336 const xla::Shape& shape = list_shape.tuple_shapes(i);
337 dynamic_dims.push_back(leading_dim_dynamic_size);
338 xla::XlaOp sub_element;
339 if (element_is_tensor_list) {
340 sub_element = xla::GetTupleElement(element, i);
341 } else {
342 sub_element = element;
343 }
344 for (int64_t dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
345 dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
346 }
347 list_dynamic_dims.push_back(dynamic_dims);
348 }
349 return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
350 initialized_list);
351 }
352 }
353
ExecuteTensorListPushBack(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * result)354 Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
355 bool element_is_tensor_list,
356 xla::XlaOp* result) {
357 bool is_initialized;
358 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
359 if (!is_initialized) {
360 return errors::InvalidArgument("TensorList is not initialized");
361 }
362
363 xla::XlaBuilder* b = list.builder();
364 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
365 int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
366 xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
367
368 std::vector<xla::XlaOp> result_parts;
369
370 if (element_is_tensor_list) {
371 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
372 int element_tuple_size = xla::ShapeUtil::TupleElementCount(element_shape);
373 for (int i = 0; i < element_tuple_size; i++) {
374 const xla::Shape& element_part_shape =
375 xla::ShapeUtil::GetTupleElementShape(element_shape, i);
376 xla::XlaOp element_part = xla::GetTupleElement(element, i);
377 std::vector<int64_t> element_part_dims =
378 xla::SpanToVector(element_part_shape.dimensions());
379 element_part_dims.insert(element_part_dims.begin(), 1);
380 element_part = xla::Reshape(element_part, element_part_dims);
381
382 std::vector<xla::XlaOp> start_indices(
383 element_part_shape.dimensions_size() + 1,
384 xla::ConstantR0<int32>(b, 0));
385 start_indices[0] = push_index;
386
387 xla::XlaOp list_part = xla::GetTupleElement(list, i);
388 xla::XlaOp updated_list_part =
389 xla::DynamicUpdateSlice(list_part, element_part, start_indices);
390 result_parts.push_back(updated_list_part);
391 }
392 } else {
393 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
394 std::vector<int64_t> element_dims =
395 xla::SpanToVector(element_shape.dimensions());
396 element_dims.insert(element_dims.begin(), 1);
397 xla::XlaOp update = xla::Reshape(element, element_dims);
398
399 std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
400 xla::ConstantR0<int32>(b, 0));
401 start_indices[0] = push_index;
402
403 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
404 xla::XlaOp updated_list_part =
405 xla::DynamicUpdateSlice(list_part, update, start_indices);
406 result_parts.push_back(updated_list_part);
407 }
408
409 xla::XlaOp updated_push_index = push_index + xla::ConstantR0<int32>(b, 1);
410 result_parts.push_back(updated_push_index);
411
412 *result = xla::Tuple(b, result_parts);
413 return OkStatus();
414 }
415
ExecuteTensorListPopBack(xla::XlaOp list,xla::XlaOp * list_result,xla::XlaOp * element_result,bool * element_is_tensor_list)416 Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
417 xla::XlaOp* element_result,
418 bool* element_is_tensor_list) {
419 bool is_initialized;
420 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
421 if (!is_initialized) {
422 return errors::InvalidArgument("TensorList is not initialized");
423 }
424
425 // If the TensorList is a nested TensorList, element will be TensorList.
426 TF_RETURN_IF_ERROR(IsNestedTensorList(list, element_is_tensor_list));
427
428 xla::XlaBuilder* b = list.builder();
429 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
430 int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
431 xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
432 push_index = push_index - xla::ConstantR0<int32>(b, 1);
433
434 std::vector<xla::XlaOp> list_result_parts, element_result_parts;
435 for (int i = 0; i < list_tuple_size - 1; i++) {
436 const xla::Shape& list_part_shape =
437 xla::ShapeUtil::GetTupleElementShape(list_shape, i);
438 std::vector<xla::XlaOp> start_indices(list_part_shape.dimensions_size(),
439 xla::ConstantR0<int32>(b, 0));
440 start_indices[0] = push_index;
441
442 std::vector<int64_t> slice_shape =
443 xla::SpanToVector(list_part_shape.dimensions());
444 slice_shape[0] = 1LL;
445
446 xla::XlaOp list_part = xla::GetTupleElement(list, i);
447 xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
448
449 slice_shape.erase(slice_shape.begin());
450 element_result_parts.push_back(xla::Reshape(read, slice_shape));
451 list_result_parts.push_back(list_part);
452 }
453 list_result_parts.push_back(push_index);
454
455 *list_result = xla::Tuple(b, list_result_parts);
456 if (*element_is_tensor_list) {
457 *element_result = xla::Tuple(b, element_result_parts);
458 } else {
459 *element_result = element_result_parts[0];
460 }
461
462 return OkStatus();
463 }
464
ExecuteTensorListSetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp element,xla::XlaOp * result)465 Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
466 xla::XlaOp element, xla::XlaOp* result) {
467 bool is_initialized;
468 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
469 if (!is_initialized) {
470 return errors::InvalidArgument("TensorList is not initialized");
471 }
472 bool is_nested;
473 TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
474 if (is_nested) {
475 return errors::Unimplemented(
476 "ExecuteTensorListSetItem() only supports non-nested TensorList");
477 }
478
479 xla::XlaBuilder* b = list.builder();
480 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
481 std::vector<int64_t> element_dims =
482 xla::SpanToVector(element_shape.dimensions());
483 element_dims.insert(element_dims.begin(), 1);
484 xla::XlaOp update = xla::Reshape(element, element_dims);
485
486 std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
487 xla::ConstantR0<int32>(b, 0));
488 start_indices[0] = index;
489
490 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
491 xla::XlaOp updated_list_part =
492 xla::DynamicUpdateSlice(list_part, update, start_indices);
493
494 std::vector<xla::XlaOp> result_parts;
495 result_parts.push_back(updated_list_part);
496 result_parts.push_back(xla::GetTupleElement(list, 1));
497 *result = xla::Tuple(b, result_parts);
498 return OkStatus();
499 }
500
ExecuteTensorListGetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp * result)501 Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
502 xla::XlaOp* result) {
503 bool is_initialized;
504 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
505 if (!is_initialized) {
506 return errors::InvalidArgument("TensorList is not initialized");
507 }
508 bool is_nested;
509 TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
510 if (is_nested) {
511 return errors::Unimplemented(
512 "ExecuteTensorListGetItem() only supports non-nested TensorList");
513 }
514
515 xla::XlaBuilder* b = list.builder();
516 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
517 const xla::Shape& buffer_shape =
518 xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
519 std::vector<xla::XlaOp> start_indices(buffer_shape.dimensions_size(),
520 xla::ConstantR0<int32>(b, 0));
521 start_indices[0] = index;
522
523 std::vector<int64_t> slice_shape =
524 xla::SpanToVector(buffer_shape.dimensions());
525 slice_shape[0] = 1LL;
526
527 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
528 xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
529 // Propagate dynamic dimensions from buffer to the sliced buffer, except for
530 // leading dimension (which is always static 1).
531 for (int64_t i = 1; i < buffer_shape.dimensions_size(); ++i) {
532 if (buffer_shape.is_dynamic_dimension(i)) {
533 auto buffer = xla::GetTupleElement(list, 0);
534 auto gds = xla::GetDimensionSize(buffer, i);
535 read = xla::SetDimensionSize(read, gds, i);
536 }
537 }
538 slice_shape.erase(slice_shape.begin());
539 *result = xla::Reshape(read, slice_shape);
540 return OkStatus();
541 }
542
ExecuteTensorListFromTensor(int push_index,xla::XlaOp tensor,xla::XlaOp * result)543 Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor,
544 xla::XlaOp* result) {
545 xla::XlaBuilder* b = tensor.builder();
546 TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor));
547 if (!shape.IsArray()) {
548 return errors::InvalidArgument(
549 "ExecuteTensorListFromTensor() only supports normal tensor. But input "
550 "shape is ",
551 shape.DebugString());
552 }
553
554 std::vector<xla::XlaOp> result_parts{tensor,
555 xla::ConstantR0<int32>(b, push_index)};
556 *result = xla::Tuple(b, result_parts);
557 return OkStatus();
558 }
559
560 } // namespace tensorflow
561