xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
17 
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 // TensorList is represented by a tuple.
31 // - The first part of the tuple is a buffer containing all the tensors,
32 // - The following parts are push indices for all nested levels of
33 //   TensorLists. The last part is push index for the outermost TensorList.
34 //
35 // TensorList, as it name suggests, is conceptually a list of tensors. In actual
36 // representation of a non-nested TensorList, the buffer shape is
37 // [element_shape, tensor_list_size]. We will call tensor_list_size "leading
38 // dimension" below. Notice that the leading dimension must be a compile time
39 // constant, since it's part of the buffer shape.
40 //
41 // Example: consider a 3-level nested TensorList whose element type is scalar.
42 // Assume inner TensorList has leading dimension 4, middle TensorList has 3,
43 // and outer TensorList has 3.
44 // Assume that lower cased letter means there is data in that position, and "."
45 // means there is no data in that position.
46 // First element of outer TensorList:
47 // [ a . . . ]
48 // [ b c . . ]
49 // [ d e f . ]
50 // Second element of outer TensorList:
51 // [ g h i . ]
52 // [ j k . . ]
53 // [ . . . . ]
54 // Third element: not pushed yet.
55 //
56 // The first part of the tuple is an array of shape [3, 3, 4] containing data.
57 // The second part is an array of shape [3, 3], each element is push index
58 // for the inner TensorList. In this case, its values are:
59 // [ 1 2 3 ]
60 // [ 3 2 . ]
61 // [ . . . ]
62 // The third part is an array of shape [3], each element is push index for
63 // the middle TensorList. In this case, its values are:
64 // [ 3 ]
65 // [ 2 ]
66 // [ . ]
67 // The forth (and last) part is a scalar. It's the push index for the outer
68 // TensorList. In this case, its values is 2.
69 //
70 // Now imagine we need to push the following element to the outer TensorList:
71 // [ l . . . ]
72 // [ m n . . ]
73 // [ . . . . ]
74 // This element is represented by a tuple of 3 parts:
75 // First part is all data.
76 // Second part is push indices for the inner TensorList, which is [ 1 2 . ].
77 // Third part is push index for the middle TensorList, which is 2.
78 // Now let's do the push.
79 // First, we append its data to outer TensorList's data.
80 // Then we start to deal with push indices. Similar to data, we append push
81 // indices for each level of TensorList.
82 // For the inner TensorList: append push indices for the pushed element.
83 // [ 1 2 3 ]               [ 1 2 3 ]
84 // [ 3 2 . ] +           = [ 3 2 . ]
85 // [ . . . ]   [ 1 2 . ]   [ 1 2 . ]
86 // For the middle TensorList: append push indices for the pushed element.
87 // [ 3 ]           [ 3 ]
88 // [ 2 ] +       = [ 2 ]
89 // [ . ]   [ 2 ]   [ 2 ]
90 // For the outer TensorList: just add 1.
91 // 2 + 1 = 3
92 //
93 // Popping an element from the outer TensorList also follows a similar process.
94 // First part is data. We get data by slicing data with push index for outer
95 // TensorList (which is 3).
96 // Second part is push indices for inner TensorList. We get it by slicing
97 // push indices for inner TensorList with push index for outer TensorList (which
98 // is 3).
99 // [ 1 2 3 ]
100 // [ 3 2 . ]
101 // [ 1 2 . ] ===> This is what we want
102 // Third part is push index for middle TensorList. We get it by slicing
103 // push indices for middle TensorList with push index for outer TensorList
104 // (which is 3).
105 // [ 3 ]
106 // [ 2 ]
107 // [ 2 ] ===> This is what we want
108 
109 namespace tensorflow {
110 
IsTensorListInput(XlaOpKernelContext * ctx,int index)111 bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
112   return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
113 }
114 
IsTensorListInitialized(xla::XlaOp list,bool * is_initialized)115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) {
116   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
117   *is_initialized = list_shape.IsTuple();
118   return OkStatus();
119 }
120 
IsNestedTensorList(xla::XlaOp list,bool * is_nested_list)121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) {
122   bool is_initialized;
123   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
124   if (!is_initialized) {
125     return errors::InvalidArgument("TensorList is not initialized");
126   }
127   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
128   *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2);
129   return OkStatus();
130 }
131 
BuildNonNestedTensorList(xla::XlaOp buffer,xla::XlaOp push_index,xla::XlaOp * output_list)132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
133                                 xla::XlaOp* output_list) {
134   TF_RET_CHECK(buffer.builder());
135   *output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
136   return OkStatus();
137 }
138 
GetTensorListBufferShape(xla::XlaOp list,xla::Shape * buffer_shape)139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) {
140   bool is_initialized;
141   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
142   if (!is_initialized) {
143     return errors::InvalidArgument("TensorList is not initialized");
144   }
145   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
146   *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
147   return OkStatus();
148 }
149 
GetTensorListBuffer(xla::XlaOp list,xla::XlaOp * buffer)150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) {
151   bool is_initialized;
152   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
153   if (!is_initialized) {
154     return errors::InvalidArgument("TensorList is not initialized");
155   }
156   *buffer = xla::GetTupleElement(list, 0);
157   return OkStatus();
158 }
159 
GetTensorListPushIndex(xla::XlaOp list,xla::XlaOp * push_index)160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) {
161   bool is_initialized;
162   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
163   if (!is_initialized) {
164     return errors::InvalidArgument("TensorList is not initialized");
165   }
166   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
167   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
168   *push_index = xla::GetTupleElement(list, tuple_size - 1);
169   return OkStatus();
170 }
171 
SetTensorListPushIndex(xla::XlaOp list,xla::XlaOp push_index,xla::XlaOp * result)172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
173                               xla::XlaOp* result) {
174   bool is_initialized;
175   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
176   if (!is_initialized) {
177     return errors::InvalidArgument("TensorList is not initialized");
178   }
179   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
180   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
181   std::vector<xla::XlaOp> result_parts;
182   result_parts.reserve(tuple_size);
183   for (int i = 0; i < tuple_size - 1; i++) {
184     result_parts.push_back(xla::GetTupleElement(list, i));
185   }
186   result_parts.push_back(push_index);
187   *result = xla::Tuple(list.builder(), result_parts);
188   return OkStatus();
189 }
190 
BuildUninitializedTensorList(xla::XlaBuilder * b,int64_t leading_dimension,bool leading_size_is_dynamic,xla::XlaOp leading_dim_size)191 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
192                                         int64_t leading_dimension,
193                                         bool leading_size_is_dynamic,
194                                         xla::XlaOp leading_dim_size) {
195   auto zero =
196       xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
197   auto broadcast =
198       xla::Broadcast(zero, std::vector<int64_t>{leading_dimension});
199   if (leading_size_is_dynamic) {
200     return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
201   } else {
202     return broadcast;
203   }
204 }
205 
GetLeadingDimForTensorList(xla::XlaOp list,int64_t * leading_dim,bool * leading_dim_is_dynamic,xla::XlaOp * leading_dim_dynamic_size)206 Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim,
207                                   bool* leading_dim_is_dynamic,
208                                   xla::XlaOp* leading_dim_dynamic_size) {
209   bool is_initialized;
210   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
211   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
212   if (is_initialized) {
213     auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
214     *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
215     auto buffer = xla::GetTupleElement(list, 0);
216     *leading_dim = buffer_shape.dimensions(0);
217     *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
218   } else {
219     *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
220     *leading_dim = list_shape.dimensions(0);
221     *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
222   }
223   return OkStatus();
224 }
225 
GetTensorListShapeFromElementTensorListShape(const xla::Shape & element_tensor_list_shape,int64_t leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)226 Status GetTensorListShapeFromElementTensorListShape(
227     const xla::Shape& element_tensor_list_shape, int64_t leading_dim,
228     bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
229   std::vector<xla::Shape> shapes;
230   int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
231   for (int i = 0; i < tuple_size; i++) {
232     const xla::Shape& shape =
233         xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
234     std::vector<int64_t> dimensions = xla::SpanToVector(shape.dimensions());
235     dimensions.insert(dimensions.begin(), leading_dim);
236     shapes.push_back(
237         xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
238     if (leading_dim_is_dynamic) {
239       shapes.back().set_dynamic_dimension(0, true);
240     }
241   }
242   shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
243                                              std::vector<int64_t>{}));
244   *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
245   return OkStatus();
246 }
247 
GetTensorListShapeFromElementShape(const xla::Shape & element_shape,int64_t leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)248 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
249                                           int64_t leading_dim,
250                                           bool leading_dim_is_dynamic,
251                                           xla::Shape* tensor_list_shape) {
252   if (!element_shape.IsArray()) {
253     return errors::InvalidArgument(
254         "GetTensorListShapeFromElementShape() only supports normal tensor "
255         "shape. But element shape is ",
256         element_shape.DebugString());
257   }
258   std::vector<xla::Shape> shapes;
259   std::vector<int64_t> dimensions =
260       xla::SpanToVector(element_shape.dimensions());
261   dimensions.insert(dimensions.begin(), leading_dim);
262   shapes.push_back(
263       xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
264   shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
265   shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
266                                              std::vector<int64_t>{}));
267   *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
268   return OkStatus();
269 }
270 
CreateZerosTensorListWithShape(xla::XlaBuilder * b,const xla::Shape & list_shape,const std::vector<std::vector<xla::XlaOp>> & dynamic_dims,xla::XlaOp * list)271 Status CreateZerosTensorListWithShape(
272     xla::XlaBuilder* b, const xla::Shape& list_shape,
273     const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
274     xla::XlaOp* list) {
275   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
276   std::vector<xla::XlaOp> elements;
277   TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
278   for (int i = 0; i < tuple_size - 1; i++) {
279     const xla::Shape& shape =
280         xla::ShapeUtil::GetTupleElementShape(list_shape, i);
281     xla::XlaOp zero =
282         xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
283     xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
284     TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
285     for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) {
286       zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
287     }
288     elements.push_back(zeros);
289   }
290   // List size (last item) has to be S32.
291   TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
292                    .element_type() == xla::S32);
293   elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
294   *list = xla::Tuple(b, elements);
295   return OkStatus();
296 }
297 
GetInitializedTensorListForElement(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * initialized_list)298 Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
299                                           bool element_is_tensor_list,
300                                           xla::XlaOp* initialized_list) {
301   int64_t leading_dim;
302   xla::XlaOp leading_dim_dynamic_size;
303   bool leading_dim_is_dynamic;
304   TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
305       list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
306 
307   xla::XlaBuilder* b = list.builder();
308   xla::Shape list_shape;
309   TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
310 
311   if (element_is_tensor_list) {
312     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
313         element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
314   } else {
315     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
316         element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
317   }
318   bool is_initialized;
319   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
320   if (is_initialized) {
321     // Check shape of initialized list is correct.
322     TF_ASSIGN_OR_RETURN(xla::Shape original_list_shape, b->GetShape(list));
323     if (!xla::ShapeUtil::Compatible(original_list_shape, list_shape)) {
324       return errors::Internal(
325           "Invalid TensorList shape: ", original_list_shape.DebugString(),
326           ", expected: ", list_shape.DebugString());
327     }
328     *initialized_list = list;
329     return OkStatus();
330   } else {
331     // Prepare dynamic dimension dimensions for zero tensor list. The dynamic
332     // sizes are created by reading the dynamic dimension size of sub-elements.
333     std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
334     for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
335       std::vector<xla::XlaOp> dynamic_dims;
336       const xla::Shape& shape = list_shape.tuple_shapes(i);
337       dynamic_dims.push_back(leading_dim_dynamic_size);
338       xla::XlaOp sub_element;
339       if (element_is_tensor_list) {
340         sub_element = xla::GetTupleElement(element, i);
341       } else {
342         sub_element = element;
343       }
344       for (int64_t dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
345         dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
346       }
347       list_dynamic_dims.push_back(dynamic_dims);
348     }
349     return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
350                                           initialized_list);
351   }
352 }
353 
ExecuteTensorListPushBack(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * result)354 Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
355                                  bool element_is_tensor_list,
356                                  xla::XlaOp* result) {
357   bool is_initialized;
358   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
359   if (!is_initialized) {
360     return errors::InvalidArgument("TensorList is not initialized");
361   }
362 
363   xla::XlaBuilder* b = list.builder();
364   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
365   int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
366   xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
367 
368   std::vector<xla::XlaOp> result_parts;
369 
370   if (element_is_tensor_list) {
371     TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
372     int element_tuple_size = xla::ShapeUtil::TupleElementCount(element_shape);
373     for (int i = 0; i < element_tuple_size; i++) {
374       const xla::Shape& element_part_shape =
375           xla::ShapeUtil::GetTupleElementShape(element_shape, i);
376       xla::XlaOp element_part = xla::GetTupleElement(element, i);
377       std::vector<int64_t> element_part_dims =
378           xla::SpanToVector(element_part_shape.dimensions());
379       element_part_dims.insert(element_part_dims.begin(), 1);
380       element_part = xla::Reshape(element_part, element_part_dims);
381 
382       std::vector<xla::XlaOp> start_indices(
383           element_part_shape.dimensions_size() + 1,
384           xla::ConstantR0<int32>(b, 0));
385       start_indices[0] = push_index;
386 
387       xla::XlaOp list_part = xla::GetTupleElement(list, i);
388       xla::XlaOp updated_list_part =
389           xla::DynamicUpdateSlice(list_part, element_part, start_indices);
390       result_parts.push_back(updated_list_part);
391     }
392   } else {
393     TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
394     std::vector<int64_t> element_dims =
395         xla::SpanToVector(element_shape.dimensions());
396     element_dims.insert(element_dims.begin(), 1);
397     xla::XlaOp update = xla::Reshape(element, element_dims);
398 
399     std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
400                                           xla::ConstantR0<int32>(b, 0));
401     start_indices[0] = push_index;
402 
403     xla::XlaOp list_part = xla::GetTupleElement(list, 0);
404     xla::XlaOp updated_list_part =
405         xla::DynamicUpdateSlice(list_part, update, start_indices);
406     result_parts.push_back(updated_list_part);
407   }
408 
409   xla::XlaOp updated_push_index = push_index + xla::ConstantR0<int32>(b, 1);
410   result_parts.push_back(updated_push_index);
411 
412   *result = xla::Tuple(b, result_parts);
413   return OkStatus();
414 }
415 
ExecuteTensorListPopBack(xla::XlaOp list,xla::XlaOp * list_result,xla::XlaOp * element_result,bool * element_is_tensor_list)416 Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
417                                 xla::XlaOp* element_result,
418                                 bool* element_is_tensor_list) {
419   bool is_initialized;
420   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
421   if (!is_initialized) {
422     return errors::InvalidArgument("TensorList is not initialized");
423   }
424 
425   // If the TensorList is a nested TensorList, element will be TensorList.
426   TF_RETURN_IF_ERROR(IsNestedTensorList(list, element_is_tensor_list));
427 
428   xla::XlaBuilder* b = list.builder();
429   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
430   int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
431   xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
432   push_index = push_index - xla::ConstantR0<int32>(b, 1);
433 
434   std::vector<xla::XlaOp> list_result_parts, element_result_parts;
435   for (int i = 0; i < list_tuple_size - 1; i++) {
436     const xla::Shape& list_part_shape =
437         xla::ShapeUtil::GetTupleElementShape(list_shape, i);
438     std::vector<xla::XlaOp> start_indices(list_part_shape.dimensions_size(),
439                                           xla::ConstantR0<int32>(b, 0));
440     start_indices[0] = push_index;
441 
442     std::vector<int64_t> slice_shape =
443         xla::SpanToVector(list_part_shape.dimensions());
444     slice_shape[0] = 1LL;
445 
446     xla::XlaOp list_part = xla::GetTupleElement(list, i);
447     xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
448 
449     slice_shape.erase(slice_shape.begin());
450     element_result_parts.push_back(xla::Reshape(read, slice_shape));
451     list_result_parts.push_back(list_part);
452   }
453   list_result_parts.push_back(push_index);
454 
455   *list_result = xla::Tuple(b, list_result_parts);
456   if (*element_is_tensor_list) {
457     *element_result = xla::Tuple(b, element_result_parts);
458   } else {
459     *element_result = element_result_parts[0];
460   }
461 
462   return OkStatus();
463 }
464 
ExecuteTensorListSetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp element,xla::XlaOp * result)465 Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
466                                 xla::XlaOp element, xla::XlaOp* result) {
467   bool is_initialized;
468   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
469   if (!is_initialized) {
470     return errors::InvalidArgument("TensorList is not initialized");
471   }
472   bool is_nested;
473   TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
474   if (is_nested) {
475     return errors::Unimplemented(
476         "ExecuteTensorListSetItem() only supports non-nested TensorList");
477   }
478 
479   xla::XlaBuilder* b = list.builder();
480   TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
481   std::vector<int64_t> element_dims =
482       xla::SpanToVector(element_shape.dimensions());
483   element_dims.insert(element_dims.begin(), 1);
484   xla::XlaOp update = xla::Reshape(element, element_dims);
485 
486   std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
487                                         xla::ConstantR0<int32>(b, 0));
488   start_indices[0] = index;
489 
490   xla::XlaOp list_part = xla::GetTupleElement(list, 0);
491   xla::XlaOp updated_list_part =
492       xla::DynamicUpdateSlice(list_part, update, start_indices);
493 
494   std::vector<xla::XlaOp> result_parts;
495   result_parts.push_back(updated_list_part);
496   result_parts.push_back(xla::GetTupleElement(list, 1));
497   *result = xla::Tuple(b, result_parts);
498   return OkStatus();
499 }
500 
ExecuteTensorListGetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp * result)501 Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
502                                 xla::XlaOp* result) {
503   bool is_initialized;
504   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
505   if (!is_initialized) {
506     return errors::InvalidArgument("TensorList is not initialized");
507   }
508   bool is_nested;
509   TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
510   if (is_nested) {
511     return errors::Unimplemented(
512         "ExecuteTensorListGetItem() only supports non-nested TensorList");
513   }
514 
515   xla::XlaBuilder* b = list.builder();
516   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
517   const xla::Shape& buffer_shape =
518       xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
519   std::vector<xla::XlaOp> start_indices(buffer_shape.dimensions_size(),
520                                         xla::ConstantR0<int32>(b, 0));
521   start_indices[0] = index;
522 
523   std::vector<int64_t> slice_shape =
524       xla::SpanToVector(buffer_shape.dimensions());
525   slice_shape[0] = 1LL;
526 
527   xla::XlaOp list_part = xla::GetTupleElement(list, 0);
528   xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
529   // Propagate dynamic dimensions from buffer to the sliced buffer, except for
530   // leading dimension (which is always static 1).
531   for (int64_t i = 1; i < buffer_shape.dimensions_size(); ++i) {
532     if (buffer_shape.is_dynamic_dimension(i)) {
533       auto buffer = xla::GetTupleElement(list, 0);
534       auto gds = xla::GetDimensionSize(buffer, i);
535       read = xla::SetDimensionSize(read, gds, i);
536     }
537   }
538   slice_shape.erase(slice_shape.begin());
539   *result = xla::Reshape(read, slice_shape);
540   return OkStatus();
541 }
542 
ExecuteTensorListFromTensor(int push_index,xla::XlaOp tensor,xla::XlaOp * result)543 Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor,
544                                    xla::XlaOp* result) {
545   xla::XlaBuilder* b = tensor.builder();
546   TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor));
547   if (!shape.IsArray()) {
548     return errors::InvalidArgument(
549         "ExecuteTensorListFromTensor() only supports normal tensor. But input "
550         "shape is ",
551         shape.DebugString());
552   }
553 
554   std::vector<xla::XlaOp> result_parts{tensor,
555                                        xla::ConstantR0<int32>(b, push_index)};
556   *result = xla::Tuple(b, result_parts);
557   return OkStatus();
558 }
559 
560 }  // namespace tensorflow
561