xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/experimental_dataset_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/full_type.pb.h"
17 #include "tensorflow/core/framework/op.h"
18 
19 namespace tensorflow {
20 
21 REGISTER_OP("AssertCardinalityDataset")
22     .Input("input_dataset: variant")
23     .Input("cardinality: int64")
24     .Output("handle: variant")
25     .Attr("output_types: list(type) >= 1")
26     .Attr("output_shapes: list(shape) >= 1")
27     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
28                                                            "output_types"))
__anon46ca241c0102(shape_inference::InferenceContext* c) 29     .SetShapeFn([](shape_inference::InferenceContext* c) {
30       shape_inference::ShapeHandle unused;
31       // cardinality should be a scalar.
32       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
33       return shape_inference::ScalarShape(c);
34     });
35 
36 REGISTER_OP("AssertNextDataset")
37     .Input("input_dataset: variant")
38     .Input("transformations: string")
39     .Output("handle: variant")
40     .Attr("output_types: list(type) >= 1")
41     .Attr("output_shapes: list(shape) >= 1")
42     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
43                                                            "output_types"))
__anon46ca241c0202(shape_inference::InferenceContext* c) 44     .SetShapeFn([](shape_inference::InferenceContext* c) {
45       shape_inference::ShapeHandle unused;
46       // transformations should be a vector.
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
48       return shape_inference::ScalarShape(c);
49     });
50 
51 REGISTER_OP("ExperimentalAssertNextDataset")
52     .Input("input_dataset: variant")
53     .Input("transformations: string")
54     .Output("handle: variant")
55     .Attr("output_types: list(type) >= 1")
56     .Attr("output_shapes: list(shape) >= 1")
57     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
58                                                            "output_types"))
__anon46ca241c0302(shape_inference::InferenceContext* c) 59     .SetShapeFn([](shape_inference::InferenceContext* c) {
60       shape_inference::ShapeHandle unused;
61       // transformations should be a vector.
62       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
63       return shape_inference::ScalarShape(c);
64     });
65 
66 REGISTER_OP("AssertPrevDataset")
67     .Input("input_dataset: variant")
68     .Input("transformations: string")
69     .Output("handle: variant")
70     .Attr("output_types: list(type) >= 1")
71     .Attr("output_shapes: list(shape) >= 1")
72     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
73                                                            "output_types"))
__anon46ca241c0402(shape_inference::InferenceContext* c) 74     .SetShapeFn([](shape_inference::InferenceContext* c) {
75       shape_inference::ShapeHandle unused;
76       // transformations should be a vector.
77       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
78       return shape_inference::ScalarShape(c);
79     });
80 
81 REGISTER_OP("AutoShardDataset")
82     .Input("input_dataset: variant")
83     .Input("num_workers: int64")
84     .Input("index: int64")
85     .Output("handle: variant")
86     .Attr("auto_shard_policy: int = 0")
87     .Attr("output_types: list(type) >= 1")
88     .Attr("output_shapes: list(shape) >= 1")
89     .Attr("num_replicas: int = 0")
90     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
91                                                            "output_types"))
92     .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
93                                               full_type::ShardTensor))
94     .SetShapeFn(shape_inference::ScalarShape);
95 
96 REGISTER_OP("ExperimentalAutoShardDataset")
97     .Input("input_dataset: variant")
98     .Input("num_workers: int64")
99     .Input("index: int64")
100     .Output("handle: variant")
101     .Attr("auto_shard_policy: int = 0")
102     .Attr("output_types: list(type) >= 1")
103     .Attr("output_shapes: list(shape) >= 1")
104     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
105                                                            "output_types"))
106     .SetShapeFn(shape_inference::ScalarShape);
107 
108 REGISTER_OP("BytesProducedStatsDataset")
109     .Input("input_dataset: variant")
110     .Input("tag: string")
111     .Output("handle: variant")
112     .Attr("output_types: list(type) >= 1")
113     .Attr("output_shapes: list(shape) >= 1")
114     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
115                                                            "output_types"))
__anon46ca241c0502(shape_inference::InferenceContext* c) 116     .SetShapeFn([](shape_inference::InferenceContext* c) {
117       shape_inference::ShapeHandle tag_shape;
118       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
119       return shape_inference::ScalarShape(c);
120     });
121 
122 REGISTER_OP("ExperimentalBytesProducedStatsDataset")
123     .Input("input_dataset: variant")
124     .Input("tag: string")
125     .Output("handle: variant")
126     .Attr("output_types: list(type) >= 1")
127     .Attr("output_shapes: list(shape) >= 1")
128     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
129                                                            "output_types"))
__anon46ca241c0602(shape_inference::InferenceContext* c) 130     .SetShapeFn([](shape_inference::InferenceContext* c) {
131       shape_inference::ShapeHandle tag_shape;
132       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
133       return shape_inference::ScalarShape(c);
134     });
135 
136 REGISTER_OP("ChooseFastestBranchDataset")
137     .Input("input_dataset: variant")
138     .Input("ratio_numerator: int64")
139     .Input("ratio_denominator: int64")
140     .Input("other_arguments: Targuments")
141     .Output("handle: variant")
142     .Attr("Targuments: list(type) >= 0")
143     .Attr("num_elements_per_branch: int >= 1")
144     .Attr("branches: list(func) >= 1")
145     .Attr("other_arguments_lengths: list(int) >= 1")
146     .Attr("output_types: list(type) >= 1")
147     .Attr("output_shapes: list(shape) >= 1")
148     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
149                                                            "output_types"))
150     .SetShapeFn(shape_inference::ScalarShape);
151 
152 REGISTER_OP("ChooseFastestDataset")
153     .Input("input_datasets: N * variant")
154     .Output("handle: variant")
155     .Attr("N: int >= 2")
156     .Attr("num_experiments: int")
157     .Attr("output_types: list(type) >= 1")
158     .Attr("output_shapes: list(shape) >= 1")
159     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
160                                                            "output_types"))
161     .SetShapeFn(shape_inference::ScalarShape);
162 
163 REGISTER_OP("ExperimentalChooseFastestDataset")
164     .Input("input_datasets: N * variant")
165     .Output("handle: variant")
166     .Attr("N: int >= 2")
167     .Attr("num_experiments: int")
168     .Attr("output_types: list(type) >= 1")
169     .Attr("output_shapes: list(shape) >= 1")
170     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
171                                                            "output_types"))
172     .SetShapeFn(shape_inference::ScalarShape);
173 
174 REGISTER_OP("CompressElement")
175     .Input("components: input_types")
176     .Output("compressed: variant")
177     .Attr("input_types: list(type) >= 1")
178     .SetShapeFn(shape_inference::ScalarShape);
179 
180 REGISTER_OP("UncompressElement")
181     .Input("compressed: variant")
182     .Output("components: output_types")
183     .Attr("output_types: list(type) >= 1")
184     .Attr("output_shapes: list(shape) >= 1")
185     .SetShapeFn(shape_inference::DatasetIteratorShape);
186 
187 REGISTER_OP("ComputeBatchSize")
188     .Input("input_dataset : variant")
189     .Output("batch_size : int64")
190     .SetShapeFn(shape_inference::ScalarShape);
191 
192 REGISTER_OP("CSVDataset")
193     .Input("filenames: string")
194     .Input("compression_type: string")
195     .Input("buffer_size: int64")
196     .Input("header: bool")
197     .Input("field_delim: string")
198     .Input("use_quote_delim: bool")
199     .Input("na_value: string")
200     .Input("select_cols: int64")
201     .Input("record_defaults: output_types")
202     .Output("handle: variant")
203     .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
204     .Attr("output_shapes: list(shape) >= 1")
205     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
206     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
207                                                            "output_types"))
__anon46ca241c0702(shape_inference::InferenceContext* c) 208     .SetShapeFn([](shape_inference::InferenceContext* c) {
209       shape_inference::ShapeHandle unused;
210       // `filenames` must be a scalar or a vector.
211       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
212       // `compression_type`, `buffer_size`, `header`, `field_delim`,
213       // `use_quote_delim`, `na_value` must be scalars
214       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
215       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
216       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
217       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
218       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
219       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
220       // `select_cols` must be a vector
221       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
222       // `record_defaults` must be lists of scalars
223       for (size_t i = 8; i < c->num_inputs(); ++i) {
224         shape_inference::ShapeHandle v;
225         TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
226         if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
227           return errors::InvalidArgument(
228               "Shape of a default must be a length-0 or length-1 vector, or a "
229               "scalar.");
230         }
231       }
232       return shape_inference::ScalarShape(c);
233     });
234 
235 REGISTER_OP("CSVDatasetV2")
236     .Input("filenames: string")
237     .Input("compression_type: string")
238     .Input("buffer_size: int64")
239     .Input("header: bool")
240     .Input("field_delim: string")
241     .Input("use_quote_delim: bool")
242     .Input("na_value: string")
243     .Input("select_cols: int64")
244     .Input("record_defaults: output_types")
245     .Input("exclude_cols: int64")
246     .Output("handle: variant")
247     .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
248     .Attr("output_shapes: list(shape) >= 1")
249     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
250     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
251                                                            "output_types"))
__anon46ca241c0802(shape_inference::InferenceContext* c) 252     .SetShapeFn([](shape_inference::InferenceContext* c) {
253       shape_inference::ShapeHandle unused;
254       // `filenames` must be a scalar or a vector.
255       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
256       // `compression_type`, `buffer_size`, `header`, `field_delim`,
257       // `use_quote_delim`, `na_value` must be scalars
258       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
259       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
260       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
261       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
262       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
263       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
264       // `select_cols` must be a vector
265       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
266       // `exclude_cols` must be a vector
267       TF_RETURN_IF_ERROR(
268           c->WithRank(c->input(c->num_inputs() - 1), 1, &unused));
269       // `record_defaults` must be lists of scalars
270       for (size_t i = 8; i < c->num_inputs() - 1; ++i) {
271         shape_inference::ShapeHandle v;
272         TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
273         if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
274           return errors::InvalidArgument(
275               "Shape of a default must be a length-0 or length-1 vector, or a "
276               "scalar.");
277         }
278       }
279       return shape_inference::ScalarShape(c);
280     });
281 
282 REGISTER_OP("ExperimentalCSVDataset")
283     .Input("filenames: string")
284     .Input("compression_type: string")
285     .Input("buffer_size: int64")
286     .Input("header: bool")
287     .Input("field_delim: string")
288     .Input("use_quote_delim: bool")
289     .Input("na_value: string")
290     .Input("select_cols: int64")
291     .Input("record_defaults: output_types")
292     .Output("handle: variant")
293     .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
294     .Attr("output_shapes: list(shape) >= 1")
295     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
296     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
297                                                            "output_types"))
__anon46ca241c0902(shape_inference::InferenceContext* c) 298     .SetShapeFn([](shape_inference::InferenceContext* c) {
299       shape_inference::ShapeHandle unused;
300       // `filenames` must be a scalar or a vector.
301       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
302       // `compression_type`, `buffer_size`, `header`, `field_delim`,
303       // `use_quote_delim`, `na_value` must be scalars
304       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
305       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
306       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
307       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
308       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
309       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
310       // `select_cols` must be a vector
311       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
312       // `record_defaults` must be lists of scalars
313       for (size_t i = 8; i < c->num_inputs(); ++i) {
314         shape_inference::ShapeHandle v;
315         TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
316         if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
317           return errors::InvalidArgument(
318               "Shape of a default must be a length-0 or length-1 vector, or a "
319               "scalar.");
320         }
321       }
322       return shape_inference::ScalarShape(c);
323     });
324 
325 REGISTER_OP("ExperimentalDatasetCardinality")
326     .Input("input_dataset: variant")
327     .Output("cardinality: int64")
328     .SetShapeFn(shape_inference::ScalarShape);
329 
330 REGISTER_OP("DatasetFromGraph")
331     .Input("graph_def: string")
332     .Output("handle: variant")
333     .SetTypeConstructor(full_type::UnaryGeneric(TFT_DATASET))
334     .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0))
335     .SetShapeFn(shape_inference::ScalarShape);
336 
337 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
338 // implement a mechanism to determine whether `dataset` has a side-effect
339 // and use it to decide whether to use a stateless or stateful version of this
340 // op.
341 REGISTER_OP("DatasetToTFRecord")
342     .Input("input_dataset: variant")
343     .Input("filename: string")
344     .Input("compression_type: string")
345     .SetIsStateful()
346     .SetShapeFn(shape_inference::NoOutputs);
347 
348 REGISTER_OP("ExperimentalDatasetToTFRecord")
349     .Input("input_dataset: variant")
350     .Input("filename: string")
351     .Input("compression_type: string")
352     .SetIsStateful()
353     .SetShapeFn(shape_inference::NoOutputs);
354 
355 REGISTER_OP("DenseToSparseBatchDataset")
356     .Input("input_dataset: variant")
357     .Input("batch_size: int64")
358     .Input("row_shape: int64")
359     .Output("handle: variant")
360     .Attr("output_types: list(type) >= 1")
361     .Attr("output_shapes: list(shape) >= 1")
362     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
363                                                            "output_types"))
__anon46ca241c0a02(shape_inference::InferenceContext* c) 364     .SetShapeFn([](shape_inference::InferenceContext* c) {
365       shape_inference::ShapeHandle unused;
366       // batch_size should be a scalar.
367       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
368       // row_shape should be a 1-D vector.
369       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
370       return shape_inference::ScalarShape(c);
371     });
372 
373 REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
374     .Input("input_dataset: variant")
375     .Input("batch_size: int64")
376     .Input("row_shape: int64")
377     .Output("handle: variant")
378     .Attr("output_types: list(type) >= 1")
379     .Attr("output_shapes: list(shape) >= 1")
380     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
381                                                            "output_types"))
__anon46ca241c0b02(shape_inference::InferenceContext* c) 382     .SetShapeFn([](shape_inference::InferenceContext* c) {
383       shape_inference::ShapeHandle unused;
384       // batch_size should be a scalar.
385       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
386       // row_shape should be a 1-D vector.
387       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
388       return shape_inference::ScalarShape(c);
389     });
390 
391 REGISTER_OP("DirectedInterleaveDataset")
392     .Input("selector_input_dataset: variant")
393     .Input("data_input_datasets: N * variant")
394     .Output("handle: variant")
395     .Attr("output_types: list(type) >= 1")
396     .Attr("output_shapes: list(shape) >= 1")
397     .Attr("N: int >= 1")
398     .Attr("stop_on_empty_dataset: bool = false")
399     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
400                                                            "output_types"))
401     .SetShapeFn(shape_inference::ScalarShape);
402 
403 REGISTER_OP("ExperimentalDirectedInterleaveDataset")
404     .Input("selector_input_dataset: variant")
405     .Input("data_input_datasets: N * variant")
406     .Output("handle: variant")
407     .Attr("output_types: list(type) >= 1")
408     .Attr("output_shapes: list(shape) >= 1")
409     .Attr("N: int >= 1")
410     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
411                                                            "output_types"))
412     .SetShapeFn(shape_inference::ScalarShape);
413 
414 REGISTER_OP("GroupByReducerDataset")
415     .Input("input_dataset: variant")
416     .Input("key_func_other_arguments: Tkey_func_other_arguments")
417     .Input("init_func_other_arguments: Tinit_func_other_arguments")
418     .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
419     .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
420     .Output("handle: variant")
421     .Attr("key_func: func")
422     .Attr("init_func: func")
423     .Attr("reduce_func: func")
424     .Attr("finalize_func: func")
425     .Attr("Tkey_func_other_arguments: list(type) >= 0")
426     .Attr("Tinit_func_other_arguments: list(type) >= 0")
427     .Attr("Treduce_func_other_arguments: list(type) >= 0")
428     .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
429     .Attr("output_types: list(type) >= 1")
430     .Attr("output_shapes: list(shape) >= 1")
431     .SetIsStateful()
432     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
433                                                            "output_types"))
434     .SetShapeFn(shape_inference::ScalarShape);
435 
436 REGISTER_OP("ExperimentalGroupByReducerDataset")
437     .Input("input_dataset: variant")
438     .Input("key_func_other_arguments: Tkey_func_other_arguments")
439     .Input("init_func_other_arguments: Tinit_func_other_arguments")
440     .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
441     .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
442     .Output("handle: variant")
443     .Attr("key_func: func")
444     .Attr("init_func: func")
445     .Attr("reduce_func: func")
446     .Attr("finalize_func: func")
447     .Attr("Tkey_func_other_arguments: list(type) >= 0")
448     .Attr("Tinit_func_other_arguments: list(type) >= 0")
449     .Attr("Treduce_func_other_arguments: list(type) >= 0")
450     .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
451     .Attr("output_types: list(type) >= 1")
452     .Attr("output_shapes: list(shape) >= 1")
453     .SetIsStateful()
454     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
455                                                            "output_types"))
456     .SetShapeFn(shape_inference::ScalarShape);
457 
458 REGISTER_OP("GroupByWindowDataset")
459     .Input("input_dataset: variant")
460     .Input("key_func_other_arguments: Tkey_func_other_arguments")
461     .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
462     .Input(
463         "window_size_func_other_arguments: Twindow_size_func_other_arguments")
464     .Output("handle: variant")
465     .Attr("key_func: func")
466     .Attr("reduce_func: func")
467     .Attr("window_size_func: func")
468     .Attr("Tkey_func_other_arguments: list(type) >= 0")
469     .Attr("Treduce_func_other_arguments: list(type) >= 0")
470     .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
471     .Attr("output_types: list(type) >= 1")
472     .Attr("output_shapes: list(shape) >= 1")
473     .Attr("metadata: string = ''")
474     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
475                                                            "output_types"))
476     .SetShapeFn(shape_inference::ScalarShape);
477 
478 REGISTER_OP("GetElementAtIndex")
479     .Input("dataset: variant")
480     .Input("index: int64")
481     .Output("components: output_types")
482     .Attr("output_types: list(type) >= 1")
483     .Attr("output_shapes: list(shape) >= 1")
484     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
485                                                            "output_types"))
486     .SetShapeFn(shape_inference::DatasetIteratorShape);
487 
488 REGISTER_OP("ExperimentalGroupByWindowDataset")
489     .Input("input_dataset: variant")
490     .Input("key_func_other_arguments: Tkey_func_other_arguments")
491     .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
492     .Input(
493         "window_size_func_other_arguments: Twindow_size_func_other_arguments")
494     .Output("handle: variant")
495     .Attr("key_func: func")
496     .Attr("reduce_func: func")
497     .Attr("window_size_func: func")
498     .Attr("Tkey_func_other_arguments: list(type) >= 0")
499     .Attr("Treduce_func_other_arguments: list(type) >= 0")
500     .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
501     .Attr("output_types: list(type) >= 1")
502     .Attr("output_shapes: list(shape) >= 1")
503     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
504                                                            "output_types"))
505     .SetShapeFn(shape_inference::ScalarShape);
506 
507 REGISTER_OP("IgnoreErrorsDataset")
508     .Input("input_dataset: variant")
509     .Output("handle: variant")
510     .Attr("output_types: list(type) >= 1")
511     .Attr("output_shapes: list(shape) >= 1")
512     .Attr("log_warning: bool = false")
513     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
514                                                            "output_types"))
515     .SetShapeFn(shape_inference::ScalarShape);
516 
517 REGISTER_OP("ExperimentalIgnoreErrorsDataset")
518     .Input("input_dataset: variant")
519     .Output("handle: variant")
520     .Attr("output_types: list(type) >= 1")
521     .Attr("output_shapes: list(shape) >= 1")
522     .Attr("log_warning: bool = false")
523     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
524                                                            "output_types"))
525     .SetShapeFn(shape_inference::ScalarShape);
526 
527 REGISTER_OP("IteratorGetDevice")
528     .Input("resource: resource")
529     .Output("device: string")
530     .SetShapeFn(shape_inference::ScalarShape);
531 
532 REGISTER_OP("ExperimentalIteratorGetDevice")
533     .Input("resource: resource")
534     .Output("device: string")
535     .SetShapeFn(shape_inference::ScalarShape);
536 
537 REGISTER_OP("LatencyStatsDataset")
538     .Input("input_dataset: variant")
539     .Input("tag: string")
540     .Output("handle: variant")
541     .Attr("output_types: list(type) >= 1")
542     .Attr("output_shapes: list(shape) >= 1")
543     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
544                                                            "output_types"))
__anon46ca241c0c02(shape_inference::InferenceContext* c) 545     .SetShapeFn([](shape_inference::InferenceContext* c) {
546       shape_inference::ShapeHandle tag_shape;
547       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
548       return shape_inference::ScalarShape(c);
549     });
550 
551 REGISTER_OP("ExperimentalLatencyStatsDataset")
552     .Input("input_dataset: variant")
553     .Input("tag: string")
554     .Output("handle: variant")
555     .Attr("output_types: list(type) >= 1")
556     .Attr("output_shapes: list(shape) >= 1")
557     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
558                                                            "output_types"))
__anon46ca241c0d02(shape_inference::InferenceContext* c) 559     .SetShapeFn([](shape_inference::InferenceContext* c) {
560       shape_inference::ShapeHandle tag_shape;
561       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
562       return shape_inference::ScalarShape(c);
563     });
564 
565 REGISTER_OP("LMDBDataset")
566     .Input("filenames: string")
567     .Output("handle: variant")
568     .Attr("output_types: list(type) >= 1")
569     .Attr("output_shapes: list(shape) >= 1")
570     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
571     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
572                                                            "output_types"))
573     .SetShapeFn(shape_inference::ScalarShape);
574 
575 REGISTER_OP("ExperimentalLMDBDataset")
576     .Input("filenames: string")
577     .Output("handle: variant")
578     .Attr("output_types: list(type) >= 1")
579     .Attr("output_shapes: list(shape) >= 1")
580     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
581     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
582                                                            "output_types"))
583     .SetShapeFn(shape_inference::ScalarShape);
584 
585 REGISTER_OP("MapAndBatchDataset")
586     .Input("input_dataset: variant")
587     .Input("other_arguments: Targuments")
588     .Input("batch_size: int64")
589     .Input("num_parallel_calls: int64")
590     .Input("drop_remainder: bool")
591     .Output("handle: variant")
592     .Attr("f: func")
593     .Attr("Targuments: list(type) >= 0")
594     .Attr("output_types: list(type) >= 1")
595     .Attr("output_shapes: list(shape) >= 1")
596     .Attr("preserve_cardinality: bool = false")
597     .Attr("metadata: string = ''")
598     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
599                                                            "output_types"))
__anon46ca241c0e02(shape_inference::InferenceContext* c) 600     .SetShapeFn([](shape_inference::InferenceContext* c) {
601       // Use index from the end to retrieve the Input shapes,
602       // so that to avoid guessing the length of "other_arguments".
603       // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
604       shape_inference::ShapeHandle unused;
605       TF_RETURN_IF_ERROR(
606           c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
607       TF_RETURN_IF_ERROR(
608           c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
609       TF_RETURN_IF_ERROR(
610           c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
611 
612       return shape_inference::ScalarShape(c);
613     });
614 
615 REGISTER_OP("ExperimentalMapAndBatchDataset")
616     .Input("input_dataset: variant")
617     .Input("other_arguments: Targuments")
618     .Input("batch_size: int64")
619     .Input("num_parallel_calls: int64")
620     .Input("drop_remainder: bool")
621     .Output("handle: variant")
622     .Attr("f: func")
623     .Attr("Targuments: list(type) >= 0")
624     .Attr("output_types: list(type) >= 1")
625     .Attr("output_shapes: list(shape) >= 1")
626     .Attr("preserve_cardinality: bool = false")
627     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
628                                                            "output_types"))
__anon46ca241c0f02(shape_inference::InferenceContext* c) 629     .SetShapeFn([](shape_inference::InferenceContext* c) {
630       // Use index from the end to retrieve the Input shapes,
631       // so that to avoid guessing the length of "other_arguments".
632       // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
633       shape_inference::ShapeHandle unused;
634       TF_RETURN_IF_ERROR(
635           c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
636       TF_RETURN_IF_ERROR(
637           c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
638       TF_RETURN_IF_ERROR(
639           c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
640 
641       return shape_inference::ScalarShape(c);
642     });
643 
644 REGISTER_OP("ExperimentalMapDataset")
645     .Input("input_dataset: variant")
646     .Input("other_arguments: Targuments")
647     .Output("handle: variant")
648     .Attr("f: func")
649     .Attr("Targuments: list(type) >= 0")
650     .Attr("output_types: list(type) >= 1")
651     .Attr("output_shapes: list(shape) >= 1")
652     .Attr("use_inter_op_parallelism: bool = true")
653     .Attr("preserve_cardinality: bool = false")
654     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
655                                                            "output_types"))
656     .SetShapeFn(shape_inference::ScalarShape);
657 
658 REGISTER_OP("MatchingFilesDataset")
659     .Input("patterns: string")
660     .Output("handle: variant")
661     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
662     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
663                                                         TFT_STRING))
__anon46ca241c1002(shape_inference::InferenceContext* c) 664     .SetShapeFn([](shape_inference::InferenceContext* c) {
665       shape_inference::ShapeHandle unused;
666       // `patterns` must be a scalar or a vector.
667       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
668       return shape_inference::ScalarShape(c);
669     });
670 
671 REGISTER_OP("ExperimentalMatchingFilesDataset")
672     .Input("patterns: string")
673     .Output("handle: variant")
674     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
675     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
676                                                         TFT_STRING))
__anon46ca241c1102(shape_inference::InferenceContext* c) 677     .SetShapeFn([](shape_inference::InferenceContext* c) {
678       shape_inference::ShapeHandle unused;
679       // `patterns` must be a scalar or a vector.
680       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
681       return shape_inference::ScalarShape(c);
682     });
683 
684 REGISTER_OP("MaxIntraOpParallelismDataset")
685     .Input("input_dataset: variant")
686     .Input("max_intra_op_parallelism: int64")
687     .Output("handle: variant")
688     .Attr("output_types: list(type) >= 1")
689     .Attr("output_shapes: list(shape) >= 1")
690     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
691                                                            "output_types"))
692     .SetShapeFn(shape_inference::ScalarShape);
693 
694 REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
695     .Input("input_dataset: variant")
696     .Input("max_intra_op_parallelism: int64")
697     .Output("handle: variant")
698     .Attr("output_types: list(type) >= 1")
699     .Attr("output_shapes: list(shape) >= 1")
700     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
701                                                            "output_types"))
702     .SetShapeFn(shape_inference::ScalarShape);
703 
704 REGISTER_OP("NonSerializableDataset")
705     .Input("input_dataset: variant")
706     .Output("handle: variant")
707     .Attr("output_types: list(type) >= 1")
708     .Attr("output_shapes: list(shape) >= 1")
709     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
710                                                            "output_types"))
711     .SetShapeFn(shape_inference::ScalarShape);
712 
713 REGISTER_OP("ExperimentalNonSerializableDataset")
714     .Input("input_dataset: variant")
715     .Output("handle: variant")
716     .Attr("output_types: list(type) >= 1")
717     .Attr("output_shapes: list(shape) >= 1")
718     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
719                                                            "output_types"))
720     .SetShapeFn(shape_inference::ScalarShape);
721 
722 REGISTER_OP("ParallelInterleaveDataset")
723     .Input("input_dataset: variant")
724     .Input("other_arguments: Targuments")
725     .Input("cycle_length: int64")
726     .Input("block_length: int64")
727     .Input("sloppy: bool")
728     .Input("buffer_output_elements: int64")
729     .Input("prefetch_input_elements: int64")
730     .Output("handle: variant")
731     .Attr("f: func")
732     .Attr("Targuments: list(type) >= 0")
733     .Attr("output_types: list(type) >= 1")
734     .Attr("output_shapes: list(shape) >= 1")
735     .Attr("metadata: string = ''")
736     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
737                                                            "output_types"))
738     .SetShapeFn(shape_inference::ScalarShape);
739 
740 // This is the V2 of ParallelInterleaveDataset, renamed to differentiate it
741 // from the non-experimental ParallelInterleaveDataset op.
742 REGISTER_OP("LegacyParallelInterleaveDatasetV2")
743     .Input("input_dataset: variant")
744     .Input("other_arguments: Targuments")
745     .Input("cycle_length: int64")
746     .Input("block_length: int64")
747     .Input("buffer_output_elements: int64")
748     .Input("prefetch_input_elements: int64")
749     .Output("handle: variant")
750     .Attr("f: func")
751     // "true", "false", or "default".
752     .Attr("deterministic: string = 'default'")
753     .Attr("Targuments: list(type) >= 0")
754     .Attr("output_types: list(type) >= 1")
755     .Attr("output_shapes: list(shape) >= 1")
756     .Attr("metadata: string = ''")
757     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
758                                                            "output_types"))
759     .SetShapeFn(shape_inference::ScalarShape);
760 
761 // This op is no longer used. We keep it so that we can read graphs written by
762 // old versions of TensorFlow.
763 REGISTER_OP("ExperimentalParallelInterleaveDataset")
764     .Input("input_dataset: variant")
765     .Input("other_arguments: Targuments")
766     .Input("cycle_length: int64")
767     .Input("block_length: int64")
768     .Input("sloppy: bool")
769     .Input("buffer_output_elements: int64")
770     .Input("prefetch_input_elements: int64")
771     .Output("handle: variant")
772     .Attr("f: func")
773     .Attr("Targuments: list(type) >= 0")
774     .Attr("output_types: list(type) >= 1")
775     .Attr("output_shapes: list(shape) >= 1")
776     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
777                                                            "output_types"))
778     .SetShapeFn(shape_inference::ScalarShape);
779 
780 REGISTER_OP("ParseExampleDataset")
781     .Input("input_dataset: variant")
782     .Input("num_parallel_calls: int64")
783     .Input("dense_defaults: Tdense")
784     .Output("handle: variant")
785     .Attr("sparse_keys: list(string) >= 0")
786     .Attr("dense_keys: list(string) >= 0")
787     .Attr("sparse_types: list({float,int64,string}) >= 0")
788     .Attr("Tdense: list({float,int64,string}) >= 0")
789     .Attr("dense_shapes: list(shape) >= 0")
790     .Attr("output_types: list(type) >= 1")
791     .Attr("output_shapes: list(shape) >= 1")  // Output components will be
792                                               // sorted by key (dense_keys and
793                                               // sparse_keys combined) here.
794     .Attr("sloppy: bool = false")
795     .Attr("ragged_keys: list(string) >= 0 = []")
796     .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
797     .Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
798     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
799                                                            "output_types"))
800     .SetShapeFn(shape_inference::ScalarShape);
801 
802 REGISTER_OP("ParseExampleDatasetV2")
803     .Input("input_dataset: variant")
804     .Input("num_parallel_calls: int64")
805     .Input("dense_defaults: Tdense")
806     .Output("handle: variant")
807     .Attr("sparse_keys: list(string) >= 0")
808     .Attr("dense_keys: list(string) >= 0")
809     .Attr("sparse_types: list({float,int64,string}) >= 0")
810     .Attr("Tdense: list({float,int64,string}) >= 0")
811     .Attr("dense_shapes: list(shape) >= 0")
812     .Attr("output_types: list(type) >= 1")
813     .Attr("output_shapes: list(shape) >= 1")  // Output components will be
814                                               // sorted by key (dense_keys and
815                                               // sparse_keys combined) here.
816     // "true", "false", or "default".
817     .Attr("deterministic: string = 'default'")
818     .Attr("ragged_keys: list(string) >= 0 = []")
819     .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
820     .Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
821     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
822                                                            "output_types"))
823     .SetShapeFn(shape_inference::ScalarShape);
824 
825 REGISTER_OP("ExperimentalParseExampleDataset")
826     .Input("input_dataset: variant")
827     .Input("num_parallel_calls: int64")
828     .Input("dense_defaults: Tdense")
829     .Output("handle: variant")
830     .Attr("sparse_keys: list(string) >= 0")
831     .Attr("dense_keys: list(string) >= 0")
832     .Attr("sparse_types: list({float,int64,string}) >= 0")
833     .Attr("Tdense: list({float,int64,string}) >= 0")
834     .Attr("dense_shapes: list(shape) >= 0")
835     .Attr("output_types: list(type) >= 1")
836     .Attr("output_shapes: list(shape) >= 1")  // Output components will be
837                                               // sorted by key (dense_keys and
838                                               // sparse_keys combined) here.
839     .Attr("sloppy: bool = false")
840     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
841                                                            "output_types"))
842     .SetShapeFn(shape_inference::ScalarShape);
843 
844 REGISTER_OP("PrivateThreadPoolDataset")
845     .Input("input_dataset: variant")
846     .Input("num_threads: int64")
847     .Output("handle: variant")
848     .Attr("output_types: list(type) >= 1")
849     .Attr("output_shapes: list(shape) >= 1")
850     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
851                                                            "output_types"))
852     .SetShapeFn(shape_inference::ScalarShape);
853 
854 REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
855     .Input("input_dataset: variant")
856     .Input("num_threads: int64")
857     .Output("handle: variant")
858     .Attr("output_types: list(type) >= 1")
859     .Attr("output_shapes: list(shape) >= 1")
860     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
861                                                            "output_types"))
862     .SetShapeFn(shape_inference::ScalarShape);
863 
864 REGISTER_OP("ExperimentalRandomDataset")
865     .Input("seed: int64")
866     .Input("seed2: int64")
867     .Output("handle: variant")
868     .Attr("output_types: list(type) >= 1")
869     .Attr("output_shapes: list(shape) >= 1")
870     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
871     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
872                                                            "output_types"))
__anon46ca241c1202(shape_inference::InferenceContext* c) 873     .SetShapeFn([](shape_inference::InferenceContext* c) {
874       shape_inference::ShapeHandle unused;
875       // buffer_size, seed, and seed2 should be scalars.
876       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
877       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
878       return shape_inference::ScalarShape(c);
879     });
880 
881 REGISTER_OP("RandomDataset")
882     .Input("seed: int64")
883     .Input("seed2: int64")
884     .Output("handle: variant")
885     .Attr("output_types: list(type) >= 1")
886     .Attr("output_shapes: list(shape) >= 1")
887     .Attr("metadata: string = ''")
888     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
889     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
890                                                            "output_types"))
__anon46ca241c1302(shape_inference::InferenceContext* c) 891     .SetShapeFn([](shape_inference::InferenceContext* c) {
892       shape_inference::ShapeHandle unused;
893       // buffer_size, seed, and seed2 should be scalars.
894       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
895       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
896       return shape_inference::ScalarShape(c);
897     });
898 
899 REGISTER_OP("ExperimentalRebatchDataset")
900     .Input("input_dataset: variant")
901     .Input("num_replicas: int64")
902     .Output("handle: variant")
903     .Attr("output_types: list(type) >= 1")
904     .Attr("output_shapes: list(shape) >= 1")
905     .Attr("use_fallback: bool = true")
906     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
907                                                            "output_types"))
908     .SetShapeFn(shape_inference::ScalarShape);
909 
910 REGISTER_OP("RebatchDataset")
911     .Input("input_dataset: variant")
912     .Input("num_replicas: int64")
913     .Output("handle: variant")
914     .Attr("output_types: list(type) >= 1")
915     .Attr("output_shapes: list(shape) >= 1")
916     .Attr("use_fallback: bool = true")
917     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
918                                                            "output_types"))
919     .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
920                                               full_type::BatchTensor))
921     .SetShapeFn(shape_inference::ScalarShape);
922 
923 REGISTER_OP("RebatchDatasetV2")
924     .Input("input_dataset: variant")
925     .Input("batch_sizes: int64")
926     .Input("drop_remainder: bool")
927     .Output("handle: variant")
928     .Attr("output_types: list(type) >= 1")
929     .Attr("output_shapes: list(shape) >= 1")
930     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
931                                                            "output_types"))
932     .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
933                                               full_type::BatchTensor))
934     .SetShapeFn(shape_inference::ScalarShape);
935 
936 REGISTER_OP("SamplingDataset")
937     .Input("input_dataset: variant")
938     .Input("rate: float32")
939     .Input("seed: int64")
940     .Input("seed2: int64")
941     .Output("handle: variant")
942     .Attr("output_types: list(type) >= 1")
943     .Attr("output_shapes: list(shape) >= 1")
944     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
945                                                            "output_types"))
__anon46ca241c1402(shape_inference::InferenceContext* c) 946     .SetShapeFn([](shape_inference::InferenceContext* c) {
947       shape_inference::ShapeHandle unused;
948       // rate, seed, and seed2 should be scalars.
949       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
950       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
951       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
952       return shape_inference::ScalarShape(c);
953     });
954 
955 REGISTER_OP("ScanDataset")
956     .Input("input_dataset: variant")
957     .Input("initial_state: Tstate")
958     .Input("other_arguments: Targuments")
959     .Output("handle: variant")
960     .Attr("f: func")
961     .Attr("Tstate: list(type) >= 1")
962     .Attr("Targuments: list(type) >= 0")
963     .Attr("output_types: list(type) >= 1")
964     .Attr("output_shapes: list(shape) >= 1")
965     .Attr("preserve_cardinality: bool = false")
966     .Attr("use_default_device: bool = true")
967     .Attr("metadata: string = ''")
968     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
969                                                            "output_types"))
970     .SetShapeFn(shape_inference::ScalarShape);
971 
972 REGISTER_OP("ExperimentalScanDataset")
973     .Input("input_dataset: variant")
974     .Input("initial_state: Tstate")
975     .Input("other_arguments: Targuments")
976     .Output("handle: variant")
977     .Attr("f: func")
978     .Attr("Tstate: list(type) >= 1")
979     .Attr("Targuments: list(type) >= 0")
980     .Attr("output_types: list(type) >= 1")
981     .Attr("output_shapes: list(shape) >= 1")
982     .Attr("preserve_cardinality: bool = false")
983     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
984                                                            "output_types"))
985     .SetShapeFn(shape_inference::ScalarShape);
986 
987 REGISTER_OP("SetStatsAggregatorDataset")
988     .Input("input_dataset: variant")
989     .Input("stats_aggregator: resource")
990     .Input("tag: string")
991     .Input("counter_prefix: string")
992     .Output("handle: variant")
993     .Attr("output_types: list(type) >= 1")
994     .Attr("output_shapes: list(shape) >= 1")
995     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
996                                                            "output_types"))
997     .SetShapeFn(shape_inference::ScalarShape);
998 
999 REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
1000     .Input("input_dataset: variant")
1001     .Input("stats_aggregator: resource")
1002     .Input("tag: string")
1003     .Input("counter_prefix: string")
1004     .Output("handle: variant")
1005     .Attr("output_types: list(type) >= 1")
1006     .Attr("output_shapes: list(shape) >= 1")
1007     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1008                                                            "output_types"))
1009     .SetShapeFn(shape_inference::ScalarShape);
1010 
1011 REGISTER_OP("SleepDataset")
1012     .Input("input_dataset: variant")
1013     .Input("sleep_microseconds: int64")
1014     .Output("handle: variant")
1015     .Attr("output_types: list(type) >= 1")
1016     .Attr("output_shapes: list(shape) >= 1")
1017     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1018                                                            "output_types"))
__anon46ca241c1502(shape_inference::InferenceContext* c) 1019     .SetShapeFn([](shape_inference::InferenceContext* c) {
1020       shape_inference::ShapeHandle unused;
1021       // Both inputs are scalar.
1022       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
1023       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
1024       return shape_inference::ScalarShape(c);
1025     });
1026 
1027 REGISTER_OP("ExperimentalSleepDataset")
1028     .Input("input_dataset: variant")
1029     .Input("sleep_microseconds: int64")
1030     .Output("handle: variant")
1031     .Attr("output_types: list(type) >= 1")
1032     .Attr("output_shapes: list(shape) >= 1")
1033     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1034                                                            "output_types"))
__anon46ca241c1602(shape_inference::InferenceContext* c) 1035     .SetShapeFn([](shape_inference::InferenceContext* c) {
1036       shape_inference::ShapeHandle unused;
1037       // Both inputs are scalar.
1038       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
1039       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
1040       return shape_inference::ScalarShape(c);
1041     });
1042 
1043 REGISTER_OP("SlidingWindowDataset")
1044     .Input("input_dataset: variant")
1045     .Input("window_size: int64")
1046     .Input("window_shift: int64")
1047     .Input("window_stride: int64")
1048     .Output("handle: variant")
1049     .Attr("drop_remainder: bool = true")
1050     .Attr("output_types: list(type) >= 1")
1051     .Attr("output_shapes: list(shape) >= 1")
1052     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1053                                                            "output_types"))
__anon46ca241c1702(shape_inference::InferenceContext* c) 1054     .SetShapeFn([](shape_inference::InferenceContext* c) {
1055       shape_inference::ShapeHandle unused;
1056       // window_size, window_shift, and window_stride should be scalars.
1057       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1058       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1059       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1060       return shape_inference::ScalarShape(c);
1061     });
1062 
1063 REGISTER_OP("ExperimentalSlidingWindowDataset")
1064     .Input("input_dataset: variant")
1065     .Input("window_size: int64")
1066     .Input("window_shift: int64")
1067     .Input("window_stride: int64")
1068     .Output("handle: variant")
1069     .Attr("output_types: list(type) >= 1")
1070     .Attr("output_shapes: list(shape) >= 1")
1071     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1072                                                            "output_types"))
__anon46ca241c1802(shape_inference::InferenceContext* c) 1073     .SetShapeFn([](shape_inference::InferenceContext* c) {
1074       shape_inference::ShapeHandle unused;
1075       // window_size, window_shift, and window_stride should be scalars.
1076       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1077       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1078       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1079       return shape_inference::ScalarShape(c);
1080     });
1081 
1082 REGISTER_OP("SnapshotDataset")
1083     .Input("input_dataset: variant")
1084     .Input("path: string")
1085     .Output("handle: variant")
1086     .Attr("output_types: list(type) >= 1")
1087     .Attr("output_shapes: list(shape) >= 1")
1088     .Attr("compression: string = ''")
1089     .Attr("reader_path_prefix: string = ''")
1090     .Attr("writer_path_prefix: string = ''")
1091     .Attr("shard_size_bytes: int = 10737418240")           // 10 GiB default
1092     .Attr("pending_snapshot_expiry_seconds: int = 86400")  // 1 day default
1093     .Attr("num_reader_threads: int = 1")
1094     .Attr("reader_buffer_size: int = 1")
1095     .Attr("num_writer_threads: int = 1")
1096     .Attr("writer_buffer_size: int = 1")
1097     .Attr("shuffle_on_read: bool = false")
1098     .Attr("seed: int = 0")
1099     .Attr("seed2: int = 0")
1100     .Attr("mode: string = 'auto'")
1101     .Attr("snapshot_name: string = ''")
1102     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1103                                                            "output_types"))
__anon46ca241c1902(shape_inference::InferenceContext* c) 1104     .SetShapeFn([](shape_inference::InferenceContext* c) {
1105       shape_inference::ShapeHandle unused;
1106       // snapshot_path should be a scalar.
1107       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1108       return shape_inference::ScalarShape(c);
1109     });
1110 
1111 REGISTER_OP("SnapshotDatasetV2")
1112     .Input("input_dataset: variant")
1113     .Input("path: string")
1114     .Input("reader_func_other_args: Treader_func_args")
1115     .Input("shard_func_other_args: Tshard_func_args")
1116     .Output("handle: variant")
1117     .Attr("output_types: list(type) >= 1")
1118     .Attr("output_shapes: list(shape) >= 1")
1119     .Attr("compression: string = ''")
1120     .Attr("reader_prefix: string = ''")
1121     .Attr("writer_prefix: string = ''")
1122     .Attr("hash_valid: bool = false")
1123     .Attr("hash: int = 0")
1124     .Attr("reader_func: func")
1125     .Attr("shard_func: func")
1126     .Attr("Treader_func_args: list(type) >= 0")
1127     .Attr("Tshard_func_args: list(type) >= 0")
1128     .Attr("metadata: string = ''")
1129     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1130                                                            "output_types"))
__anon46ca241c1a02(shape_inference::InferenceContext* c) 1131     .SetShapeFn([](shape_inference::InferenceContext* c) {
1132       shape_inference::ShapeHandle unused;
1133       // `path` should be a scalar.
1134       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1135       return shape_inference::ScalarShape(c);
1136     });
1137 
1138 REGISTER_OP("SaveDataset")
1139     .Input("input_dataset: variant")
1140     .Input("path: string")
1141     .Input("shard_func_other_args: Tshard_func_args")
1142     .Attr("compression: string = ''")
1143     .Attr("shard_func: func")
1144     .Attr("use_shard_func: bool = true")
1145     .Attr("Tshard_func_args: list(type) >= 0")
1146     .SetIsStateful()
__anon46ca241c1b02(shape_inference::InferenceContext* c) 1147     .SetShapeFn([](shape_inference::InferenceContext* c) {
1148       shape_inference::ShapeHandle unused;
1149       // `path` should be a scalar.
1150       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1151       return OkStatus();
1152     });
1153 
1154 REGISTER_OP("SaveDatasetV2")
1155     .Input("input_dataset: variant")
1156     .Input("path: string")
1157     .Input("shard_func_other_args: Tshard_func_args")
1158     .Output("handle: variant")
1159     .Attr("compression: string = ''")
1160     .Attr("shard_func: func")
1161     .Attr("use_shard_func: bool = true")
1162     .Attr("Tshard_func_args: list(type) >= 0")
1163     .Attr("output_types: list(type) >= 1")
1164     .Attr("output_shapes: list(shape) >= 1")
1165     .SetIsStateful()
__anon46ca241c1c02(shape_inference::InferenceContext* c) 1166     .SetShapeFn([](shape_inference::InferenceContext* c) {
1167       shape_inference::ShapeHandle unused;
1168       // `path` should be a scalar.
1169       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1170       return shape_inference::ScalarShape(c);
1171     });
1172 
1173 REGISTER_OP("LoadDataset")
1174     .Input("path: string")
1175     .Input("reader_func_other_args: Treader_func_args")
1176     .Output("handle: variant")
1177     .Attr("output_types: list(type) >= 1")
1178     .Attr("output_shapes: list(shape) >= 1")
1179     .Attr("compression: string = ''")
1180     .Attr("reader_func: func")
1181     .Attr("Treader_func_args: list(type) >= 0")
1182     .SetIsStateful()
1183     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1184                                                            "output_types"))
__anon46ca241c1d02(shape_inference::InferenceContext* c) 1185     .SetShapeFn([](shape_inference::InferenceContext* c) {
1186       shape_inference::ShapeHandle unused;
1187       // `path` should be a scalar.
1188       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1189       return shape_inference::ScalarShape(c);
1190     });
1191 
1192 REGISTER_OP("SnapshotDatasetReader")
1193     .Input("shard_dir: string")
1194     .Input("start_index: int64")
1195     .Output("handle: variant")
1196     .Attr("output_types: list(type) >= 1")
1197     .Attr("output_shapes: list(shape) >= 1")
1198     .Attr("compression: string = ''")
1199     .Attr("version: int")
1200     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1201                                                            "output_types"))
__anon46ca241c1e02(shape_inference::InferenceContext* c) 1202     .SetShapeFn([](shape_inference::InferenceContext* c) {
1203       shape_inference::ShapeHandle unused;
1204       // `shard_dir` should be a scalar.
1205       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1206       // `start_index` should be a scalar.
1207       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1208       return shape_inference::ScalarShape(c);
1209     });
1210 
1211 REGISTER_OP("SnapshotNestedDatasetReader")
1212     .Input("inputs: N * variant")
1213     .Output("handle: variant")
1214     .Attr("output_types: list(type) >= 1")
1215     .Attr("output_shapes: list(shape) >= 1")
1216     .Attr("N: int >= 1")
1217     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1218                                                            "output_types"))
1219     .SetShapeFn(shape_inference::ScalarShape);
1220 
1221 REGISTER_OP("SqlDataset")
1222     .Input("driver_name: string")
1223     .Input("data_source_name: string")
1224     .Input("query: string")
1225     .Output("handle: variant")
1226     .Attr("output_types: list(type) >= 1")
1227     .Attr("output_shapes: list(shape) >= 1")
1228     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
1229     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1230                                                            "output_types"))
__anon46ca241c1f02(shape_inference::InferenceContext* c) 1231     .SetShapeFn([](shape_inference::InferenceContext* c) {
1232       shape_inference::ShapeHandle unused;
1233       // driver_name, data_source_name, and query should be scalars.
1234       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1235       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1236       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1237       return shape_inference::ScalarShape(c);
1238     });
1239 
1240 REGISTER_OP("ExperimentalSqlDataset")
1241     .Input("driver_name: string")
1242     .Input("data_source_name: string")
1243     .Input("query: string")
1244     .Output("handle: variant")
1245     .Attr("output_types: list(type) >= 1")
1246     .Attr("output_shapes: list(shape) >= 1")
1247     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
1248     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1249                                                            "output_types"))
__anon46ca241c2002(shape_inference::InferenceContext* c) 1250     .SetShapeFn([](shape_inference::InferenceContext* c) {
1251       shape_inference::ShapeHandle unused;
1252       // driver_name, data_source_name, and query should be scalars.
1253       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1254       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1255       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1256       return shape_inference::ScalarShape(c);
1257     });
1258 
1259 REGISTER_OP("StatsAggregatorHandle")
1260     .Output("handle: resource")
1261     .SetShapeFn(shape_inference::ScalarShape)
1262     .Attr("container: string = ''")
1263     .Attr("shared_name: string = ''");
1264 
1265 REGISTER_OP("ExperimentalStatsAggregatorHandle")
1266     .Output("handle: resource")
1267     .SetShapeFn(shape_inference::ScalarShape)
1268     .Attr("container: string = ''")
1269     .Attr("shared_name: string = ''");
1270 
1271 REGISTER_OP("StatsAggregatorHandleV2")
1272     .Output("handle: resource")
1273     .SetShapeFn(shape_inference::ScalarShape)
1274     .Attr("container: string = ''")
1275     .Attr("shared_name: string = ''");
1276 
1277 REGISTER_OP("StatsAggregatorSetSummaryWriter")
1278     .Input("stats_aggregator: resource")
1279     .Input("summary: resource")
1280     .SetShapeFn(shape_inference::NoOutputs);
1281 
1282 REGISTER_OP("StatsAggregatorSummary")
1283     .Input("iterator: resource")
1284     .Output("summary: string")
1285     .SetShapeFn(shape_inference::ScalarShape);
1286 
1287 REGISTER_OP("ExperimentalStatsAggregatorSummary")
1288     .Input("iterator: resource")
1289     .Output("summary: string")
1290     .SetShapeFn(shape_inference::ScalarShape);
1291 
1292 REGISTER_OP("TakeWhileDataset")
1293     .Input("input_dataset: variant")
1294     .Input("other_arguments: Targuments")
1295     .Output("handle: variant")
1296     .Attr("predicate: func")
1297     .Attr("Targuments: list(type) >= 0")
1298     .Attr("output_types: list(type) >= 1")
1299     .Attr("output_shapes: list(shape) >= 1")
1300     .Attr("metadata: string = ''")
1301     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1302                                                            "output_types"))
1303     .SetShapeFn(shape_inference::ScalarShape);
1304 
1305 REGISTER_OP("ExperimentalTakeWhileDataset")
1306     .Input("input_dataset: variant")
1307     .Input("other_arguments: Targuments")
1308     .Output("handle: variant")
1309     .Attr("predicate: func")
1310     .Attr("Targuments: list(type) >= 0")
1311     .Attr("output_types: list(type) >= 1")
1312     .Attr("output_shapes: list(shape) >= 1")
1313     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1314                                                            "output_types"))
1315     .SetShapeFn(shape_inference::ScalarShape);
1316 
1317 REGISTER_OP("ThreadPoolDataset")
1318     .Input("input_dataset: variant")
1319     .Input("thread_pool: resource")
1320     .Output("handle: variant")
1321     .Attr("output_types: list(type) >= 1")
1322     .Attr("output_shapes: list(shape) >= 1")
1323     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1324                                                            "output_types"))
1325     .SetShapeFn(shape_inference::ScalarShape);
1326 
1327 REGISTER_OP("ExperimentalThreadPoolDataset")
1328     .Input("input_dataset: variant")
1329     .Input("thread_pool: resource")
1330     .Output("handle: variant")
1331     .Attr("output_types: list(type) >= 1")
1332     .Attr("output_shapes: list(shape) >= 1")
1333     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1334                                                            "output_types"))
1335     .SetShapeFn(shape_inference::ScalarShape);
1336 
1337 REGISTER_OP("ThreadPoolHandle")
1338     .Output("handle: resource")
1339     .SetShapeFn(shape_inference::ScalarShape)
1340     .Attr("num_threads: int")
1341     .Attr("max_intra_op_parallelism: int = 1")
1342     .Attr("display_name: string")
1343     .Attr("container: string = ''")
1344     .Attr("shared_name: string = ''");
1345 
1346 REGISTER_OP("ExperimentalThreadPoolHandle")
1347     .Output("handle: resource")
1348     .SetShapeFn(shape_inference::ScalarShape)
1349     .Attr("num_threads: int")
1350     .Attr("max_intra_op_parallelism: int = 1")
1351     .Attr("display_name: string")
1352     .Attr("container: string = ''")
1353     .Attr("shared_name: string = ''");
1354 
1355 REGISTER_OP("UnbatchDataset")
1356     .Input("input_dataset: variant")
1357     .Output("handle: variant")
1358     .Attr("output_types: list(type) >= 1")
1359     .Attr("output_shapes: list(shape) >= 1")
1360     .Attr("metadata: string = ''")
1361     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1362                                                            "output_types"))
1363     .SetShapeFn(shape_inference::ScalarShape);
1364 
1365 REGISTER_OP("ExperimentalUnbatchDataset")
1366     .Input("input_dataset: variant")
1367     .Output("handle: variant")
1368     .Attr("output_types: list(type) >= 1")
1369     .Attr("output_shapes: list(shape) >= 1")
1370     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1371                                                            "output_types"))
1372     .SetShapeFn(shape_inference::ScalarShape);
1373 
1374 REGISTER_OP("UniqueDataset")
1375     .Input("input_dataset: variant")
1376     .Output("handle: variant")
1377     .Attr("output_types: list(type) >= 1")
1378     .Attr("output_shapes: list(shape) >= 1")
1379     .Attr("metadata: string = ''")
1380     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1381                                                            "output_types"))
1382     .SetShapeFn(shape_inference::ScalarShape);
1383 
1384 REGISTER_OP("ExperimentalUniqueDataset")
1385     .Input("input_dataset: variant")
1386     .Output("handle: variant")
1387     .Attr("output_types: list(type) >= 1")
1388     .Attr("output_shapes: list(shape) >= 1")
1389     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1390                                                            "output_types"))
1391     .SetShapeFn(shape_inference::ScalarShape);
1392 
1393 REGISTER_OP("DummyIterationCounter")
1394     .Output("handle: resource")
__anon46ca241c2102(shape_inference::InferenceContext* c) 1395     .SetShapeFn([](shape_inference::InferenceContext* c) {
1396       c->set_output(0, c->Scalar());
1397       return OkStatus();
1398     });
1399 
1400 REGISTER_OP("DataServiceDataset")
1401     .Input("dataset_id: int64")
1402     .Input("processing_mode: string")
1403     .Input("address: string")
1404     .Input("protocol: string")
1405     .Input("job_name: string")
1406     .Input("max_outstanding_requests: int64")
1407     .Input("iteration_counter: resource")
1408     .Output("handle: variant")
1409     .Attr("task_refresh_interval_hint_ms: int = -1")
1410     .Attr("output_types: list(type) >= 1")
1411     .Attr("output_shapes: list(shape) >= 1")
1412     .Attr("data_transfer_protocol: string = ''")
1413     .Attr("target_workers: string = 'AUTO'")
1414     .Attr("cross_trainer_cache_options: string = ''")
1415     .SetIsStateful()
1416     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1417                                                            "output_types"))
1418     .SetShapeFn(shape_inference::ScalarShape);
1419 
1420 // Adds `consumer_index` and `num_consumers` arguments to support round-robin
1421 // reads.
1422 REGISTER_OP("DataServiceDatasetV2")
1423     .Input("dataset_id: int64")
1424     .Input("processing_mode: string")
1425     .Input("address: string")
1426     .Input("protocol: string")
1427     .Input("job_name: string")
1428     .Input("consumer_index: int64")
1429     .Input("num_consumers: int64")
1430     .Input("max_outstanding_requests: int64")
1431     .Input("iteration_counter: resource")
1432     .Output("handle: variant")
1433     .Attr("task_refresh_interval_hint_ms: int = -1")
1434     .Attr("output_types: list(type) >= 1")
1435     .Attr("output_shapes: list(shape) >= 1")
1436     .Attr("data_transfer_protocol: string = ''")
1437     .Attr("target_workers: string = 'AUTO'")
1438     .Attr("cross_trainer_cache_options: string = ''")
1439     .SetIsStateful()
1440     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1441                                                            "output_types"))
1442     .SetShapeFn(shape_inference::ScalarShape);
1443 
1444 // Adds `uncompress` and `uncompress_fn` attributes to support uncompression.
1445 REGISTER_OP("DataServiceDatasetV3")
1446     .Input("dataset_id: int64")
1447     .Input("processing_mode: string")
1448     .Input("address: string")
1449     .Input("protocol: string")
1450     .Input("job_name: string")
1451     .Input("consumer_index: int64")
1452     .Input("num_consumers: int64")
1453     .Input("max_outstanding_requests: int64")
1454     .Input("iteration_counter: resource")
1455     .Output("handle: variant")
1456     .Attr("task_refresh_interval_hint_ms: int = -1")
1457     .Attr("output_types: list(type) >= 1")
1458     .Attr("output_shapes: list(shape) >= 1")
1459     .Attr("data_transfer_protocol: string = ''")
1460     .Attr("target_workers: string = 'AUTO'")
1461     .Attr("uncompress: bool = false")
1462     .Attr("uncompress_fn: func")
1463     .Attr("cross_trainer_cache_options: string = ''")
1464     .SetIsStateful()
1465     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1466                                                            "output_types"))
1467     .SetShapeFn(shape_inference::ScalarShape);
1468 
1469 // Changes `dataset_id` from int64 to string.
1470 REGISTER_OP("DataServiceDatasetV4")
1471     .Input("dataset_id: string")
1472     .Input("processing_mode: string")
1473     .Input("address: string")
1474     .Input("protocol: string")
1475     .Input("job_name: string")
1476     .Input("consumer_index: int64")
1477     .Input("num_consumers: int64")
1478     .Input("max_outstanding_requests: int64")
1479     .Input("iteration_counter: resource")
1480     .Output("handle: variant")
1481     .Attr("task_refresh_interval_hint_ms: int = -1")
1482     .Attr("output_types: list(type) >= 1")
1483     .Attr("output_shapes: list(shape) >= 1")
1484     .Attr("data_transfer_protocol: string = ''")
1485     .Attr("target_workers: string = 'AUTO'")
1486     .Attr("uncompress: bool = false")
1487     .Attr("uncompress_fn: func")
1488     .Attr("cross_trainer_cache_options: string = ''")
1489     .SetIsStateful()
1490     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1491                                                            "output_types"))
1492     .SetShapeFn(shape_inference::ScalarShape);
1493 
1494 REGISTER_OP("RegisterDataset")
1495     .Input("dataset: variant")
1496     .Input("address: string")
1497     .Input("protocol: string")
1498     .Output("dataset_id: int64")
1499     .Attr("external_state_policy: int")
1500     .Attr("element_spec: string = ''")
1501     .Attr("metadata: string = ''")
1502     .SetShapeFn(shape_inference::ScalarShape);
1503 
1504 // Changes `dataset_id` from int64 to string.
1505 REGISTER_OP("RegisterDatasetV2")
1506     .Input("dataset: variant")
1507     .Input("address: string")
1508     .Input("protocol: string")
1509     .Output("dataset_id: string")
1510     .Attr("external_state_policy: int")
1511     .Attr("element_spec: string = ''")
1512     .Attr("requested_dataset_id: string = ''")
1513     .Attr("metadata: string = ''")
1514     .SetShapeFn(shape_inference::ScalarShape);
1515 
1516 REGISTER_OP("InitializeTableFromDataset")
1517     .Input("table_handle: resource")
1518     .Input("dataset: variant")
__anon46ca241c2202(shape_inference::InferenceContext* c) 1519     .SetShapeFn([](shape_inference::InferenceContext* c) {
1520       shape_inference::ShapeHandle handle;
1521       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
1522       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
1523       return OkStatus();
1524     });
1525 
1526 // - `output_types` is the types of tensors in a single dataset element.
1527 // - `output_shapes` is the shapes of tensors in a single dataset element.
1528 // - `output_types` and `output_shapes` are the same size: the number of
1529 // tensors in a single dataset element, a.k.a. the number of components.
1530 // - `Tinput_types` is the types of tensors for all dataset elements.
1531 // `Tinput_types` is equivalent to `output_types` repeated for N total dataset
1532 // elements.
1533 REGISTER_OP("ListDataset")
1534     .Input("tensors: Tinput_types")
1535     .Output("handle: variant")
1536     .Attr("Tinput_types: list(type) >= 1")
1537     .Attr("output_types: list(type) >= 1")
1538     .Attr("output_shapes: list(shape) >= 1")
1539     .Attr("metadata: string = ''")
1540     .SetDoNotOptimize()
1541     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1542                                                            "output_types"))
1543     .SetShapeFn(shape_inference::ScalarShape);
1544 
1545 }  // namespace tensorflow
1546