xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/list_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/full_type.pb.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 // Verifies that `shapes_and_types` is a valid list handle and has the right
26 // dtype.
VerifyHandleData(shape_inference::InferenceContext * c,const std::vector<shape_inference::ShapeAndType> & shapes_and_types,DataType element_dtype)27 Status VerifyHandleData(
28     shape_inference::InferenceContext* c,
29     const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
30     DataType element_dtype) {
31   if (shapes_and_types.size() != 1) {
32     return errors::InvalidArgument(
33         "Invalid handle_data for input list. Expected length of "
34         "shape_and_types: ",
35         1, " Saw: ", shapes_and_types.size());
36   }
37   const shape_inference::ShapeAndType& list_shape_type = shapes_and_types[0];
38   if (list_shape_type.dtype != element_dtype) {
39     return errors::InvalidArgument("Expected list with element dtype ",
40                                    DataTypeString(element_dtype),
41                                    " but got list with element dtype ",
42                                    DataTypeString(list_shape_type.dtype));
43   }
44   return OkStatus();
45 }
46 
IsValidTensorListHandleData(const std::vector<shape_inference::ShapeAndType> * handle_data)47 bool IsValidTensorListHandleData(
48     const std::vector<shape_inference::ShapeAndType>* handle_data) {
49   return handle_data != nullptr && handle_data->size() == 1;
50 }
51 
52 // Assumes that the handle_data is valid.
GetElementShapeFromHandleData(const std::vector<shape_inference::ShapeAndType> & shapes_and_types)53 shape_inference::ShapeHandle GetElementShapeFromHandleData(
54     const std::vector<shape_inference::ShapeAndType>& shapes_and_types) {
55   return shapes_and_types[0].shape;
56 }
57 
58 REGISTER_OP("EmptyTensorList")
59     .Input("element_shape: shape_type")
60     .Input("max_num_elements: int32")
61     .Output("handle: variant")
62     .Attr("element_dtype: type")
63     .Attr("shape_type: {int32, int64}")
64     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
65                                                         "element_dtype"))
__anon5e0c1ea50202(shape_inference::InferenceContext* c) 66     .SetShapeFn([](shape_inference::InferenceContext* c) {
67       c->set_output(0, c->Scalar());
68       DataType element_dtype;
69       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
70       shape_inference::ShapeHandle element_shape;
71       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
72           0, &element_shape));
73       const FullTypeDef& ret_types = c->ret_types();
74       c->set_output_handle_shapes_and_types(
75           0, std::vector<shape_inference::ShapeAndType>{
76                  {element_shape, element_dtype, ret_types.args(0)}});
77       return OkStatus();
78     });
79 
80 REGISTER_OP("TensorListPushBack")
81     .Input("input_handle: variant")
82     .Input("tensor: element_dtype")
83     .Output("output_handle: variant")
84     .Attr("element_dtype: type")
85     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
86                                                         "element_dtype"))
__anon5e0c1ea50302(shape_inference::InferenceContext* c) 87     .SetShapeFn([](shape_inference::InferenceContext* c) {
88       c->set_output(0, c->Scalar());
89       DataType element_dtype;
90       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
91       shape_inference::ShapeHandle element_shape = c->UnknownShape();
92 
93       auto* handle_data = c->input_handle_shapes_and_types(0);
94       if (handle_data != nullptr && handle_data->size() > 1) {
95         return errors::InvalidArgument(
96             "Trying to push to list with wrong variant data.");
97       }
98       if (IsValidTensorListHandleData(handle_data)) {
99         const shape_inference::ShapeAndType& list_shape_type =
100             (*handle_data)[0];
101         if (list_shape_type.dtype != element_dtype) {
102           return errors::InvalidArgument(
103               "Trying to push to list with wrong element dtype. List has type ",
104               DataTypeString(list_shape_type.dtype),
105               " but trying to push element with type ",
106               DataTypeString(element_dtype));
107         }
108         shape_inference::ShapeHandle ignored;
109         TF_RETURN_IF_ERROR(
110             c->Merge(element_shape, list_shape_type.shape, &ignored));
111         element_shape = list_shape_type.shape;
112       }
113       const FullTypeDef& ret_types = c->ret_types();
114       c->set_output_handle_shapes_and_types(
115           0, std::vector<shape_inference::ShapeAndType>{
116                  {element_shape, element_dtype, ret_types.args(0)}});
117       return OkStatus();
118     });
119 
120 REGISTER_OP("TensorListPushBackBatch")
121     .Input("input_handles: variant")
122     .Input("tensor: element_dtype")
123     .Output("output_handles: variant")
124     .Attr("element_dtype: type")
125     // TODO(mdan): Also support for inferring from an input type as well.
126     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
127                                                         "element_dtype"))
__anon5e0c1ea50402(shape_inference::InferenceContext* c) 128     .SetShapeFn([](shape_inference::InferenceContext* c) {
129       shape_inference::ShapeHandle input_handles;
130       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles));
131 
132       shape_inference::ShapeHandle tensor;
133       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor));
134 
135       TF_RETURN_IF_ERROR(
136           c->MergePrefix(tensor, input_handles, &tensor, &input_handles));
137 
138       c->set_output(0, input_handles);
139 
140       DataType element_dtype;
141       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
142       shape_inference::ShapeHandle element_shape = c->UnknownShape();
143 
144       auto* handle_data = c->input_handle_shapes_and_types(0);
145       if (handle_data != nullptr && handle_data->size() > 1) {
146         return errors::InvalidArgument(
147             "Trying to push to list with wrong variant data.");
148       }
149       if (IsValidTensorListHandleData(handle_data)) {
150         const shape_inference::ShapeAndType& list_shape_type =
151             (*handle_data)[0];
152         if (list_shape_type.dtype != element_dtype) {
153           return errors::InvalidArgument(
154               "Trying to push to list with wrong element dtype. List has type ",
155               DataTypeString(list_shape_type.dtype),
156               " but trying to push element with type ",
157               DataTypeString(element_dtype));
158         }
159         shape_inference::ShapeHandle ignored;
160         TF_RETURN_IF_ERROR(
161             c->Merge(element_shape, list_shape_type.shape, &ignored));
162         element_shape = list_shape_type.shape;
163       }
164       const FullTypeDef& ret_types = c->ret_types();
165       c->set_output_handle_shapes_and_types(
166           0, std::vector<shape_inference::ShapeAndType>{
167                  {element_shape, element_dtype, ret_types.args(0)}});
168       return OkStatus();
169     });
170 
171 REGISTER_OP("TensorListLength")
172     .Input("input_handle: variant")
173     .Output("length: int32")
174     .SetShapeFn(shape_inference::ScalarShape);
175 
176 REGISTER_OP("TensorListPopBack")
177     .Input("input_handle: variant")
178     .Input("element_shape: int32")
179     .Output("output_handle: variant")
180     .Output("tensor: element_dtype")
181     .Attr("element_dtype: type")
__anon5e0c1ea50502(shape_inference::InferenceContext* c) 182     .SetShapeFn([](shape_inference::InferenceContext* c) {
183       DataType element_dtype;
184       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
185       shape_inference::ShapeHandle tensor_shape = c->UnknownShape();
186       auto* handle_data = c->input_handle_shapes_and_types(0);
187       if (handle_data != nullptr && handle_data->size() > 1) {
188         return errors::InvalidArgument(
189             "Trying to read from list with invalid variant data.");
190       }
191       if (IsValidTensorListHandleData(handle_data)) {
192         const shape_inference::ShapeAndType& list_shape_type =
193             (*handle_data)[0];
194         if (list_shape_type.type.type_id() != TFT_ARRAY) {
195           return errors::InvalidArgument("Input argument must be a list.");
196         }
197         if (list_shape_type.dtype != element_dtype) {
198           return errors::InvalidArgument(
199               "Trying to read from list with wrong element dtype. List has "
200               "type ",
201               DataTypeString(list_shape_type.dtype),
202               " but trying to push element with type ",
203               DataTypeString(element_dtype));
204         }
205         shape_inference::ShapeHandle ignored;
206         TF_RETURN_IF_ERROR(
207             c->Merge(tensor_shape, list_shape_type.shape, &ignored));
208         c->set_output_handle_shapes_and_types(0, *handle_data);
209         tensor_shape = list_shape_type.shape;
210       }
211       c->set_output(1, tensor_shape);
212       c->set_output(0, c->Scalar());
213       return OkStatus();
214     });
215 
216 REGISTER_OP("TensorListStack")
217     .Input("input_handle: variant")
218     .Input("element_shape: int32")
219     .Output("tensor: element_dtype")
220     .Attr("element_dtype: type")
221     .Attr("num_elements: int = -1")
__anon5e0c1ea50602(shape_inference::InferenceContext* c) 222     .SetShapeFn([](shape_inference::InferenceContext* c) {
223       DataType element_dtype;
224       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
225       shape_inference::ShapeHandle element_shape = c->UnknownShape();
226       auto* handle_data = c->input_handle_shapes_and_types(0);
227       if (handle_data != nullptr && handle_data->size() > 1) {
228         return errors::InvalidArgument(
229             "Trying to read from list with wrong variant data.");
230       }
231       if (IsValidTensorListHandleData(handle_data)) {
232         const shape_inference::ShapeAndType& list_shape_type =
233             (*handle_data)[0];
234         if (list_shape_type.dtype != element_dtype) {
235           return errors::InvalidArgument(
236               "Trying to read from list with wrong element dtype. List has "
237               "type ",
238               DataTypeString(list_shape_type.dtype), " but expected type ",
239               DataTypeString(element_dtype));
240         }
241         shape_inference::ShapeHandle ignored;
242         TF_RETURN_IF_ERROR(
243             c->Merge(element_shape, list_shape_type.shape, &ignored));
244         element_shape = list_shape_type.shape;
245       }
246       shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
247       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
248           1, &element_shape_input));
249       TF_RETURN_IF_ERROR(
250           c->Merge(element_shape, element_shape_input, &element_shape));
251       int expected_num_elements = -1;
252       TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
253       shape_inference::ShapeHandle num_elements;
254       if (expected_num_elements == -1) {
255         num_elements = c->MakeShape({c->UnknownDim()});
256       } else {
257         num_elements = c->MakeShape({expected_num_elements});
258       }
259       shape_inference::ShapeHandle result;
260       TF_RETURN_IF_ERROR(c->Concatenate(num_elements, element_shape, &result));
261       c->set_output(0, result);
262       return OkStatus();
263     });
264 
TensorListConcatShapeInference(shape_inference::InferenceContext * c,shape_inference::ShapeHandle element_shape)265 Status TensorListConcatShapeInference(
266     shape_inference::InferenceContext* c,
267     shape_inference::ShapeHandle element_shape) {
268   DataType element_dtype;
269   TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
270   auto* handle_data = c->input_handle_shapes_and_types(0);
271   if (handle_data != nullptr && handle_data->size() > 1) {
272     return errors::InvalidArgument(
273         "Trying to read from list with wrong variant data.");
274   }
275   if (IsValidTensorListHandleData(handle_data)) {
276     const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
277     if (list_shape_type.dtype != element_dtype) {
278       return errors::InvalidArgument(
279           "Trying to read from list with wrong element dtype. List has "
280           "type ",
281           DataTypeString(list_shape_type.dtype), " but expected type ",
282           DataTypeString(element_dtype));
283     }
284     shape_inference::ShapeHandle merged;
285     TF_RETURN_IF_ERROR(c->Merge(element_shape, list_shape_type.shape, &merged));
286     element_shape = merged;
287   }
288   if (c->RankKnown(element_shape)) {
289     shape_inference::ShapeHandle result;
290     TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result));
291     TF_RETURN_IF_ERROR(
292         c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result));
293     c->set_output(0, result);
294   } else {
295     c->set_output(0, c->UnknownShape());
296   }
297   c->set_output(1, c->MakeShape({c->UnknownDim()}));
298   return OkStatus();
299 }
300 
301 REGISTER_OP("TensorListConcat")
302     .Input("input_handle: variant")
303     .Output("tensor: element_dtype")
304     .Output("lengths: int64")
305     .Attr("element_dtype: type")
306     .Attr("element_shape: shape = { unknown_rank: true }")
__anon5e0c1ea50702(shape_inference::InferenceContext* c) 307     .SetShapeFn([](shape_inference::InferenceContext* c) {
308       PartialTensorShape raw_element_shape;
309       TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &raw_element_shape));
310       shape_inference::ShapeHandle element_shape;
311       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape,
312                                                             &element_shape));
313       return TensorListConcatShapeInference(c, element_shape);
314     });
315 
316 REGISTER_OP("TensorListConcatV2")
317     .Input("input_handle: variant")
318     .Input("element_shape: shape_type")
319     .Input("leading_dims: int64")
320     .Output("tensor: element_dtype")
321     .Output("lengths: int64")
322     .Attr("element_dtype: type")
323     .Attr("shape_type: {int32, int64}")
__anon5e0c1ea50802(shape_inference::InferenceContext* c) 324     .SetShapeFn([](shape_inference::InferenceContext* c) {
325       shape_inference::ShapeHandle element_shape;
326       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
327           1, &element_shape));
328       return TensorListConcatShapeInference(c, element_shape);
329     });
330 
331 REGISTER_OP("TensorListSplit")
332     .Input("tensor: element_dtype")
333     .Input("element_shape: shape_type")
334     .Input("lengths: int64")
335     .Output("output_handle: variant")
336     .Attr("element_dtype: type")
337     .Attr("shape_type: {int32, int64}")
338     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
339                                                         "element_dtype"))
__anon5e0c1ea50902(shape_inference::InferenceContext* c) 340     .SetShapeFn([](shape_inference::InferenceContext* c) {
341       c->set_output(0, c->Scalar());
342       DataType element_dtype;
343       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
344       shape_inference::ShapeHandle tensor_shape = c->input(0);
345       shape_inference::ShapeHandle ignored;
346       // Check that tensor is at least a vector.
347       TF_RETURN_IF_ERROR(c->WithRankAtLeast(tensor_shape, 1, &ignored));
348       // Check that lengths is a vector.
349       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
350       shape_inference::ShapeHandle element_shape_from_tensor_shape;
351       TF_RETURN_IF_ERROR(
352           c->Subshape(tensor_shape, 1, &element_shape_from_tensor_shape));
353       TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape({c->UnknownDim()}),
354                                         element_shape_from_tensor_shape,
355                                         &element_shape_from_tensor_shape));
356       shape_inference::ShapeHandle element_shape;
357       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
358           1, &element_shape));
359       TF_RETURN_IF_ERROR(c->Merge(element_shape_from_tensor_shape,
360                                   element_shape,
361                                   &element_shape_from_tensor_shape));
362       const FullTypeDef& ret_types = c->ret_types();
363       c->set_output_handle_shapes_and_types(
364           0, std::vector<shape_inference::ShapeAndType>{
365                  {element_shape, element_dtype, ret_types.args(0)}});
366       return OkStatus();
367     });
368 
369 REGISTER_OP("TensorListFromTensor")
370     .Input("tensor: element_dtype")
371     .Input("element_shape: shape_type")
372     .Output("output_handle: variant")
373     .Attr("element_dtype: type")
374     .Attr("shape_type: {int32, int64}")
375     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
376                                                         "element_dtype"))
377     .SetForwardTypeFn(full_type::UnaryContainerCreate(TFT_ARRAY, 0))
__anon5e0c1ea50a02(shape_inference::InferenceContext* c) 378     .SetShapeFn([](shape_inference::InferenceContext* c) {
379       c->set_output(0, c->Scalar());
380       DataType element_dtype;
381       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
382       shape_inference::ShapeHandle tensor_shape = c->input(0);
383       shape_inference::ShapeHandle tensor_shape_except_first_dim;
384       TF_RETURN_IF_ERROR(
385           c->Subshape(tensor_shape, 1, &tensor_shape_except_first_dim));
386       shape_inference::ShapeHandle element_shape;
387       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
388           1, &element_shape));
389       TF_RETURN_IF_ERROR(c->Merge(tensor_shape_except_first_dim, element_shape,
390                                   &tensor_shape_except_first_dim));
391       const FullTypeDef& ret_types = c->ret_types();
392       c->set_output_handle_shapes_and_types(
393           0, std::vector<shape_inference::ShapeAndType>{
394                  {element_shape, element_dtype, ret_types.args(0)}});
395       return OkStatus();
396     });
397 
398 REGISTER_OP("TensorListElementShape")
399     .Input("input_handle: variant")
400     .Output("element_shape: shape_type")
401     .Attr("shape_type: {int32, int64}")
__anon5e0c1ea50b02(shape_inference::InferenceContext* c) 402     .SetShapeFn([](shape_inference::InferenceContext* c) {
403       auto* handle_data = c->input_handle_shapes_and_types(0);
404       // `TensorListElementShape` returns the scalar -1 if the rank of
405       // element_shape is unknown else returns the shape vector (with possibly
406       // unknown dims).
407       if (!IsValidTensorListHandleData(handle_data)) {
408         c->set_output(0, c->UnknownShape());
409         return OkStatus();
410       }
411       if (c->RankKnown((*handle_data)[0].shape)) {
412         c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape)));
413       } else {
414         c->set_output(0, c->UnknownShape());
415       }
416       return OkStatus();
417     });
418 
419 REGISTER_OP("TensorListReserve")
420     .Input("element_shape: shape_type")
421     .Input("num_elements: int32")
422     .Output("handle: variant")
423     .Attr("element_dtype: type")
424     .Attr("shape_type: {int32, int64}")
425     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
426                                                         "element_dtype"))
__anon5e0c1ea50c02(shape_inference::InferenceContext* c) 427     .SetShapeFn([](shape_inference::InferenceContext* c) {
428       c->set_output(0, c->Scalar());
429       shape_inference::ShapeHandle element_shape;
430       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
431           0, &element_shape));
432       DataType element_dtype;
433       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
434       const FullTypeDef& ret_types = c->ret_types();
435       c->set_output_handle_shapes_and_types(
436           0, std::vector<shape_inference::ShapeAndType>{
437                  {element_shape, element_dtype, ret_types.args(0)}});
438       return OkStatus();
439     });
440 
441 REGISTER_OP("TensorListGetItem")
442     .Input("input_handle: variant")
443     .Input("index: int32")
444     .Input("element_shape: int32")
445     .Output("item: element_dtype")
446     .Attr("element_dtype: type")
__anon5e0c1ea50d02(shape_inference::InferenceContext* c) 447     .SetShapeFn([](shape_inference::InferenceContext* c) {
448       DataType element_dtype;
449       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
450       auto* handle_data = c->input_handle_shapes_and_types(0);
451       shape_inference::ShapeHandle element_shape = c->UnknownShape();
452       if (IsValidTensorListHandleData(handle_data)) {
453         const shape_inference::ShapeAndType& list_shape_type =
454             (*handle_data)[0];
455         element_shape = list_shape_type.shape;
456         if (list_shape_type.dtype != element_dtype) {
457           return errors::InvalidArgument("Expected list with element dtype ",
458                                          DataTypeString(element_dtype),
459                                          " but got list with element dtype ",
460                                          DataTypeString(list_shape_type.dtype));
461         }
462       }
463       shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
464       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
465           2, &element_shape_input));
466       TF_RETURN_IF_ERROR(
467           c->Merge(element_shape, element_shape_input, &element_shape));
468       c->set_output(0, element_shape);
469       return OkStatus();
470     });
471 
472 REGISTER_OP("TensorListResize")
473     .Input("input_handle: variant")
474     .Input("size: int32")
475     .Output("output_handle: variant")
__anon5e0c1ea50e02(shape_inference::InferenceContext* c) 476     .SetShapeFn([](shape_inference::InferenceContext* c) {
477       // Check that `size` has scalar shape.
478       shape_inference::ShapeHandle size_shape = c->input(1);
479       shape_inference::ShapeHandle unused;
480       TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused));
481       c->set_output(0, c->Scalar());
482       auto* handle_data = c->input_handle_shapes_and_types(0);
483       if (IsValidTensorListHandleData(handle_data)) {
484         c->set_output_handle_shapes_and_types(0, *handle_data);
485       }
486       return OkStatus();
487     });
488 
489 REGISTER_OP("TensorListSetItem")
490     .Input("input_handle: variant")
491     .Input("index: int32")
492     .Input("item: element_dtype")
493     .Output("output_handle: variant")
494     .Attr("element_dtype: type")
495     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
496                                                         "element_dtype"))
497     .SetForwardTypeFn(full_type::UnaryContainerAdd(TFT_ARRAY,
498                                                    /*container_idx=*/0,
499                                                    /*element_idx=*/2,
500                                                    /*homogeneous=*/true))
__anon5e0c1ea50f02(shape_inference::InferenceContext* c) 501     .SetShapeFn([](shape_inference::InferenceContext* c) {
502       DataType element_dtype;
503       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
504       auto* handle_data = c->input_handle_shapes_and_types(0);
505       c->set_output(0, c->Scalar());
506       if (IsValidTensorListHandleData(handle_data)) {
507         const shape_inference::ShapeAndType& list_shape_type =
508             (*handle_data)[0];
509         shape_inference::ShapeHandle item_shape = c->input(2);
510         TF_RETURN_IF_ERROR(
511             c->Merge(item_shape, list_shape_type.shape, &item_shape));
512         c->set_output_handle_shapes_and_types(0, *handle_data);
513       } else {
514         const FullTypeDef& ret_types = c->ret_types();
515         c->set_output_handle_shapes_and_types(
516             0, std::vector<shape_inference::ShapeAndType>{
517                    {c->UnknownShape(), element_dtype, ret_types.args(0)}});
518       }
519       return OkStatus();
520     });
521 
522 REGISTER_OP("TensorListGather")
523     .Input("input_handle: variant")
524     .Input("indices: int32")
525     .Input("element_shape: int32")
526     .Output("values: element_dtype")
527     .Attr("element_dtype: type")
__anon5e0c1ea51002(shape_inference::InferenceContext* c) 528     .SetShapeFn([](shape_inference::InferenceContext* c) {
529       DataType element_dtype;
530       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
531       auto* handle_data = c->input_handle_shapes_and_types(0);
532       shape_inference::ShapeHandle element_shape = c->UnknownShape();
533       if (IsValidTensorListHandleData(handle_data)) {
534         const shape_inference::ShapeAndType& list_shape_type =
535             (*handle_data)[0];
536         element_shape = list_shape_type.shape;
537         if (list_shape_type.dtype != element_dtype) {
538           return errors::InvalidArgument("Expected list with element dtype ",
539                                          DataTypeString(element_dtype),
540                                          " but got list with element dtype ",
541                                          DataTypeString(list_shape_type.dtype));
542         }
543       }
544       shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
545       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
546           2, &element_shape_input));
547       TF_RETURN_IF_ERROR(
548           c->Merge(element_shape, element_shape_input, &element_shape));
549       shape_inference::ShapeHandle out;
550       TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
551       c->set_output(0, out);
552       return OkStatus();
553     });
554 
555 REGISTER_OP("TensorListScatter")
556     .Input("tensor: element_dtype")
557     .Input("indices: int32")
558     .Input("element_shape: shape_type")
559     .Output("output_handle: variant")
560     .Attr("element_dtype: type")
561     .Attr("shape_type: {int32, int64}")
562     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
563                                                         "element_dtype"))
__anon5e0c1ea51102(shape_inference::InferenceContext* c) 564     .SetShapeFn([](shape_inference::InferenceContext* c) {
565       DataType element_dtype;
566       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
567       shape_inference::ShapeHandle element_shape;
568       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
569           2, &element_shape));
570       const FullTypeDef& ret_types = c->ret_types();
571       c->set_output_handle_shapes_and_types(
572           0, std::vector<shape_inference::ShapeAndType>{
573                  {element_shape, element_dtype, ret_types.args(0)}});
574       c->set_output(0, c->Scalar());
575       return OkStatus();
576     });
577 
578 REGISTER_OP("TensorListScatterV2")
579     .Input("tensor: element_dtype")
580     .Input("indices: int32")
581     .Input("element_shape: shape_type")
582     .Input("num_elements: int32")
583     .Output("output_handle: variant")
584     .Attr("element_dtype: type")
585     .Attr("shape_type: {int32, int64}")
586     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
587                                                         "element_dtype"))
__anon5e0c1ea51202(shape_inference::InferenceContext* c) 588     .SetShapeFn([](shape_inference::InferenceContext* c) {
589       DataType element_dtype;
590       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
591       shape_inference::ShapeHandle element_shape;
592       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
593           2, &element_shape));
594       const FullTypeDef& ret_types = c->ret_types();
595       c->set_output_handle_shapes_and_types(
596           0, std::vector<shape_inference::ShapeAndType>{
597                  {element_shape, element_dtype, ret_types.args(0)}});
598       c->set_output(0, c->Scalar());
599       return OkStatus();
600     });
601 
602 REGISTER_OP("TensorListScatterIntoExistingList")
603     .Input("input_handle: variant")
604     .Input("tensor: element_dtype")
605     .Input("indices: int32")
606     .Output("output_handle: variant")
607     .Attr("element_dtype: type")
608     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
609                                                         "element_dtype"))
__anon5e0c1ea51302(shape_inference::InferenceContext* c) 610     .SetShapeFn([](shape_inference::InferenceContext* c) {
611       shape_inference::ShapeHandle ignored;
612       // Check that tensor is at least a vector.
613       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &ignored));
614       // Check that indices is a vector.
615       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
616 
617       DataType element_dtype;
618       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
619       shape_inference::ShapeHandle element_shape = c->UnknownShape();
620 
621       auto* handle_data = c->input_handle_shapes_and_types(0);
622       if (IsValidTensorListHandleData(handle_data)) {
623         TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
624         element_shape = GetElementShapeFromHandleData(*handle_data);
625       }
626       const FullTypeDef& ret_types = c->ret_types();
627       c->set_output_handle_shapes_and_types(
628           0, std::vector<shape_inference::ShapeAndType>{
629                  {element_shape, element_dtype, ret_types.args(0)}});
630       c->set_output(0, c->Scalar());
631       return OkStatus();
632     });
633 
634 REGISTER_OP("TensorListConcatLists")
635     .Input("input_a: variant")
636     .Input("input_b: variant")
637     .Attr("element_dtype: type")
638     .Output("output: variant")
639     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
640                                                         "element_dtype"))
__anon5e0c1ea51402(shape_inference::InferenceContext* c) 641     .SetShapeFn([](shape_inference::InferenceContext* c) {
642       auto input_a = c->input(0);
643       auto input_b = c->input(1);
644       TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a));
645       c->set_output(0, input_a);
646 
647       DataType element_dtype;
648       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
649 
650       auto* handle_data_a = c->input_handle_shapes_and_types(0);
651       auto* handle_data_b = c->input_handle_shapes_and_types(1);
652       bool handle_data_a_nonempty = handle_data_a && !handle_data_a->empty();
653       bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
654       if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
655         const FullTypeDef& ret_types = c->ret_types();
656         c->set_output_handle_shapes_and_types(
657             0, {{c->UnknownShape(), element_dtype, ret_types.args(0)}});
658         return OkStatus();
659       }
660       shape_inference::ShapeAndType list_shape_type_a =
661           handle_data_a_nonempty ? handle_data_a->at(0) : handle_data_b->at(0);
662       const shape_inference::ShapeAndType& list_shape_type_b =
663           handle_data_b_nonempty ? handle_data_b->at(0) : handle_data_a->at(0);
664       if (list_shape_type_a.dtype != element_dtype) {
665         return errors::InvalidArgument("input_a.type != element_dtype: ",
666                                        DataTypeString(list_shape_type_a.dtype),
667                                        " vs. ", DataTypeString(element_dtype));
668       }
669       if (list_shape_type_b.dtype != element_dtype) {
670         return errors::InvalidArgument("input_b.type != element_dtype: ",
671                                        DataTypeString(list_shape_type_b.dtype),
672                                        " vs. ", DataTypeString(element_dtype));
673       }
674       TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape,
675                                   list_shape_type_b.shape,
676                                   &list_shape_type_a.shape));
677       c->set_output_handle_shapes_and_types(0, {list_shape_type_a});
678       return OkStatus();
679     });
680 
681 }  // namespace
682 }  // namespace tensorflow
683