1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/full_type.pb.h" 17 #include "tensorflow/core/framework/op.h" 18 19 namespace tensorflow { 20 21 REGISTER_OP("AssertCardinalityDataset") 22 .Input("input_dataset: variant") 23 .Input("cardinality: int64") 24 .Output("handle: variant") 25 .Attr("output_types: list(type) >= 1") 26 .Attr("output_shapes: list(shape) >= 1") 27 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 28 "output_types")) __anon46ca241c0102(shape_inference::InferenceContext* c) 29 .SetShapeFn([](shape_inference::InferenceContext* c) { 30 shape_inference::ShapeHandle unused; 31 // cardinality should be a scalar. 32 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 33 return shape_inference::ScalarShape(c); 34 }); 35 36 REGISTER_OP("AssertNextDataset") 37 .Input("input_dataset: variant") 38 .Input("transformations: string") 39 .Output("handle: variant") 40 .Attr("output_types: list(type) >= 1") 41 .Attr("output_shapes: list(shape) >= 1") 42 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 43 "output_types")) __anon46ca241c0202(shape_inference::InferenceContext* c) 44 .SetShapeFn([](shape_inference::InferenceContext* c) { 45 shape_inference::ShapeHandle unused; 46 // transformations should be a vector. 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 48 return shape_inference::ScalarShape(c); 49 }); 50 51 REGISTER_OP("ExperimentalAssertNextDataset") 52 .Input("input_dataset: variant") 53 .Input("transformations: string") 54 .Output("handle: variant") 55 .Attr("output_types: list(type) >= 1") 56 .Attr("output_shapes: list(shape) >= 1") 57 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 58 "output_types")) __anon46ca241c0302(shape_inference::InferenceContext* c) 59 .SetShapeFn([](shape_inference::InferenceContext* c) { 60 shape_inference::ShapeHandle unused; 61 // transformations should be a vector. 62 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 63 return shape_inference::ScalarShape(c); 64 }); 65 66 REGISTER_OP("AssertPrevDataset") 67 .Input("input_dataset: variant") 68 .Input("transformations: string") 69 .Output("handle: variant") 70 .Attr("output_types: list(type) >= 1") 71 .Attr("output_shapes: list(shape) >= 1") 72 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 73 "output_types")) __anon46ca241c0402(shape_inference::InferenceContext* c) 74 .SetShapeFn([](shape_inference::InferenceContext* c) { 75 shape_inference::ShapeHandle unused; 76 // transformations should be a vector. 77 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 78 return shape_inference::ScalarShape(c); 79 }); 80 81 REGISTER_OP("AutoShardDataset") 82 .Input("input_dataset: variant") 83 .Input("num_workers: int64") 84 .Input("index: int64") 85 .Output("handle: variant") 86 .Attr("auto_shard_policy: int = 0") 87 .Attr("output_types: list(type) >= 1") 88 .Attr("output_shapes: list(shape) >= 1") 89 .Attr("num_replicas: int = 0") 90 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 91 "output_types")) 92 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, 93 full_type::ShardTensor)) 94 .SetShapeFn(shape_inference::ScalarShape); 95 96 REGISTER_OP("ExperimentalAutoShardDataset") 97 .Input("input_dataset: variant") 98 .Input("num_workers: int64") 99 .Input("index: int64") 100 .Output("handle: variant") 101 .Attr("auto_shard_policy: int = 0") 102 .Attr("output_types: list(type) >= 1") 103 .Attr("output_shapes: list(shape) >= 1") 104 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 105 "output_types")) 106 .SetShapeFn(shape_inference::ScalarShape); 107 108 REGISTER_OP("BytesProducedStatsDataset") 109 .Input("input_dataset: variant") 110 .Input("tag: string") 111 .Output("handle: variant") 112 .Attr("output_types: list(type) >= 1") 113 .Attr("output_shapes: list(shape) >= 1") 114 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 115 "output_types")) __anon46ca241c0502(shape_inference::InferenceContext* c) 116 .SetShapeFn([](shape_inference::InferenceContext* c) { 117 shape_inference::ShapeHandle tag_shape; 118 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); 119 return shape_inference::ScalarShape(c); 120 }); 121 122 REGISTER_OP("ExperimentalBytesProducedStatsDataset") 123 .Input("input_dataset: variant") 124 .Input("tag: string") 125 .Output("handle: variant") 126 .Attr("output_types: list(type) >= 1") 127 .Attr("output_shapes: list(shape) >= 1") 128 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 129 "output_types")) __anon46ca241c0602(shape_inference::InferenceContext* c) 130 .SetShapeFn([](shape_inference::InferenceContext* c) { 131 shape_inference::ShapeHandle tag_shape; 132 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); 133 return shape_inference::ScalarShape(c); 134 }); 135 136 REGISTER_OP("ChooseFastestBranchDataset") 137 .Input("input_dataset: variant") 138 .Input("ratio_numerator: int64") 139 .Input("ratio_denominator: int64") 140 .Input("other_arguments: Targuments") 141 .Output("handle: variant") 142 .Attr("Targuments: list(type) >= 0") 143 .Attr("num_elements_per_branch: int >= 1") 144 .Attr("branches: list(func) >= 1") 145 .Attr("other_arguments_lengths: list(int) >= 1") 146 .Attr("output_types: list(type) >= 1") 147 .Attr("output_shapes: list(shape) >= 1") 148 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 149 "output_types")) 150 .SetShapeFn(shape_inference::ScalarShape); 151 152 REGISTER_OP("ChooseFastestDataset") 153 .Input("input_datasets: N * variant") 154 .Output("handle: variant") 155 .Attr("N: int >= 2") 156 .Attr("num_experiments: int") 157 .Attr("output_types: list(type) >= 1") 158 .Attr("output_shapes: list(shape) >= 1") 159 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 160 "output_types")) 161 .SetShapeFn(shape_inference::ScalarShape); 162 163 REGISTER_OP("ExperimentalChooseFastestDataset") 164 .Input("input_datasets: N * variant") 165 .Output("handle: variant") 166 .Attr("N: int >= 2") 167 .Attr("num_experiments: int") 168 .Attr("output_types: list(type) >= 1") 169 .Attr("output_shapes: list(shape) >= 1") 170 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 171 "output_types")) 172 .SetShapeFn(shape_inference::ScalarShape); 173 174 REGISTER_OP("CompressElement") 175 .Input("components: input_types") 176 .Output("compressed: variant") 177 .Attr("input_types: list(type) >= 1") 178 .SetShapeFn(shape_inference::ScalarShape); 179 180 REGISTER_OP("UncompressElement") 181 .Input("compressed: variant") 182 .Output("components: output_types") 183 .Attr("output_types: list(type) >= 1") 184 .Attr("output_shapes: list(shape) >= 1") 185 .SetShapeFn(shape_inference::DatasetIteratorShape); 186 187 REGISTER_OP("ComputeBatchSize") 188 .Input("input_dataset : variant") 189 .Output("batch_size : int64") 190 .SetShapeFn(shape_inference::ScalarShape); 191 192 REGISTER_OP("CSVDataset") 193 .Input("filenames: string") 194 .Input("compression_type: string") 195 .Input("buffer_size: int64") 196 .Input("header: bool") 197 .Input("field_delim: string") 198 .Input("use_quote_delim: bool") 199 .Input("na_value: string") 200 .Input("select_cols: int64") 201 .Input("record_defaults: output_types") 202 .Output("handle: variant") 203 .Attr("output_types: list({float,double,int32,int64,string}) >= 1") 204 .Attr("output_shapes: list(shape) >= 1") 205 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 206 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 207 "output_types")) __anon46ca241c0702(shape_inference::InferenceContext* c) 208 .SetShapeFn([](shape_inference::InferenceContext* c) { 209 shape_inference::ShapeHandle unused; 210 // `filenames` must be a scalar or a vector. 211 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 212 // `compression_type`, `buffer_size`, `header`, `field_delim`, 213 // `use_quote_delim`, `na_value` must be scalars 214 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 215 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 216 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 217 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 218 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 219 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); 220 // `select_cols` must be a vector 221 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); 222 // `record_defaults` must be lists of scalars 223 for (size_t i = 8; i < c->num_inputs(); ++i) { 224 shape_inference::ShapeHandle v; 225 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); 226 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { 227 return errors::InvalidArgument( 228 "Shape of a default must be a length-0 or length-1 vector, or a " 229 "scalar."); 230 } 231 } 232 return shape_inference::ScalarShape(c); 233 }); 234 235 REGISTER_OP("CSVDatasetV2") 236 .Input("filenames: string") 237 .Input("compression_type: string") 238 .Input("buffer_size: int64") 239 .Input("header: bool") 240 .Input("field_delim: string") 241 .Input("use_quote_delim: bool") 242 .Input("na_value: string") 243 .Input("select_cols: int64") 244 .Input("record_defaults: output_types") 245 .Input("exclude_cols: int64") 246 .Output("handle: variant") 247 .Attr("output_types: list({float,double,int32,int64,string}) >= 1") 248 .Attr("output_shapes: list(shape) >= 1") 249 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 250 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 251 "output_types")) __anon46ca241c0802(shape_inference::InferenceContext* c) 252 .SetShapeFn([](shape_inference::InferenceContext* c) { 253 shape_inference::ShapeHandle unused; 254 // `filenames` must be a scalar or a vector. 255 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 256 // `compression_type`, `buffer_size`, `header`, `field_delim`, 257 // `use_quote_delim`, `na_value` must be scalars 258 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 259 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 260 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 261 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 262 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 263 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); 264 // `select_cols` must be a vector 265 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); 266 // `exclude_cols` must be a vector 267 TF_RETURN_IF_ERROR( 268 c->WithRank(c->input(c->num_inputs() - 1), 1, &unused)); 269 // `record_defaults` must be lists of scalars 270 for (size_t i = 8; i < c->num_inputs() - 1; ++i) { 271 shape_inference::ShapeHandle v; 272 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); 273 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { 274 return errors::InvalidArgument( 275 "Shape of a default must be a length-0 or length-1 vector, or a " 276 "scalar."); 277 } 278 } 279 return shape_inference::ScalarShape(c); 280 }); 281 282 REGISTER_OP("ExperimentalCSVDataset") 283 .Input("filenames: string") 284 .Input("compression_type: string") 285 .Input("buffer_size: int64") 286 .Input("header: bool") 287 .Input("field_delim: string") 288 .Input("use_quote_delim: bool") 289 .Input("na_value: string") 290 .Input("select_cols: int64") 291 .Input("record_defaults: output_types") 292 .Output("handle: variant") 293 .Attr("output_types: list({float,double,int32,int64,string}) >= 1") 294 .Attr("output_shapes: list(shape) >= 1") 295 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 296 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 297 "output_types")) __anon46ca241c0902(shape_inference::InferenceContext* c) 298 .SetShapeFn([](shape_inference::InferenceContext* c) { 299 shape_inference::ShapeHandle unused; 300 // `filenames` must be a scalar or a vector. 301 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 302 // `compression_type`, `buffer_size`, `header`, `field_delim`, 303 // `use_quote_delim`, `na_value` must be scalars 304 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 305 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 306 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 307 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 308 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 309 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); 310 // `select_cols` must be a vector 311 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); 312 // `record_defaults` must be lists of scalars 313 for (size_t i = 8; i < c->num_inputs(); ++i) { 314 shape_inference::ShapeHandle v; 315 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); 316 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { 317 return errors::InvalidArgument( 318 "Shape of a default must be a length-0 or length-1 vector, or a " 319 "scalar."); 320 } 321 } 322 return shape_inference::ScalarShape(c); 323 }); 324 325 REGISTER_OP("ExperimentalDatasetCardinality") 326 .Input("input_dataset: variant") 327 .Output("cardinality: int64") 328 .SetShapeFn(shape_inference::ScalarShape); 329 330 REGISTER_OP("DatasetFromGraph") 331 .Input("graph_def: string") 332 .Output("handle: variant") 333 .SetTypeConstructor(full_type::UnaryGeneric(TFT_DATASET)) 334 .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0)) 335 .SetShapeFn(shape_inference::ScalarShape); 336 337 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 338 // implement a mechanism to determine whether `dataset` has a side-effect 339 // and use it to decide whether to use a stateless or stateful version of this 340 // op. 341 REGISTER_OP("DatasetToTFRecord") 342 .Input("input_dataset: variant") 343 .Input("filename: string") 344 .Input("compression_type: string") 345 .SetIsStateful() 346 .SetShapeFn(shape_inference::NoOutputs); 347 348 REGISTER_OP("ExperimentalDatasetToTFRecord") 349 .Input("input_dataset: variant") 350 .Input("filename: string") 351 .Input("compression_type: string") 352 .SetIsStateful() 353 .SetShapeFn(shape_inference::NoOutputs); 354 355 REGISTER_OP("DenseToSparseBatchDataset") 356 .Input("input_dataset: variant") 357 .Input("batch_size: int64") 358 .Input("row_shape: int64") 359 .Output("handle: variant") 360 .Attr("output_types: list(type) >= 1") 361 .Attr("output_shapes: list(shape) >= 1") 362 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 363 "output_types")) __anon46ca241c0a02(shape_inference::InferenceContext* c) 364 .SetShapeFn([](shape_inference::InferenceContext* c) { 365 shape_inference::ShapeHandle unused; 366 // batch_size should be a scalar. 367 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 368 // row_shape should be a 1-D vector. 369 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 370 return shape_inference::ScalarShape(c); 371 }); 372 373 REGISTER_OP("ExperimentalDenseToSparseBatchDataset") 374 .Input("input_dataset: variant") 375 .Input("batch_size: int64") 376 .Input("row_shape: int64") 377 .Output("handle: variant") 378 .Attr("output_types: list(type) >= 1") 379 .Attr("output_shapes: list(shape) >= 1") 380 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 381 "output_types")) __anon46ca241c0b02(shape_inference::InferenceContext* c) 382 .SetShapeFn([](shape_inference::InferenceContext* c) { 383 shape_inference::ShapeHandle unused; 384 // batch_size should be a scalar. 385 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 386 // row_shape should be a 1-D vector. 387 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 388 return shape_inference::ScalarShape(c); 389 }); 390 391 REGISTER_OP("DirectedInterleaveDataset") 392 .Input("selector_input_dataset: variant") 393 .Input("data_input_datasets: N * variant") 394 .Output("handle: variant") 395 .Attr("output_types: list(type) >= 1") 396 .Attr("output_shapes: list(shape) >= 1") 397 .Attr("N: int >= 1") 398 .Attr("stop_on_empty_dataset: bool = false") 399 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 400 "output_types")) 401 .SetShapeFn(shape_inference::ScalarShape); 402 403 REGISTER_OP("ExperimentalDirectedInterleaveDataset") 404 .Input("selector_input_dataset: variant") 405 .Input("data_input_datasets: N * variant") 406 .Output("handle: variant") 407 .Attr("output_types: list(type) >= 1") 408 .Attr("output_shapes: list(shape) >= 1") 409 .Attr("N: int >= 1") 410 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 411 "output_types")) 412 .SetShapeFn(shape_inference::ScalarShape); 413 414 REGISTER_OP("GroupByReducerDataset") 415 .Input("input_dataset: variant") 416 .Input("key_func_other_arguments: Tkey_func_other_arguments") 417 .Input("init_func_other_arguments: Tinit_func_other_arguments") 418 .Input("reduce_func_other_arguments: Treduce_func_other_arguments") 419 .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments") 420 .Output("handle: variant") 421 .Attr("key_func: func") 422 .Attr("init_func: func") 423 .Attr("reduce_func: func") 424 .Attr("finalize_func: func") 425 .Attr("Tkey_func_other_arguments: list(type) >= 0") 426 .Attr("Tinit_func_other_arguments: list(type) >= 0") 427 .Attr("Treduce_func_other_arguments: list(type) >= 0") 428 .Attr("Tfinalize_func_other_arguments: list(type) >= 0") 429 .Attr("output_types: list(type) >= 1") 430 .Attr("output_shapes: list(shape) >= 1") 431 .SetIsStateful() 432 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 433 "output_types")) 434 .SetShapeFn(shape_inference::ScalarShape); 435 436 REGISTER_OP("ExperimentalGroupByReducerDataset") 437 .Input("input_dataset: variant") 438 .Input("key_func_other_arguments: Tkey_func_other_arguments") 439 .Input("init_func_other_arguments: Tinit_func_other_arguments") 440 .Input("reduce_func_other_arguments: Treduce_func_other_arguments") 441 .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments") 442 .Output("handle: variant") 443 .Attr("key_func: func") 444 .Attr("init_func: func") 445 .Attr("reduce_func: func") 446 .Attr("finalize_func: func") 447 .Attr("Tkey_func_other_arguments: list(type) >= 0") 448 .Attr("Tinit_func_other_arguments: list(type) >= 0") 449 .Attr("Treduce_func_other_arguments: list(type) >= 0") 450 .Attr("Tfinalize_func_other_arguments: list(type) >= 0") 451 .Attr("output_types: list(type) >= 1") 452 .Attr("output_shapes: list(shape) >= 1") 453 .SetIsStateful() 454 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 455 "output_types")) 456 .SetShapeFn(shape_inference::ScalarShape); 457 458 REGISTER_OP("GroupByWindowDataset") 459 .Input("input_dataset: variant") 460 .Input("key_func_other_arguments: Tkey_func_other_arguments") 461 .Input("reduce_func_other_arguments: Treduce_func_other_arguments") 462 .Input( 463 "window_size_func_other_arguments: Twindow_size_func_other_arguments") 464 .Output("handle: variant") 465 .Attr("key_func: func") 466 .Attr("reduce_func: func") 467 .Attr("window_size_func: func") 468 .Attr("Tkey_func_other_arguments: list(type) >= 0") 469 .Attr("Treduce_func_other_arguments: list(type) >= 0") 470 .Attr("Twindow_size_func_other_arguments: list(type) >= 0") 471 .Attr("output_types: list(type) >= 1") 472 .Attr("output_shapes: list(shape) >= 1") 473 .Attr("metadata: string = ''") 474 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 475 "output_types")) 476 .SetShapeFn(shape_inference::ScalarShape); 477 478 REGISTER_OP("GetElementAtIndex") 479 .Input("dataset: variant") 480 .Input("index: int64") 481 .Output("components: output_types") 482 .Attr("output_types: list(type) >= 1") 483 .Attr("output_shapes: list(shape) >= 1") 484 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 485 "output_types")) 486 .SetShapeFn(shape_inference::DatasetIteratorShape); 487 488 REGISTER_OP("ExperimentalGroupByWindowDataset") 489 .Input("input_dataset: variant") 490 .Input("key_func_other_arguments: Tkey_func_other_arguments") 491 .Input("reduce_func_other_arguments: Treduce_func_other_arguments") 492 .Input( 493 "window_size_func_other_arguments: Twindow_size_func_other_arguments") 494 .Output("handle: variant") 495 .Attr("key_func: func") 496 .Attr("reduce_func: func") 497 .Attr("window_size_func: func") 498 .Attr("Tkey_func_other_arguments: list(type) >= 0") 499 .Attr("Treduce_func_other_arguments: list(type) >= 0") 500 .Attr("Twindow_size_func_other_arguments: list(type) >= 0") 501 .Attr("output_types: list(type) >= 1") 502 .Attr("output_shapes: list(shape) >= 1") 503 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 504 "output_types")) 505 .SetShapeFn(shape_inference::ScalarShape); 506 507 REGISTER_OP("IgnoreErrorsDataset") 508 .Input("input_dataset: variant") 509 .Output("handle: variant") 510 .Attr("output_types: list(type) >= 1") 511 .Attr("output_shapes: list(shape) >= 1") 512 .Attr("log_warning: bool = false") 513 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 514 "output_types")) 515 .SetShapeFn(shape_inference::ScalarShape); 516 517 REGISTER_OP("ExperimentalIgnoreErrorsDataset") 518 .Input("input_dataset: variant") 519 .Output("handle: variant") 520 .Attr("output_types: list(type) >= 1") 521 .Attr("output_shapes: list(shape) >= 1") 522 .Attr("log_warning: bool = false") 523 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 524 "output_types")) 525 .SetShapeFn(shape_inference::ScalarShape); 526 527 REGISTER_OP("IteratorGetDevice") 528 .Input("resource: resource") 529 .Output("device: string") 530 .SetShapeFn(shape_inference::ScalarShape); 531 532 REGISTER_OP("ExperimentalIteratorGetDevice") 533 .Input("resource: resource") 534 .Output("device: string") 535 .SetShapeFn(shape_inference::ScalarShape); 536 537 REGISTER_OP("LatencyStatsDataset") 538 .Input("input_dataset: variant") 539 .Input("tag: string") 540 .Output("handle: variant") 541 .Attr("output_types: list(type) >= 1") 542 .Attr("output_shapes: list(shape) >= 1") 543 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 544 "output_types")) __anon46ca241c0c02(shape_inference::InferenceContext* c) 545 .SetShapeFn([](shape_inference::InferenceContext* c) { 546 shape_inference::ShapeHandle tag_shape; 547 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); 548 return shape_inference::ScalarShape(c); 549 }); 550 551 REGISTER_OP("ExperimentalLatencyStatsDataset") 552 .Input("input_dataset: variant") 553 .Input("tag: string") 554 .Output("handle: variant") 555 .Attr("output_types: list(type) >= 1") 556 .Attr("output_shapes: list(shape) >= 1") 557 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 558 "output_types")) __anon46ca241c0d02(shape_inference::InferenceContext* c) 559 .SetShapeFn([](shape_inference::InferenceContext* c) { 560 shape_inference::ShapeHandle tag_shape; 561 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); 562 return shape_inference::ScalarShape(c); 563 }); 564 565 REGISTER_OP("LMDBDataset") 566 .Input("filenames: string") 567 .Output("handle: variant") 568 .Attr("output_types: list(type) >= 1") 569 .Attr("output_shapes: list(shape) >= 1") 570 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 571 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 572 "output_types")) 573 .SetShapeFn(shape_inference::ScalarShape); 574 575 REGISTER_OP("ExperimentalLMDBDataset") 576 .Input("filenames: string") 577 .Output("handle: variant") 578 .Attr("output_types: list(type) >= 1") 579 .Attr("output_shapes: list(shape) >= 1") 580 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 581 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 582 "output_types")) 583 .SetShapeFn(shape_inference::ScalarShape); 584 585 REGISTER_OP("MapAndBatchDataset") 586 .Input("input_dataset: variant") 587 .Input("other_arguments: Targuments") 588 .Input("batch_size: int64") 589 .Input("num_parallel_calls: int64") 590 .Input("drop_remainder: bool") 591 .Output("handle: variant") 592 .Attr("f: func") 593 .Attr("Targuments: list(type) >= 0") 594 .Attr("output_types: list(type) >= 1") 595 .Attr("output_shapes: list(shape) >= 1") 596 .Attr("preserve_cardinality: bool = false") 597 .Attr("metadata: string = ''") 598 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 599 "output_types")) __anon46ca241c0e02(shape_inference::InferenceContext* c) 600 .SetShapeFn([](shape_inference::InferenceContext* c) { 601 // Use index from the end to retrieve the Input shapes, 602 // so that to avoid guessing the length of "other_arguments". 603 // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars. 604 shape_inference::ShapeHandle unused; 605 TF_RETURN_IF_ERROR( 606 c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); 607 TF_RETURN_IF_ERROR( 608 c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); 609 TF_RETURN_IF_ERROR( 610 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); 611 612 return shape_inference::ScalarShape(c); 613 }); 614 615 REGISTER_OP("ExperimentalMapAndBatchDataset") 616 .Input("input_dataset: variant") 617 .Input("other_arguments: Targuments") 618 .Input("batch_size: int64") 619 .Input("num_parallel_calls: int64") 620 .Input("drop_remainder: bool") 621 .Output("handle: variant") 622 .Attr("f: func") 623 .Attr("Targuments: list(type) >= 0") 624 .Attr("output_types: list(type) >= 1") 625 .Attr("output_shapes: list(shape) >= 1") 626 .Attr("preserve_cardinality: bool = false") 627 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 628 "output_types")) __anon46ca241c0f02(shape_inference::InferenceContext* c) 629 .SetShapeFn([](shape_inference::InferenceContext* c) { 630 // Use index from the end to retrieve the Input shapes, 631 // so that to avoid guessing the length of "other_arguments". 632 // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars. 633 shape_inference::ShapeHandle unused; 634 TF_RETURN_IF_ERROR( 635 c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); 636 TF_RETURN_IF_ERROR( 637 c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); 638 TF_RETURN_IF_ERROR( 639 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); 640 641 return shape_inference::ScalarShape(c); 642 }); 643 644 REGISTER_OP("ExperimentalMapDataset") 645 .Input("input_dataset: variant") 646 .Input("other_arguments: Targuments") 647 .Output("handle: variant") 648 .Attr("f: func") 649 .Attr("Targuments: list(type) >= 0") 650 .Attr("output_types: list(type) >= 1") 651 .Attr("output_shapes: list(shape) >= 1") 652 .Attr("use_inter_op_parallelism: bool = true") 653 .Attr("preserve_cardinality: bool = false") 654 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 655 "output_types")) 656 .SetShapeFn(shape_inference::ScalarShape); 657 658 REGISTER_OP("MatchingFilesDataset") 659 .Input("patterns: string") 660 .Output("handle: variant") 661 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 662 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 663 TFT_STRING)) __anon46ca241c1002(shape_inference::InferenceContext* c) 664 .SetShapeFn([](shape_inference::InferenceContext* c) { 665 shape_inference::ShapeHandle unused; 666 // `patterns` must be a scalar or a vector. 667 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 668 return shape_inference::ScalarShape(c); 669 }); 670 671 REGISTER_OP("ExperimentalMatchingFilesDataset") 672 .Input("patterns: string") 673 .Output("handle: variant") 674 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 675 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 676 TFT_STRING)) __anon46ca241c1102(shape_inference::InferenceContext* c) 677 .SetShapeFn([](shape_inference::InferenceContext* c) { 678 shape_inference::ShapeHandle unused; 679 // `patterns` must be a scalar or a vector. 680 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 681 return shape_inference::ScalarShape(c); 682 }); 683 684 REGISTER_OP("MaxIntraOpParallelismDataset") 685 .Input("input_dataset: variant") 686 .Input("max_intra_op_parallelism: int64") 687 .Output("handle: variant") 688 .Attr("output_types: list(type) >= 1") 689 .Attr("output_shapes: list(shape) >= 1") 690 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 691 "output_types")) 692 .SetShapeFn(shape_inference::ScalarShape); 693 694 REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset") 695 .Input("input_dataset: variant") 696 .Input("max_intra_op_parallelism: int64") 697 .Output("handle: variant") 698 .Attr("output_types: list(type) >= 1") 699 .Attr("output_shapes: list(shape) >= 1") 700 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 701 "output_types")) 702 .SetShapeFn(shape_inference::ScalarShape); 703 704 REGISTER_OP("NonSerializableDataset") 705 .Input("input_dataset: variant") 706 .Output("handle: variant") 707 .Attr("output_types: list(type) >= 1") 708 .Attr("output_shapes: list(shape) >= 1") 709 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 710 "output_types")) 711 .SetShapeFn(shape_inference::ScalarShape); 712 713 REGISTER_OP("ExperimentalNonSerializableDataset") 714 .Input("input_dataset: variant") 715 .Output("handle: variant") 716 .Attr("output_types: list(type) >= 1") 717 .Attr("output_shapes: list(shape) >= 1") 718 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 719 "output_types")) 720 .SetShapeFn(shape_inference::ScalarShape); 721 722 REGISTER_OP("ParallelInterleaveDataset") 723 .Input("input_dataset: variant") 724 .Input("other_arguments: Targuments") 725 .Input("cycle_length: int64") 726 .Input("block_length: int64") 727 .Input("sloppy: bool") 728 .Input("buffer_output_elements: int64") 729 .Input("prefetch_input_elements: int64") 730 .Output("handle: variant") 731 .Attr("f: func") 732 .Attr("Targuments: list(type) >= 0") 733 .Attr("output_types: list(type) >= 1") 734 .Attr("output_shapes: list(shape) >= 1") 735 .Attr("metadata: string = ''") 736 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 737 "output_types")) 738 .SetShapeFn(shape_inference::ScalarShape); 739 740 // This is the V2 of ParallelInterleaveDataset, renamed to differentiate it 741 // from the non-experimental ParallelInterleaveDataset op. 742 REGISTER_OP("LegacyParallelInterleaveDatasetV2") 743 .Input("input_dataset: variant") 744 .Input("other_arguments: Targuments") 745 .Input("cycle_length: int64") 746 .Input("block_length: int64") 747 .Input("buffer_output_elements: int64") 748 .Input("prefetch_input_elements: int64") 749 .Output("handle: variant") 750 .Attr("f: func") 751 // "true", "false", or "default". 752 .Attr("deterministic: string = 'default'") 753 .Attr("Targuments: list(type) >= 0") 754 .Attr("output_types: list(type) >= 1") 755 .Attr("output_shapes: list(shape) >= 1") 756 .Attr("metadata: string = ''") 757 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 758 "output_types")) 759 .SetShapeFn(shape_inference::ScalarShape); 760 761 // This op is no longer used. We keep it so that we can read graphs written by 762 // old versions of TensorFlow. 763 REGISTER_OP("ExperimentalParallelInterleaveDataset") 764 .Input("input_dataset: variant") 765 .Input("other_arguments: Targuments") 766 .Input("cycle_length: int64") 767 .Input("block_length: int64") 768 .Input("sloppy: bool") 769 .Input("buffer_output_elements: int64") 770 .Input("prefetch_input_elements: int64") 771 .Output("handle: variant") 772 .Attr("f: func") 773 .Attr("Targuments: list(type) >= 0") 774 .Attr("output_types: list(type) >= 1") 775 .Attr("output_shapes: list(shape) >= 1") 776 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 777 "output_types")) 778 .SetShapeFn(shape_inference::ScalarShape); 779 780 REGISTER_OP("ParseExampleDataset") 781 .Input("input_dataset: variant") 782 .Input("num_parallel_calls: int64") 783 .Input("dense_defaults: Tdense") 784 .Output("handle: variant") 785 .Attr("sparse_keys: list(string) >= 0") 786 .Attr("dense_keys: list(string) >= 0") 787 .Attr("sparse_types: list({float,int64,string}) >= 0") 788 .Attr("Tdense: list({float,int64,string}) >= 0") 789 .Attr("dense_shapes: list(shape) >= 0") 790 .Attr("output_types: list(type) >= 1") 791 .Attr("output_shapes: list(shape) >= 1") // Output components will be 792 // sorted by key (dense_keys and 793 // sparse_keys combined) here. 794 .Attr("sloppy: bool = false") 795 .Attr("ragged_keys: list(string) >= 0 = []") 796 .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []") 797 .Attr("ragged_split_types: list({int32,int64}) >= 0 = []") 798 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 799 "output_types")) 800 .SetShapeFn(shape_inference::ScalarShape); 801 802 REGISTER_OP("ParseExampleDatasetV2") 803 .Input("input_dataset: variant") 804 .Input("num_parallel_calls: int64") 805 .Input("dense_defaults: Tdense") 806 .Output("handle: variant") 807 .Attr("sparse_keys: list(string) >= 0") 808 .Attr("dense_keys: list(string) >= 0") 809 .Attr("sparse_types: list({float,int64,string}) >= 0") 810 .Attr("Tdense: list({float,int64,string}) >= 0") 811 .Attr("dense_shapes: list(shape) >= 0") 812 .Attr("output_types: list(type) >= 1") 813 .Attr("output_shapes: list(shape) >= 1") // Output components will be 814 // sorted by key (dense_keys and 815 // sparse_keys combined) here. 816 // "true", "false", or "default". 817 .Attr("deterministic: string = 'default'") 818 .Attr("ragged_keys: list(string) >= 0 = []") 819 .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []") 820 .Attr("ragged_split_types: list({int32,int64}) >= 0 = []") 821 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 822 "output_types")) 823 .SetShapeFn(shape_inference::ScalarShape); 824 825 REGISTER_OP("ExperimentalParseExampleDataset") 826 .Input("input_dataset: variant") 827 .Input("num_parallel_calls: int64") 828 .Input("dense_defaults: Tdense") 829 .Output("handle: variant") 830 .Attr("sparse_keys: list(string) >= 0") 831 .Attr("dense_keys: list(string) >= 0") 832 .Attr("sparse_types: list({float,int64,string}) >= 0") 833 .Attr("Tdense: list({float,int64,string}) >= 0") 834 .Attr("dense_shapes: list(shape) >= 0") 835 .Attr("output_types: list(type) >= 1") 836 .Attr("output_shapes: list(shape) >= 1") // Output components will be 837 // sorted by key (dense_keys and 838 // sparse_keys combined) here. 839 .Attr("sloppy: bool = false") 840 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 841 "output_types")) 842 .SetShapeFn(shape_inference::ScalarShape); 843 844 REGISTER_OP("PrivateThreadPoolDataset") 845 .Input("input_dataset: variant") 846 .Input("num_threads: int64") 847 .Output("handle: variant") 848 .Attr("output_types: list(type) >= 1") 849 .Attr("output_shapes: list(shape) >= 1") 850 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 851 "output_types")) 852 .SetShapeFn(shape_inference::ScalarShape); 853 854 REGISTER_OP("ExperimentalPrivateThreadPoolDataset") 855 .Input("input_dataset: variant") 856 .Input("num_threads: int64") 857 .Output("handle: variant") 858 .Attr("output_types: list(type) >= 1") 859 .Attr("output_shapes: list(shape) >= 1") 860 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 861 "output_types")) 862 .SetShapeFn(shape_inference::ScalarShape); 863 864 REGISTER_OP("ExperimentalRandomDataset") 865 .Input("seed: int64") 866 .Input("seed2: int64") 867 .Output("handle: variant") 868 .Attr("output_types: list(type) >= 1") 869 .Attr("output_shapes: list(shape) >= 1") 870 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 871 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 872 "output_types")) __anon46ca241c1202(shape_inference::InferenceContext* c) 873 .SetShapeFn([](shape_inference::InferenceContext* c) { 874 shape_inference::ShapeHandle unused; 875 // buffer_size, seed, and seed2 should be scalars. 876 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 877 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 878 return shape_inference::ScalarShape(c); 879 }); 880 881 REGISTER_OP("RandomDataset") 882 .Input("seed: int64") 883 .Input("seed2: int64") 884 .Output("handle: variant") 885 .Attr("output_types: list(type) >= 1") 886 .Attr("output_shapes: list(shape) >= 1") 887 .Attr("metadata: string = ''") 888 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 889 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 890 "output_types")) __anon46ca241c1302(shape_inference::InferenceContext* c) 891 .SetShapeFn([](shape_inference::InferenceContext* c) { 892 shape_inference::ShapeHandle unused; 893 // buffer_size, seed, and seed2 should be scalars. 894 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 895 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 896 return shape_inference::ScalarShape(c); 897 }); 898 899 REGISTER_OP("ExperimentalRebatchDataset") 900 .Input("input_dataset: variant") 901 .Input("num_replicas: int64") 902 .Output("handle: variant") 903 .Attr("output_types: list(type) >= 1") 904 .Attr("output_shapes: list(shape) >= 1") 905 .Attr("use_fallback: bool = true") 906 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 907 "output_types")) 908 .SetShapeFn(shape_inference::ScalarShape); 909 910 REGISTER_OP("RebatchDataset") 911 .Input("input_dataset: variant") 912 .Input("num_replicas: int64") 913 .Output("handle: variant") 914 .Attr("output_types: list(type) >= 1") 915 .Attr("output_shapes: list(shape) >= 1") 916 .Attr("use_fallback: bool = true") 917 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 918 "output_types")) 919 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, 920 full_type::BatchTensor)) 921 .SetShapeFn(shape_inference::ScalarShape); 922 923 REGISTER_OP("RebatchDatasetV2") 924 .Input("input_dataset: variant") 925 .Input("batch_sizes: int64") 926 .Input("drop_remainder: bool") 927 .Output("handle: variant") 928 .Attr("output_types: list(type) >= 1") 929 .Attr("output_shapes: list(shape) >= 1") 930 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 931 "output_types")) 932 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, 933 full_type::BatchTensor)) 934 .SetShapeFn(shape_inference::ScalarShape); 935 936 REGISTER_OP("SamplingDataset") 937 .Input("input_dataset: variant") 938 .Input("rate: float32") 939 .Input("seed: int64") 940 .Input("seed2: int64") 941 .Output("handle: variant") 942 .Attr("output_types: list(type) >= 1") 943 .Attr("output_shapes: list(shape) >= 1") 944 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 945 "output_types")) __anon46ca241c1402(shape_inference::InferenceContext* c) 946 .SetShapeFn([](shape_inference::InferenceContext* c) { 947 shape_inference::ShapeHandle unused; 948 // rate, seed, and seed2 should be scalars. 949 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 950 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 951 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 952 return shape_inference::ScalarShape(c); 953 }); 954 955 REGISTER_OP("ScanDataset") 956 .Input("input_dataset: variant") 957 .Input("initial_state: Tstate") 958 .Input("other_arguments: Targuments") 959 .Output("handle: variant") 960 .Attr("f: func") 961 .Attr("Tstate: list(type) >= 1") 962 .Attr("Targuments: list(type) >= 0") 963 .Attr("output_types: list(type) >= 1") 964 .Attr("output_shapes: list(shape) >= 1") 965 .Attr("preserve_cardinality: bool = false") 966 .Attr("use_default_device: bool = true") 967 .Attr("metadata: string = ''") 968 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 969 "output_types")) 970 .SetShapeFn(shape_inference::ScalarShape); 971 972 REGISTER_OP("ExperimentalScanDataset") 973 .Input("input_dataset: variant") 974 .Input("initial_state: Tstate") 975 .Input("other_arguments: Targuments") 976 .Output("handle: variant") 977 .Attr("f: func") 978 .Attr("Tstate: list(type) >= 1") 979 .Attr("Targuments: list(type) >= 0") 980 .Attr("output_types: list(type) >= 1") 981 .Attr("output_shapes: list(shape) >= 1") 982 .Attr("preserve_cardinality: bool = false") 983 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 984 "output_types")) 985 .SetShapeFn(shape_inference::ScalarShape); 986 987 REGISTER_OP("SetStatsAggregatorDataset") 988 .Input("input_dataset: variant") 989 .Input("stats_aggregator: resource") 990 .Input("tag: string") 991 .Input("counter_prefix: string") 992 .Output("handle: variant") 993 .Attr("output_types: list(type) >= 1") 994 .Attr("output_shapes: list(shape) >= 1") 995 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 996 "output_types")) 997 .SetShapeFn(shape_inference::ScalarShape); 998 999 REGISTER_OP("ExperimentalSetStatsAggregatorDataset") 1000 .Input("input_dataset: variant") 1001 .Input("stats_aggregator: resource") 1002 .Input("tag: string") 1003 .Input("counter_prefix: string") 1004 .Output("handle: variant") 1005 .Attr("output_types: list(type) >= 1") 1006 .Attr("output_shapes: list(shape) >= 1") 1007 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1008 "output_types")) 1009 .SetShapeFn(shape_inference::ScalarShape); 1010 1011 REGISTER_OP("SleepDataset") 1012 .Input("input_dataset: variant") 1013 .Input("sleep_microseconds: int64") 1014 .Output("handle: variant") 1015 .Attr("output_types: list(type) >= 1") 1016 .Attr("output_shapes: list(shape) >= 1") 1017 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1018 "output_types")) __anon46ca241c1502(shape_inference::InferenceContext* c) 1019 .SetShapeFn([](shape_inference::InferenceContext* c) { 1020 shape_inference::ShapeHandle unused; 1021 // Both inputs are scalar. 1022 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); 1023 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); 1024 return shape_inference::ScalarShape(c); 1025 }); 1026 1027 REGISTER_OP("ExperimentalSleepDataset") 1028 .Input("input_dataset: variant") 1029 .Input("sleep_microseconds: int64") 1030 .Output("handle: variant") 1031 .Attr("output_types: list(type) >= 1") 1032 .Attr("output_shapes: list(shape) >= 1") 1033 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1034 "output_types")) __anon46ca241c1602(shape_inference::InferenceContext* c) 1035 .SetShapeFn([](shape_inference::InferenceContext* c) { 1036 shape_inference::ShapeHandle unused; 1037 // Both inputs are scalar. 1038 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); 1039 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); 1040 return shape_inference::ScalarShape(c); 1041 }); 1042 1043 REGISTER_OP("SlidingWindowDataset") 1044 .Input("input_dataset: variant") 1045 .Input("window_size: int64") 1046 .Input("window_shift: int64") 1047 .Input("window_stride: int64") 1048 .Output("handle: variant") 1049 .Attr("drop_remainder: bool = true") 1050 .Attr("output_types: list(type) >= 1") 1051 .Attr("output_shapes: list(shape) >= 1") 1052 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1053 "output_types")) __anon46ca241c1702(shape_inference::InferenceContext* c) 1054 .SetShapeFn([](shape_inference::InferenceContext* c) { 1055 shape_inference::ShapeHandle unused; 1056 // window_size, window_shift, and window_stride should be scalars. 1057 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1058 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1059 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1060 return shape_inference::ScalarShape(c); 1061 }); 1062 1063 REGISTER_OP("ExperimentalSlidingWindowDataset") 1064 .Input("input_dataset: variant") 1065 .Input("window_size: int64") 1066 .Input("window_shift: int64") 1067 .Input("window_stride: int64") 1068 .Output("handle: variant") 1069 .Attr("output_types: list(type) >= 1") 1070 .Attr("output_shapes: list(shape) >= 1") 1071 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1072 "output_types")) __anon46ca241c1802(shape_inference::InferenceContext* c) 1073 .SetShapeFn([](shape_inference::InferenceContext* c) { 1074 shape_inference::ShapeHandle unused; 1075 // window_size, window_shift, and window_stride should be scalars. 1076 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1078 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 1079 return shape_inference::ScalarShape(c); 1080 }); 1081 1082 REGISTER_OP("SnapshotDataset") 1083 .Input("input_dataset: variant") 1084 .Input("path: string") 1085 .Output("handle: variant") 1086 .Attr("output_types: list(type) >= 1") 1087 .Attr("output_shapes: list(shape) >= 1") 1088 .Attr("compression: string = ''") 1089 .Attr("reader_path_prefix: string = ''") 1090 .Attr("writer_path_prefix: string = ''") 1091 .Attr("shard_size_bytes: int = 10737418240") // 10 GiB default 1092 .Attr("pending_snapshot_expiry_seconds: int = 86400") // 1 day default 1093 .Attr("num_reader_threads: int = 1") 1094 .Attr("reader_buffer_size: int = 1") 1095 .Attr("num_writer_threads: int = 1") 1096 .Attr("writer_buffer_size: int = 1") 1097 .Attr("shuffle_on_read: bool = false") 1098 .Attr("seed: int = 0") 1099 .Attr("seed2: int = 0") 1100 .Attr("mode: string = 'auto'") 1101 .Attr("snapshot_name: string = ''") 1102 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1103 "output_types")) __anon46ca241c1902(shape_inference::InferenceContext* c) 1104 .SetShapeFn([](shape_inference::InferenceContext* c) { 1105 shape_inference::ShapeHandle unused; 1106 // snapshot_path should be a scalar. 1107 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1108 return shape_inference::ScalarShape(c); 1109 }); 1110 1111 REGISTER_OP("SnapshotDatasetV2") 1112 .Input("input_dataset: variant") 1113 .Input("path: string") 1114 .Input("reader_func_other_args: Treader_func_args") 1115 .Input("shard_func_other_args: Tshard_func_args") 1116 .Output("handle: variant") 1117 .Attr("output_types: list(type) >= 1") 1118 .Attr("output_shapes: list(shape) >= 1") 1119 .Attr("compression: string = ''") 1120 .Attr("reader_prefix: string = ''") 1121 .Attr("writer_prefix: string = ''") 1122 .Attr("hash_valid: bool = false") 1123 .Attr("hash: int = 0") 1124 .Attr("reader_func: func") 1125 .Attr("shard_func: func") 1126 .Attr("Treader_func_args: list(type) >= 0") 1127 .Attr("Tshard_func_args: list(type) >= 0") 1128 .Attr("metadata: string = ''") 1129 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1130 "output_types")) __anon46ca241c1a02(shape_inference::InferenceContext* c) 1131 .SetShapeFn([](shape_inference::InferenceContext* c) { 1132 shape_inference::ShapeHandle unused; 1133 // `path` should be a scalar. 1134 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1135 return shape_inference::ScalarShape(c); 1136 }); 1137 1138 REGISTER_OP("SaveDataset") 1139 .Input("input_dataset: variant") 1140 .Input("path: string") 1141 .Input("shard_func_other_args: Tshard_func_args") 1142 .Attr("compression: string = ''") 1143 .Attr("shard_func: func") 1144 .Attr("use_shard_func: bool = true") 1145 .Attr("Tshard_func_args: list(type) >= 0") 1146 .SetIsStateful() __anon46ca241c1b02(shape_inference::InferenceContext* c) 1147 .SetShapeFn([](shape_inference::InferenceContext* c) { 1148 shape_inference::ShapeHandle unused; 1149 // `path` should be a scalar. 1150 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1151 return OkStatus(); 1152 }); 1153 1154 REGISTER_OP("SaveDatasetV2") 1155 .Input("input_dataset: variant") 1156 .Input("path: string") 1157 .Input("shard_func_other_args: Tshard_func_args") 1158 .Output("handle: variant") 1159 .Attr("compression: string = ''") 1160 .Attr("shard_func: func") 1161 .Attr("use_shard_func: bool = true") 1162 .Attr("Tshard_func_args: list(type) >= 0") 1163 .Attr("output_types: list(type) >= 1") 1164 .Attr("output_shapes: list(shape) >= 1") 1165 .SetIsStateful() __anon46ca241c1c02(shape_inference::InferenceContext* c) 1166 .SetShapeFn([](shape_inference::InferenceContext* c) { 1167 shape_inference::ShapeHandle unused; 1168 // `path` should be a scalar. 1169 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1170 return shape_inference::ScalarShape(c); 1171 }); 1172 1173 REGISTER_OP("LoadDataset") 1174 .Input("path: string") 1175 .Input("reader_func_other_args: Treader_func_args") 1176 .Output("handle: variant") 1177 .Attr("output_types: list(type) >= 1") 1178 .Attr("output_shapes: list(shape) >= 1") 1179 .Attr("compression: string = ''") 1180 .Attr("reader_func: func") 1181 .Attr("Treader_func_args: list(type) >= 0") 1182 .SetIsStateful() 1183 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1184 "output_types")) __anon46ca241c1d02(shape_inference::InferenceContext* c) 1185 .SetShapeFn([](shape_inference::InferenceContext* c) { 1186 shape_inference::ShapeHandle unused; 1187 // `path` should be a scalar. 1188 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1189 return shape_inference::ScalarShape(c); 1190 }); 1191 1192 REGISTER_OP("SnapshotDatasetReader") 1193 .Input("shard_dir: string") 1194 .Input("start_index: int64") 1195 .Output("handle: variant") 1196 .Attr("output_types: list(type) >= 1") 1197 .Attr("output_shapes: list(shape) >= 1") 1198 .Attr("compression: string = ''") 1199 .Attr("version: int") 1200 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1201 "output_types")) __anon46ca241c1e02(shape_inference::InferenceContext* c) 1202 .SetShapeFn([](shape_inference::InferenceContext* c) { 1203 shape_inference::ShapeHandle unused; 1204 // `shard_dir` should be a scalar. 1205 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1206 // `start_index` should be a scalar. 1207 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1208 return shape_inference::ScalarShape(c); 1209 }); 1210 1211 REGISTER_OP("SnapshotNestedDatasetReader") 1212 .Input("inputs: N * variant") 1213 .Output("handle: variant") 1214 .Attr("output_types: list(type) >= 1") 1215 .Attr("output_shapes: list(shape) >= 1") 1216 .Attr("N: int >= 1") 1217 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1218 "output_types")) 1219 .SetShapeFn(shape_inference::ScalarShape); 1220 1221 REGISTER_OP("SqlDataset") 1222 .Input("driver_name: string") 1223 .Input("data_source_name: string") 1224 .Input("query: string") 1225 .Output("handle: variant") 1226 .Attr("output_types: list(type) >= 1") 1227 .Attr("output_shapes: list(shape) >= 1") 1228 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 1229 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1230 "output_types")) __anon46ca241c1f02(shape_inference::InferenceContext* c) 1231 .SetShapeFn([](shape_inference::InferenceContext* c) { 1232 shape_inference::ShapeHandle unused; 1233 // driver_name, data_source_name, and query should be scalars. 1234 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1235 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1236 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1237 return shape_inference::ScalarShape(c); 1238 }); 1239 1240 REGISTER_OP("ExperimentalSqlDataset") 1241 .Input("driver_name: string") 1242 .Input("data_source_name: string") 1243 .Input("query: string") 1244 .Output("handle: variant") 1245 .Attr("output_types: list(type) >= 1") 1246 .Attr("output_shapes: list(shape) >= 1") 1247 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 1248 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1249 "output_types")) __anon46ca241c2002(shape_inference::InferenceContext* c) 1250 .SetShapeFn([](shape_inference::InferenceContext* c) { 1251 shape_inference::ShapeHandle unused; 1252 // driver_name, data_source_name, and query should be scalars. 1253 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1254 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 1255 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 1256 return shape_inference::ScalarShape(c); 1257 }); 1258 1259 REGISTER_OP("StatsAggregatorHandle") 1260 .Output("handle: resource") 1261 .SetShapeFn(shape_inference::ScalarShape) 1262 .Attr("container: string = ''") 1263 .Attr("shared_name: string = ''"); 1264 1265 REGISTER_OP("ExperimentalStatsAggregatorHandle") 1266 .Output("handle: resource") 1267 .SetShapeFn(shape_inference::ScalarShape) 1268 .Attr("container: string = ''") 1269 .Attr("shared_name: string = ''"); 1270 1271 REGISTER_OP("StatsAggregatorHandleV2") 1272 .Output("handle: resource") 1273 .SetShapeFn(shape_inference::ScalarShape) 1274 .Attr("container: string = ''") 1275 .Attr("shared_name: string = ''"); 1276 1277 REGISTER_OP("StatsAggregatorSetSummaryWriter") 1278 .Input("stats_aggregator: resource") 1279 .Input("summary: resource") 1280 .SetShapeFn(shape_inference::NoOutputs); 1281 1282 REGISTER_OP("StatsAggregatorSummary") 1283 .Input("iterator: resource") 1284 .Output("summary: string") 1285 .SetShapeFn(shape_inference::ScalarShape); 1286 1287 REGISTER_OP("ExperimentalStatsAggregatorSummary") 1288 .Input("iterator: resource") 1289 .Output("summary: string") 1290 .SetShapeFn(shape_inference::ScalarShape); 1291 1292 REGISTER_OP("TakeWhileDataset") 1293 .Input("input_dataset: variant") 1294 .Input("other_arguments: Targuments") 1295 .Output("handle: variant") 1296 .Attr("predicate: func") 1297 .Attr("Targuments: list(type) >= 0") 1298 .Attr("output_types: list(type) >= 1") 1299 .Attr("output_shapes: list(shape) >= 1") 1300 .Attr("metadata: string = ''") 1301 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1302 "output_types")) 1303 .SetShapeFn(shape_inference::ScalarShape); 1304 1305 REGISTER_OP("ExperimentalTakeWhileDataset") 1306 .Input("input_dataset: variant") 1307 .Input("other_arguments: Targuments") 1308 .Output("handle: variant") 1309 .Attr("predicate: func") 1310 .Attr("Targuments: list(type) >= 0") 1311 .Attr("output_types: list(type) >= 1") 1312 .Attr("output_shapes: list(shape) >= 1") 1313 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1314 "output_types")) 1315 .SetShapeFn(shape_inference::ScalarShape); 1316 1317 REGISTER_OP("ThreadPoolDataset") 1318 .Input("input_dataset: variant") 1319 .Input("thread_pool: resource") 1320 .Output("handle: variant") 1321 .Attr("output_types: list(type) >= 1") 1322 .Attr("output_shapes: list(shape) >= 1") 1323 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1324 "output_types")) 1325 .SetShapeFn(shape_inference::ScalarShape); 1326 1327 REGISTER_OP("ExperimentalThreadPoolDataset") 1328 .Input("input_dataset: variant") 1329 .Input("thread_pool: resource") 1330 .Output("handle: variant") 1331 .Attr("output_types: list(type) >= 1") 1332 .Attr("output_shapes: list(shape) >= 1") 1333 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1334 "output_types")) 1335 .SetShapeFn(shape_inference::ScalarShape); 1336 1337 REGISTER_OP("ThreadPoolHandle") 1338 .Output("handle: resource") 1339 .SetShapeFn(shape_inference::ScalarShape) 1340 .Attr("num_threads: int") 1341 .Attr("max_intra_op_parallelism: int = 1") 1342 .Attr("display_name: string") 1343 .Attr("container: string = ''") 1344 .Attr("shared_name: string = ''"); 1345 1346 REGISTER_OP("ExperimentalThreadPoolHandle") 1347 .Output("handle: resource") 1348 .SetShapeFn(shape_inference::ScalarShape) 1349 .Attr("num_threads: int") 1350 .Attr("max_intra_op_parallelism: int = 1") 1351 .Attr("display_name: string") 1352 .Attr("container: string = ''") 1353 .Attr("shared_name: string = ''"); 1354 1355 REGISTER_OP("UnbatchDataset") 1356 .Input("input_dataset: variant") 1357 .Output("handle: variant") 1358 .Attr("output_types: list(type) >= 1") 1359 .Attr("output_shapes: list(shape) >= 1") 1360 .Attr("metadata: string = ''") 1361 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1362 "output_types")) 1363 .SetShapeFn(shape_inference::ScalarShape); 1364 1365 REGISTER_OP("ExperimentalUnbatchDataset") 1366 .Input("input_dataset: variant") 1367 .Output("handle: variant") 1368 .Attr("output_types: list(type) >= 1") 1369 .Attr("output_shapes: list(shape) >= 1") 1370 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1371 "output_types")) 1372 .SetShapeFn(shape_inference::ScalarShape); 1373 1374 REGISTER_OP("UniqueDataset") 1375 .Input("input_dataset: variant") 1376 .Output("handle: variant") 1377 .Attr("output_types: list(type) >= 1") 1378 .Attr("output_shapes: list(shape) >= 1") 1379 .Attr("metadata: string = ''") 1380 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1381 "output_types")) 1382 .SetShapeFn(shape_inference::ScalarShape); 1383 1384 REGISTER_OP("ExperimentalUniqueDataset") 1385 .Input("input_dataset: variant") 1386 .Output("handle: variant") 1387 .Attr("output_types: list(type) >= 1") 1388 .Attr("output_shapes: list(shape) >= 1") 1389 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1390 "output_types")) 1391 .SetShapeFn(shape_inference::ScalarShape); 1392 1393 REGISTER_OP("DummyIterationCounter") 1394 .Output("handle: resource") __anon46ca241c2102(shape_inference::InferenceContext* c) 1395 .SetShapeFn([](shape_inference::InferenceContext* c) { 1396 c->set_output(0, c->Scalar()); 1397 return OkStatus(); 1398 }); 1399 1400 REGISTER_OP("DataServiceDataset") 1401 .Input("dataset_id: int64") 1402 .Input("processing_mode: string") 1403 .Input("address: string") 1404 .Input("protocol: string") 1405 .Input("job_name: string") 1406 .Input("max_outstanding_requests: int64") 1407 .Input("iteration_counter: resource") 1408 .Output("handle: variant") 1409 .Attr("task_refresh_interval_hint_ms: int = -1") 1410 .Attr("output_types: list(type) >= 1") 1411 .Attr("output_shapes: list(shape) >= 1") 1412 .Attr("data_transfer_protocol: string = ''") 1413 .Attr("target_workers: string = 'AUTO'") 1414 .Attr("cross_trainer_cache_options: string = ''") 1415 .SetIsStateful() 1416 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1417 "output_types")) 1418 .SetShapeFn(shape_inference::ScalarShape); 1419 1420 // Adds `consumer_index` and `num_consumers` arguments to support round-robin 1421 // reads. 1422 REGISTER_OP("DataServiceDatasetV2") 1423 .Input("dataset_id: int64") 1424 .Input("processing_mode: string") 1425 .Input("address: string") 1426 .Input("protocol: string") 1427 .Input("job_name: string") 1428 .Input("consumer_index: int64") 1429 .Input("num_consumers: int64") 1430 .Input("max_outstanding_requests: int64") 1431 .Input("iteration_counter: resource") 1432 .Output("handle: variant") 1433 .Attr("task_refresh_interval_hint_ms: int = -1") 1434 .Attr("output_types: list(type) >= 1") 1435 .Attr("output_shapes: list(shape) >= 1") 1436 .Attr("data_transfer_protocol: string = ''") 1437 .Attr("target_workers: string = 'AUTO'") 1438 .Attr("cross_trainer_cache_options: string = ''") 1439 .SetIsStateful() 1440 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1441 "output_types")) 1442 .SetShapeFn(shape_inference::ScalarShape); 1443 1444 // Adds `uncompress` and `uncompress_fn` attributes to support uncompression. 1445 REGISTER_OP("DataServiceDatasetV3") 1446 .Input("dataset_id: int64") 1447 .Input("processing_mode: string") 1448 .Input("address: string") 1449 .Input("protocol: string") 1450 .Input("job_name: string") 1451 .Input("consumer_index: int64") 1452 .Input("num_consumers: int64") 1453 .Input("max_outstanding_requests: int64") 1454 .Input("iteration_counter: resource") 1455 .Output("handle: variant") 1456 .Attr("task_refresh_interval_hint_ms: int = -1") 1457 .Attr("output_types: list(type) >= 1") 1458 .Attr("output_shapes: list(shape) >= 1") 1459 .Attr("data_transfer_protocol: string = ''") 1460 .Attr("target_workers: string = 'AUTO'") 1461 .Attr("uncompress: bool = false") 1462 .Attr("uncompress_fn: func") 1463 .Attr("cross_trainer_cache_options: string = ''") 1464 .SetIsStateful() 1465 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1466 "output_types")) 1467 .SetShapeFn(shape_inference::ScalarShape); 1468 1469 // Changes `dataset_id` from int64 to string. 1470 REGISTER_OP("DataServiceDatasetV4") 1471 .Input("dataset_id: string") 1472 .Input("processing_mode: string") 1473 .Input("address: string") 1474 .Input("protocol: string") 1475 .Input("job_name: string") 1476 .Input("consumer_index: int64") 1477 .Input("num_consumers: int64") 1478 .Input("max_outstanding_requests: int64") 1479 .Input("iteration_counter: resource") 1480 .Output("handle: variant") 1481 .Attr("task_refresh_interval_hint_ms: int = -1") 1482 .Attr("output_types: list(type) >= 1") 1483 .Attr("output_shapes: list(shape) >= 1") 1484 .Attr("data_transfer_protocol: string = ''") 1485 .Attr("target_workers: string = 'AUTO'") 1486 .Attr("uncompress: bool = false") 1487 .Attr("uncompress_fn: func") 1488 .Attr("cross_trainer_cache_options: string = ''") 1489 .SetIsStateful() 1490 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1491 "output_types")) 1492 .SetShapeFn(shape_inference::ScalarShape); 1493 1494 REGISTER_OP("RegisterDataset") 1495 .Input("dataset: variant") 1496 .Input("address: string") 1497 .Input("protocol: string") 1498 .Output("dataset_id: int64") 1499 .Attr("external_state_policy: int") 1500 .Attr("element_spec: string = ''") 1501 .Attr("metadata: string = ''") 1502 .SetShapeFn(shape_inference::ScalarShape); 1503 1504 // Changes `dataset_id` from int64 to string. 1505 REGISTER_OP("RegisterDatasetV2") 1506 .Input("dataset: variant") 1507 .Input("address: string") 1508 .Input("protocol: string") 1509 .Output("dataset_id: string") 1510 .Attr("external_state_policy: int") 1511 .Attr("element_spec: string = ''") 1512 .Attr("requested_dataset_id: string = ''") 1513 .Attr("metadata: string = ''") 1514 .SetShapeFn(shape_inference::ScalarShape); 1515 1516 REGISTER_OP("InitializeTableFromDataset") 1517 .Input("table_handle: resource") 1518 .Input("dataset: variant") __anon46ca241c2202(shape_inference::InferenceContext* c) 1519 .SetShapeFn([](shape_inference::InferenceContext* c) { 1520 shape_inference::ShapeHandle handle; 1521 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); 1522 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); 1523 return OkStatus(); 1524 }); 1525 1526 // - `output_types` is the types of tensors in a single dataset element. 1527 // - `output_shapes` is the shapes of tensors in a single dataset element. 1528 // - `output_types` and `output_shapes` are the same size: the number of 1529 // tensors in a single dataset element, a.k.a. the number of components. 1530 // - `Tinput_types` is the types of tensors for all dataset elements. 1531 // `Tinput_types` is equivalent to `output_types` repeated for N total dataset 1532 // elements. 1533 REGISTER_OP("ListDataset") 1534 .Input("tensors: Tinput_types") 1535 .Output("handle: variant") 1536 .Attr("Tinput_types: list(type) >= 1") 1537 .Attr("output_types: list(type) >= 1") 1538 .Attr("output_shapes: list(shape) >= 1") 1539 .Attr("metadata: string = ''") 1540 .SetDoNotOptimize() 1541 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1542 "output_types")) 1543 .SetShapeFn(shape_inference::ScalarShape); 1544 1545 } // namespace tensorflow 1546