1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the operation definition file for TensorFlow. 17// 18// This file contains TensorFlow ops whose definitions are amended to fix 19// issues or provide more information. In this file you have full control 20// of the op definition; all changes will be retained with subsequent 21// refreshes. 22// 23// This file includes another file, `tf_generated_ops.td`, which contains 24// all ops whose definitions are generated from TensorFlow codebase. 25// Changes made there are not respected. 26 27#ifndef TF_OPS 28#define TF_OPS 29 30include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" 31include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" 32include "mlir/Interfaces/CallInterfaces.td" 33include "mlir/Interfaces/ControlFlowInterfaces.td" 34include "mlir/Interfaces/InferTypeOpInterface.td" 35include "mlir/Interfaces/LoopLikeInterface.td" 36include "mlir/Interfaces/SideEffectInterfaces.td" 37include "mlir/IR/OpAsmInterface.td" 38include "mlir/IR/OpBase.td" 39include "mlir/IR/SymbolInterfaces.td" 40 41class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> { 42 let results = (outs 43 TF_VariantTensor:$handle 44 ); 45 46 TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; 47 48 let hasVerifier = 1; 49 50 DerivedTypeAttr element_dtype = DerivedTypeAttr< 51 "return getElementTypeOrSelf(element_type());">; 52 53 let extraClassDeclaration = [{ 54 // Returns type of the TensorList element produced by this op. 55 TensorType element_type() { return handle_dtype().getSubtypes()[0]; } 56 57 // Returns data type of the result handle. Returned type contains type of 58 // the TensorList element as a subtype. 59 VariantType handle_dtype() { 60 return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>(); 61 } 62 }]; 63} 64 65def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { 66 let summary = [{ 67An n-way switch statement which calls a single branch function. 68 }]; 69 70 let description = [{ 71An n-way switch statement, implementing the following: 72 ``` 73 switch (branch_index) { 74 case 0: 75 output = branches[0](input); 76 break; 77 case 1: 78 output = branches[1](input); 79 break; 80 ... 81 case [[nbranches-1]]: 82 default: 83 output = branches[nbranches-1](input); 84 break; 85 } 86 ``` 87 }]; 88 89 let arguments = (ins 90 I32Tensor:$branch_index, 91 Variadic<TF_Tensor>:$input, 92 93 ConfinedAttr<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches, 94 95 // Used to map StatelessCase and Case op defined in TensorFlow to a common 96 // op. 97 BoolAttr:$is_stateless 98 ); 99 100 let results = (outs 101 Variadic<TF_Tensor>:$output 102 ); 103 104 TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; 105 TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; 106 TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; 107 108 let hasCanonicalizer = 1; 109 110 let hasVerifier = 1; 111 112 113 let extraClassDeclaration = [{ 114 int num_branches() { return branches().size(); } 115 116 // Gets function corresponding branch # `index`. 117 // Prefer passing in SymbolTableCollection to reduce lookup costs by 118 // enabling reusing cached symbol table lookup. 119 func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) { 120 auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>(); 121 if (table) 122 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref); 123 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref); 124 } 125 // TODO(b/204997177): Deprecate and remove. 126 func::FuncOp branch_function(int index) { return ResolveBranchFunction(nullptr, index); } 127 128 // Resolve all branch functions. 129 // Prefer passing in SymbolTableCollection to reduce lookup costs by 130 // enabling reusing cached symbol table lookup. 131 void ResolveBranchFunctions(::mlir::SymbolTableCollection* table, 132 SmallVectorImpl<func::FuncOp> &functions) { 133 functions.reserve(num_branches()); 134 for (int idx : llvm::seq<int>(0, num_branches())) 135 functions.push_back(ResolveBranchFunction(table, idx)); 136 } 137 // TODO(b/204997177): Deprecate and remove. 138 void get_branch_functions(SmallVectorImpl<func::FuncOp> &functions) { 139 return ResolveBranchFunctions(nullptr, functions); 140 } 141 }]; 142} 143 144def TF_CaseRegionOp : TF_Op<"CaseRegion", 145 [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { 146 let summary = [{ 147An n-way switch statement which calls a single branch function. 148 }]; 149 150 let description = [{ 151An n-way switch statement, implementing the following: 152 ``` 153 switch (branch_index) { 154 case 0: 155 output = branches[0](input); 156 break; 157 case 1: 158 output = branches[1](input); 159 break; 160 ... 161 case [[nbranches-1]]: 162 default: 163 output = branches[nbranches-1](input); 164 break; 165 } 166 ``` 167 }]; 168 169 let arguments = (ins 170 I32Tensor:$branch_index, 171 172 // Used to map StatelessCase and Case op defined in TensorFlow to a common 173 // op. 174 BoolAttr:$is_stateless 175 ); 176 177 let results = (outs 178 Variadic<TF_Tensor>:$output 179 ); 180 181 let regions = (region VariadicRegion<SizedRegion<1>>:$branches); 182 183 let hasVerifier = 1; 184 185 let hasCanonicalizer = 1; 186 187} 188 189// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with 190// its type encoding the tensor's shape and data type. 191def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, 192 DeclareOpInterfaceMethods<InferTypeOpInterface>, 193 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { 194 let summary = "Constant tensor op"; 195 196 let arguments = (ins 197 ElementsAttr:$value 198 ); 199 200 let results = (outs 201 TF_Tensor:$output 202 ); 203 204 TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; 205 206 let builders = [ 207 OpBuilder<(ins "Attribute":$value)>, 208 OpBuilder<(ins "Type":$type, "Attribute":$value)>, 209 ]; 210 211 let hasFolder = 1; 212 213 let extraClassDeclaration = [{ 214 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 215 return BroadcastCompatible(l, r); 216 } 217 }]; 218} 219 220def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { 221 let summary = "Creates and returns an empty tensor list."; 222 223 let description = [{ 224All list elements must be tensors of dtype element_dtype and shape compatible 225with element_shape. 226 227handle: an empty tensor list. 228element_dtype: the type of elements in the list. 229element_shape: a shape compatible with that of elements in the list. 230 }]; 231 232 let arguments = (ins 233 TF_I32OrI64Tensor:$element_shape, 234 TF_Int32Tensor:$max_num_elements 235 ); 236} 237 238def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { 239 let summary = "output = cond ? then_branch(input) : else_branch(input)"; 240 241 let description = [{ 242output = cond ? then_branch(input) : else_branch(input) 243 244cond: A Tensor. If the tensor is a scalar of non-boolean type, the 245 scalar is converted to a boolean according to the 246 following rule: if the scalar is a numerical value, non-zero means 247 True and zero means False; if the scalar is a string, non-empty 248 means True and empty means False. If the tensor is not a scalar, 249 being empty means False and being non-empty means True. 250input: A list of input tensors. 251then_branch: A function that takes 'inputs' and returns a list of 252 tensors, whose types are the same as what else_branch returns. 253else_branch: A function that takes 'inputs' and returns a list of 254 tensors. whose types are the same as what then_branch returns. 255 }]; 256 257 let arguments = (ins 258 TF_Tensor:$cond, 259 Variadic<TF_Tensor>:$input, 260 261 FlatSymbolRefAttr:$then_branch, 262 FlatSymbolRefAttr:$else_branch, 263 264 // Used to map StatelessIf and If op defined in TensorFlow to a common op. 265 BoolAttr:$is_stateless 266 ); 267 268 let results = (outs 269 Variadic<TF_Tensor>:$output 270 ); 271 272 TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; 273 TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; 274 TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; 275 TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; 276 277 let hasCanonicalizer = 1; 278 279 let extraClassDeclaration = [{ 280 // Resolve the then branch function. 281 // Prefer passing in SymbolTableCollection to reduce lookup costs by 282 // enabling reusing cached symbol table lookup. 283 func::FuncOp ResolveThenFunction(::mlir::SymbolTableCollection* table) { 284 if (table) 285 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, then_branchAttr()); 286 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>( 287 *this, then_branchAttr()); 288 } 289 // TODO(b/204997177): Deprecate and remove. 290 func::FuncOp then_function(::mlir::SymbolTableCollection* table = nullptr) { 291 return ResolveThenFunction(table); 292 } 293 294 // Resolve the else branch function. 295 // Prefer passing in SymbolTableCollection to reduce lookup costs by 296 // enabling reusing cached symbol table lookup. 297 func::FuncOp ResolveElseFunction(::mlir::SymbolTableCollection* table) { 298 if (table) 299 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, else_branchAttr()); 300 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>( 301 *this, else_branchAttr()); 302 } 303 // TODO(b/204997177): Deprecate and remove. 304 func::FuncOp else_function(::mlir::SymbolTableCollection* table = nullptr) { 305 return ResolveElseFunction(table); 306 } 307 }]; 308} 309 310def TF_YieldOp : TF_Op<"Yield", 311 [NoSideEffect, ReturnLike, Terminator, 312 ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> { 313 let summary = "Yield operation"; 314 315 let description = [{ 316 The "yield" operation represents a return operation within the conditional 317 and body of structured control flow (e.g., if and while). The operation 318 takes a variable number of operands and produces no results. The number and 319 types of inputs must match the signature of the operation that contains the 320 region. 321 }]; 322 323 let arguments = (ins Variadic<AnyType>:$operands); 324} 325 326def TF_IfRegionOp : TF_Op<"IfRegion", 327 [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { 328 let summary = "output = cond ? then_branch output : else_branch output"; 329 330 let description = [{ 331"output = cond ? then_branch output : else_branch output" 332 333cond: A Tensor. If the tensor is a scalar of non-boolean type, the 334 scalar is converted to a boolean according to the 335 following rule: if the scalar is a numerical value, non-zero means 336 True and zero means False; if the scalar is a string, non-empty 337 means True and empty means False. If the tensor is not a scalar, 338 being empty means False and being non-empty means True. 339then_branch: A region that computes the outputs of the op if cond = true. 340 It returns a list of tensors using tf.yield (as the terminator). The 341 types of these returned tensors is same as that of the else_branch 342else_branch: A region that computes the outputs of the op if cond = false. 343 It returns a list of tensors using tf.yield (as the terminator). The 344 types of these returned tensors is same as that of the then_branch 345 }]; 346 347 let arguments = (ins 348 0DTensorOf<[I1]>:$cond, 349 350 // Used to map StatelessIf and If op defined in TensorFlow to a common op. 351 BoolAttr:$is_stateless, 352 // Used to maintain function name when round-tripping 353 // between functional and regional control flow. This can be removed if 354 // the runtime does not require globally unique then/else branch function names. 355 OptionalAttr<StrAttr>:$_then_func_name, 356 OptionalAttr<StrAttr>:$_else_func_name 357 ); 358 359 let results = (outs 360 Variadic<TF_Tensor>:$output 361 ); 362 363 let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); 364 365 let hasRegionVerifier = 1; 366 367 let builders = [ 368 OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands, 369 "llvm::ArrayRef<::mlir::NamedAttribute>":$attributes, 370 "unsigned":$numRegions), 371 [{ 372 assert(numRegions == 2u && "mismatched number of regions"); 373 build($_builder, $_state, resultTypes, operands, attributes); 374 }]>]; 375 376 let hasCanonicalizer = 1; 377} 378 379def TF_LegacyCallOp : TF_Op<"LegacyCall", 380 [CallOpInterface, NoSideEffect]> { 381 let summary = 382 "returns `f(inputs)`, where `f` is a function."; 383 384 let description = [{ 385 The LegacyCall operation represents a direct call to a function that is 386 within the same symbol scope as the call and is mapped to a GraphDef node 387 with the function name as the op name. Unlike a PartitionedCall which 388 represents asynchronously executing a function across multiple devices, a 389 LegacyCall ignores specification for ops in the attached function and 390 instead executes it on the device assigned to this op. 391 }]; 392 393 let arguments = (ins 394 Variadic<TF_Tensor>:$args, 395 396 FlatSymbolRefAttr:$f, 397 DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference 398 ); 399 400 let results = (outs 401 Variadic<TF_Tensor>:$output 402 ); 403 404 let extraClassDeclaration = [{ 405 // Gets the argument operands to the called function. 406 operand_range getArgOperands() { return args(); } 407 408 // Returns the callee of this operation. 409 CallInterfaceCallable getCallableForCallee() { return fAttr(); } 410 411 // Returns the resolved callee function of this operation. 412 // Prefer passing in SymbolTableCollection to reduce lookup costs by 413 // enabling reusing cached symbol table lookup. 414 func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { 415 if (table) 416 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 417 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 418 } 419 // TODO(b/204997177): Deprecate and remove. 420 func::FuncOp func() { return ResolveFunc(nullptr); } 421 }]; 422} 423 424def TF_ParseExampleOp : TF_Op<"ParseExample", 425 [NoSideEffect, 426 AttrSizedResultSegments, 427 AttrSizedOperandSegments]> { 428 429 let summary = 430 "Transforms a vector of tf.Example protos (as strings) into typed tensors."; 431 432 let arguments = (ins 433 TF_StrTensor:$serialized, 434 TF_StrTensor:$names, 435 Variadic<TF_StrTensor>:$sparse_keys, 436 Variadic<TF_StrTensor>:$dense_keys, 437 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults, 438 439 TF_ShapeAttrArray:$dense_shapes, 440 DenseI32ArrayAttr:$result_segment_sizes, 441 DenseI32ArrayAttr:$operand_segment_sizes 442 ); 443 444 let results = (outs 445 Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types) 446 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types) 447 Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types) 448 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values // len(Tdense) 449 ); 450 451 TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>; 452 TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>; 453 TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>; 454 TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>; 455 456 let hasVerifier = 0; 457} 458 459def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", 460 [NoSideEffect, 461 AttrSizedResultSegments]> { 462 463 let summary = 464 "Transforms a vector of tf.Example protos (as strings) into typed tensors."; 465 466 let arguments = (ins 467 TF_StrTensor:$serialized, 468 TF_StrTensor:$names, 469 TF_StrTensor:$sparse_keys, 470 TF_StrTensor:$dense_keys, 471 TF_StrTensor:$ragged_keys, 472 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults, 473 474 ConfinedAttr<I64Attr, [IntMinValue<0>]>:$num_sparse, 475 TF_ShapeAttrArray:$dense_shapes, 476 DenseI32ArrayAttr:$result_segment_sizes 477 ); 478 479 let results = (outs 480 Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types) 481 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types) 482 Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types) 483 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values, // len(Tdense) 484 Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values, // len(ragged_value_types) 485 // = len(ragged_split_types) 486 Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits // len(ragged_split_types) 487 // = len(ragged_value_types) 488 ); 489 490 // The Verify(ParseExampleV2Op) function validates that the lengths and types 491 // of these attrs are compatible. 492 TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>; 493 TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>; 494 TF_DerivedResultTypeListAttr ragged_value_types = 495 TF_DerivedResultTypeListAttr<4>; 496 TF_DerivedResultTypeListAttr ragged_split_types = 497 TF_DerivedResultTypeListAttr<5>; 498 499 let hasVerifier = 1; 500} 501 502def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> { 503 let summary = "Placeholder op"; 504 505 let description = [{ 506Inserts a placeholder for a tensor that will be always fed. 507 }]; 508 509 let arguments = (ins 510 ); 511 512 let results = (outs 513 TF_Tensor:$output 514 ); 515 516 TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; 517} 518 519def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> { 520 let summary = "Placeholder op"; 521 522 let description = [{ 523 A placeholder op that passes through input when its output is not fed. 524 }]; 525 526 let arguments = (ins 527 TF_Tensor:$input 528 ); 529 530 let results = (outs 531 TF_Tensor:$output 532 ); 533 534 TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; 535 DerivedAttr shape = TF_DerivedResultShapeAttr; 536} 537 538def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", 539 [CallOpInterface, SymbolUserOpInterface]> { 540 let summary = 541 "returns `f(inputs)`, where `f`'s body is placed and partitioned."; 542 543 let description = [{ 544Asynchronously executes a function, potentially across multiple devices but 545within a single process. The kernel places and partitions a given function's 546underlying graph, and executes each of the partitioned subgraphs as a function. 547 }]; 548 549 let arguments = (ins 550 Variadic<TF_Tensor>:$args, 551 552 FlatSymbolRefAttr:$f, 553 StrAttr:$config, 554 StrAttr:$config_proto, 555 StrAttr:$executor_type 556 ); 557 558 let results = (outs 559 Variadic<TF_Tensor>:$output 560 ); 561 562 TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; 563 TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; 564 565 let extraClassDeclaration = [{ 566 // Gets the argument operands to the called function. 567 operand_range getArgOperands() { return args(); } 568 569 // Returns the callee of this operation. 570 CallInterfaceCallable getCallableForCallee() { return fAttr(); } 571 572 // Returns the resolved callee function of this operation. 573 // Prefer passing in SymbolTableCollection to reduce lookup costs by 574 // enabling reusing cached symbol table lookup. 575 func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { 576 if (table) 577 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 578 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 579 } 580 // TODO(b/204997177): Deprecate and remove. 581 func::FuncOp func() { return ResolveFunc(nullptr); } 582 583 // SymbolUserOpInterface verifier. 584 LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable); 585 }]; 586} 587 588def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { 589 let summary = [{ 590output = input; While (Cond(output)) { output = Body(output) } 591 }]; 592 593 let description = [{ 594output = input; While (Cond(output)) { output = Body(output) } 595 596input: A list of input tensors whose types are T. 597output: A list of output tensors whose types are T. 598cond: A function that takes 'input' and returns a tensor. If the tensor is 599 a scalar of non-boolean, the scalar is converted to a boolean 600 according to the following rule: if the scalar is a numerical 601 value, non-zero means True and zero means False; if the scalar is 602 a string, non-empty means True and empty means False. If the 603 tensor is not a scalar, non-emptiness means True and False 604 otherwise. 605body: A function that takes a list of tensors and returns another 606 list of tensors. Both lists have the same types as specified 607 by T. 608 }]; 609 610 let arguments = (ins 611 Variadic<TF_Tensor>:$input, 612 613 FlatSymbolRefAttr:$cond, 614 FlatSymbolRefAttr:$body, 615 ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations, 616 617 // Used to map StatelessWhile and While op defined in TensorFlow to a common 618 // op. 619 BoolAttr:$is_stateless, 620 621 // In TensorFlow, While has a special behavior where if `output_shapes` 622 // attribute is not empty, those shapes are used in its shape function 623 // as result shapes instead of propagating operand shapes as result shapes. 624 // This allows for different result shapes from operand shapes. While these 625 // shapes are imported and set as a part of the result type, there is no 626 // indicator differentiating between having no output shapes compared to 627 // having all unranked shapes. Thus this attribute is set to determine 628 // which shape function behavior to use for this op, specifically 629 // propagating operand shapes as result shapes when this attribute is not 630 // set, or preserving result shapes as is when this attribute is set. 631 UnitAttr:$shape_invariant 632 ); 633 634 let results = (outs 635 Variadic<TF_Tensor>:$output 636 ); 637 638 TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; 639 TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; 640 641 let extraClassDeclaration = [{ 642 // Get the condition function. 643 // Prefer passing in SymbolTableCollection to reduce lookup costs by 644 // enabling reusing cached symbol table lookup. 645 func::FuncOp ResolveCondFunction(::mlir::SymbolTableCollection* table) { 646 if (table) 647 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr()); 648 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr()); 649 } 650 // TODO(b/204997177): Deprecate and remove. 651 func::FuncOp cond_function() { return ResolveCondFunction(nullptr); } 652 653 // Get the body function. 654 // Prefer passing in SymbolTableCollection to reduce lookup costs by 655 // enabling reusing cached symbol table lookup. 656 func::FuncOp ResolveBodyFunction(::mlir::SymbolTableCollection* table) { 657 if (table) 658 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr()); 659 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr()); 660 } 661 // TODO(b/204997177): Deprecate and remove. 662 func::FuncOp body_function() { return ResolveBodyFunction(nullptr); } 663 }]; 664} 665 666def TF_WhileRegionOp : TF_Op<"WhileRegion", 667 [DeclareOpInterfaceMethods<LoopLikeOpInterface>, 668 SingleBlockImplicitTerminator<"YieldOp">]> { 669 let summary = "while operation"; 670 let description = [{ 671 The tf.WhileRegion op represents a while loop using 2 regions and a set of 672 iteration variables. The iteration variables maintained by this Op have the 673 same types as the inputs. The Op executes a while loop described by the 674 following pseudo code: 675 676 ``` 677 func WhileRegionOp(inputs) { 678 iteration_vars = inputs; 679 while (cond(iteration_vars)) { 680 iteration_vars = body(iteration_vars); 681 } 682 return iteration_vars; 683 } 684 ``` 685 686 `cond` is the condition region and `body` is the body region. Both these 687 regions accept the current value of the iteration variables as inputs. The 688 condition region returns a tensor<i1> which, if false, will exit the loop. 689 The body region computes new values of the iteration variables. The iteration 690 variables are initialized to the Op input, and the results of the 691 tf.WhileRegion op are the final values of the iteration variables. 692 693 This implies that the operand and result types for tf.WhileRegion should be 694 the same. Note that the condition and body regions can implicitly capture 695 loop invariant values directly. In canonical form, iteration variables that 696 pass through the loop body unmodified are converted to implicitly captured 697 references to their values outside the loop. 698 }]; 699 700 let arguments = (ins 701 Variadic<AnyTensor>:$input, 702 703 ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations, 704 705 // Used to map StatelessWhile and While op defined in TensorFlow to a common 706 // op. 707 BoolAttr:$is_stateless, 708 709 // In TensorFlow, While has a special behavior where if `output_shapes` 710 // attribute is not empty, those shapes are used in its shape function 711 // as result shapes instead of propagating operand shapes as result shapes. 712 // This allows for different result shapes from operand shapes. While these 713 // shapes are imported and set as a part of the result type, there is no 714 // indicator differentiating between having no output shapes compared to 715 // having all unranked shapes. Thus this attribute is set to determine 716 // which shape function behavior to use for this op, specifically 717 // propagating operand shapes as result shapes when this attribute is not 718 // set, or preserving result shapes as is when this attribute is set. 719 UnitAttr:$shape_invariant 720 ); 721 let results = (outs Variadic<AnyTensor>:$output); 722 723 let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); 724 725 let hasVerifier = 1; 726 727 let hasCanonicalizer = 1; 728} 729 730def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { 731 let summary = "List of the given size with empty elements."; 732 733 let description = [{ 734element_shape: the shape of the future elements of the list 735num_elements: the number of elements to reserve 736handle: the output list 737element_dtype: the desired type of elements in the list. 738 }]; 739 740 let arguments = (ins 741 TF_I32OrI64Tensor:$element_shape, 742 TF_Int32Tensor:$num_elements 743 ); 744} 745 746def TF_VarHandleOp : TF_Op<"VarHandleOp", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> { 747 let summary = "Creates a handle to a Variable resource from its name."; 748 749 let description = [{ 750container: the container this variable is placed in. 751shared_name: the name by which this variable is referred to. 752dtype and shape: attributes representing the data type and shape held in the 753 variable. 754 755Example: 756 resource_variable_ops.var_handle_op( 757 dtype=dtypes.int32, shape=[8, 16], container="foo", shared_name="bar") 758 returns a handle for a variable with name "bar" in container "foo", and the 759 variable holds a tensor of shape [8, 16] and dtype int32. 760 }]; 761 762 let arguments = (ins 763 DefaultValuedStrAttr<StrAttr, "">:$container, 764 DefaultValuedStrAttr<StrAttr, "">:$shared_name 765 ); 766 767 let results = (outs 768 Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource 769 ); 770 771 DerivedTypeAttr dtype = DerivedTypeAttr< 772 "return getElementTypeOrSelf(resource_subtype());">; 773 DerivedAttr shape = DerivedAttr< 774 "ShapedType", 775 "return resource_subtype().cast<ShapedType>();", 776 [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; 777 778 let extraClassDeclaration = [{ 779 TensorType resource_subtype() { return resource_type().getSubtypes()[0]; } 780 781 ResourceType resource_type() { 782 return getElementTypeOrSelf(resource()).cast<TF::ResourceType>(); 783 } 784 }]; 785 786 let hasVerifier = 1; 787} 788 789def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect, TF_NoConstantFold]> { 790 let summary = [{ 791An op which shards the input based on the given sharding attribute. 792 }]; 793 794 let arguments = (ins 795 TF_Tensor:$input, 796 797 DefaultValuedStrAttr<StrAttr, "">:$sharding, 798 OptionalAttr<StrAttr>:$_XlaSharding 799 ); 800 801 let results = (outs 802 TF_Tensor:$output 803 ); 804 805 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 806} 807 808def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { 809 let summary = "Fetches multiple values from infeed as an XLA tuple."; 810 811 let arguments = (ins 812 OptionalAttr<StrAttr>:$_XlaSharding, 813 OptionalAttr<ArrayAttr>:$layouts 814 ); 815 816 let results = (outs 817 Variadic<TF_Tensor>:$outputs 818 ); 819 820 TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>; 821 TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>; 822} 823 824// TODO(b/177675373): Make dtypes and shapes derived attributes, 825// use more general solution. 826def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> { 827 let summary = [{ 828Feeds multiple Tensor values into the computation as an XLA tuple. 829 }]; 830 831 let arguments = (ins 832 Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs, 833 834 ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes, 835 TF_ShapeAttrArray:$shapes, 836 DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts, 837 DefaultValuedAttr<I64Attr, "-1">:$device_ordinal 838 ); 839 840 let results = (outs); 841} 842 843// This op is manually defined because the attribute name `template` (which is 844// a keyword) is changed to `strtemplate`. 845def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> { 846 let summary = "Formats a string template using a list of tensors."; 847 848 let description = [{ 849Formats a string template using a list of tensors, pretty-printing tensor summaries. 850 }]; 851 852 let arguments = (ins 853 Variadic<TF_Tensor>:$inputs, 854 855 DefaultValuedStrAttr<StrAttr, "%s">:$strtemplate, 856 DefaultValuedStrAttr<StrAttr, "%s">:$placeholder, 857 DefaultValuedAttr<I64Attr, "3">:$summarize 858 ); 859 860 let results = (outs 861 TF_StrTensor:$output 862 ); 863 864 TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; 865} 866 867//===----------------------------------------------------------------------===// 868// tf.data ops 869//===----------------------------------------------------------------------===// 870 871def TF_ReduceDatasetOp : TF_Op<"ReduceDataset", [SameVariadicOperandSize]> { 872 let summary = [{ 873 Reduces the input dataset to a singleton using a reduce function. 874 }]; 875 876 let arguments = (ins 877 TF_VariantTensor:$input_dataset, 878 Variadic<TF_Tensor>:$initial_state, 879 Variadic<TF_Tensor>:$other_arguments, 880 881 SymbolRefAttr:$f, 882 ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$Tstate, 883 ConfinedAttr<TypeArrayAttr, [ArrayMinCount<0>]>:$Targuments, 884 ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types, 885 ConfinedAttr<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes, 886 DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism 887 ); 888 889 let results = (outs 890 Variadic<TF_Tensor>:$components 891 ); 892} 893 894// Manually defined to restrict result type to `I1Tensor`. 895def TF_ToBoolOp : TF_Op<"ToBool", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> { 896 let summary = "Converts a tensor to a scalar predicate."; 897 898 let description = [{ 899Converts a tensor to a scalar predicate with the following rules: 900 901- For 0D tensors, truthiness is determined by comparing against a "zero" 902 value. For numerical types it is the obvious zero. For strings it is the 903 empty string. 904 905- For >0D tensors, truthiness is determined by looking at the number of 906 elements. If has zero elements, then the result is false. Otherwise the 907 result is true. 908 909This matches the behavior of If and While for determining if a tensor counts 910as true/false for a branch condition. 911 }]; 912 913 let arguments = (ins 914 TF_Tensor:$input 915 ); 916 917 let results = (outs 918 I1Tensor:$output 919 ); 920 921 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 922 923 let hasCanonicalizer = 1; 924 925 let extraClassDeclaration = [{ 926 // InferTypeOpInterface: 927 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 928 return ArraysAreCastCompatible(l, r); 929 } 930 }]; 931} 932 933def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { 934 let summary = "Computes the Bessel i0e function of `x` element-wise."; 935 936 let description = [{ 937Exponentially scaled modified Bessel function of order 0 defined as 938`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. 939 940This function is faster and numerically stabler than `bessel_i0(x)`. 941 }]; 942 943 let arguments = (ins 944 TF_FloatTensor:$x 945 ); 946 947 let results = (outs 948 TF_FloatTensor:$y 949 ); 950 951 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 952} 953 954def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { 955 let summary = "Computes the Bessel i1e function of `x` element-wise."; 956 957 let description = [{ 958Exponentially scaled modified Bessel function of order 0 defined as 959`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. 960 961This function is faster and numerically stabler than `bessel_i1(x)`. 962 }]; 963 964 let arguments = (ins 965 TF_FloatTensor:$x 966 ); 967 968 let results = (outs 969 TF_FloatTensor:$y 970 ); 971 972 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 973} 974 975def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface, SymbolUserOpInterface]> { 976 let summary = "Calls a function placed on a specified TPU device."; 977 978 let arguments = (ins 979 Variadic<TF_Tensor>:$args, 980 TF_Int32Tensor:$device_ordinal, 981 982 SymbolRefAttr:$f, 983 DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh 984 ); 985 986 let results = (outs 987 Variadic<TF_Tensor>:$output 988 ); 989 990 TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; 991 TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; 992 993 let extraClassDeclaration = [{ 994 // Gets the argument operands to the called function. 995 operand_range getArgOperands() { return args(); } 996 997 // Returns the callee of this operation. 998 CallInterfaceCallable getCallableForCallee() { return fAttr(); } 999 1000 // Returns the resolved callee function of this operation. 1001 // Prefer passing in SymbolTableCollection to reduce lookup costs by 1002 // enabling reusing cached symbol table lookup. 1003 func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { 1004 if (table) 1005 return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 1006 return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr()); 1007 } 1008 // TODO(b/204997177): Deprecate and remove. 1009 func::FuncOp func() { return ResolveFunc(nullptr); } 1010 1011 // SymbolUserOpInterface verifier. 1012 LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable); 1013 }]; 1014} 1015 1016def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> { 1017 let summary = "Outputs random integers from a uniform distribution."; 1018 1019 let description = [{ 1020The generated values are uniform integers covering the whole range of `dtype`. 1021 }]; 1022 1023 let arguments = (ins 1024 Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, 1025 TF_Int64Tensor:$algorithm, 1026 TF_I32OrI64Tensor:$shape 1027 ); 1028 1029 let results = (outs 1030 TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output 1031 ); 1032 1033 TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; 1034 TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; 1035} 1036 1037// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for 1038// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64 1039// while TensorFlow CPU/GPU kernels only support i32 and i64. 1040def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> { 1041 let summary = "Outputs random integers from a uniform distribution."; 1042 1043 let description = [{ 1044The generated values are uniform integers in the range `[minval, maxval)`. 1045The lower bound `minval` is included in the range, while the upper bound 1046`maxval` is excluded. 1047 1048The random integers are slightly biased unless `maxval - minval` is an exact 1049power of two. The bias is small for values of `maxval - minval` significantly 1050smaller than the range of the output (either `2^32` or `2^64`). 1051 }]; 1052 1053 let arguments = (ins 1054 Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource, 1055 TF_Int64Tensor:$algorithm, 1056 TF_I32OrI64Tensor:$shape, 1057 TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval, 1058 TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval 1059 ); 1060 1061 let results = (outs 1062 TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output 1063 ); 1064 1065 TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>; 1066 TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>; 1067} 1068 1069def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> { 1070 let summary = "Flushes and closes the summary writer."; 1071 1072 let description = [{ 1073Also removes it from the resource manager. To reopen, use another 1074CreateSummaryFileWriter op. 1075 1076writer: A handle to the summary writer resource. 1077 }]; 1078 1079 let arguments = (ins 1080 Arg<TF_ResourceTensor, "", [TF_SummaryFree]>:$writer 1081 ); 1082 1083 let results = (outs); 1084} 1085 1086// TODO(b/168035831): Model db_uri read/write. 1087def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> { 1088 let summary = "Creates summary database writer accessible by given resource handle."; 1089 1090 let description = [{ 1091This can be used to write tensors from the execution graph directly 1092to a database. Only SQLite is supported right now. This function 1093will create the schema if it doesn't exist. Entries in the Users, 1094Experiments, and Runs tables will be created automatically if they 1095don't already exist. 1096 1097writer: Handle to SummaryWriter resource to overwrite. 1098db_uri: For example "file:/tmp/foo.sqlite". 1099experiment_name: Can't contain ASCII control characters or <>. Case 1100 sensitive. If empty, then the Run will not be associated with any 1101 Experiment. 1102run_name: Can't contain ASCII control characters or <>. Case sensitive. 1103 If empty, then each Tag will not be associated with any Run. 1104user_name: Must be valid as both a DNS label and Linux username. If 1105 empty, then the Experiment will not be associated with any User. 1106 }]; 1107 1108 let arguments = (ins 1109 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1110 TF_StrTensor:$db_uri, 1111 TF_StrTensor:$experiment_name, 1112 TF_StrTensor:$run_name, 1113 TF_StrTensor:$user_name 1114 ); 1115 1116 let results = (outs); 1117} 1118 1119// TODO(b/168035831): Model logdir read/write. 1120def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> { 1121 let summary = "Creates a summary file writer accessible by the given resource handle."; 1122 1123 let description = [{ 1124writer: A handle to the summary writer resource 1125logdir: Directory where the event file will be written. 1126max_queue: Size of the queue of pending events and summaries. 1127flush_millis: How often, in milliseconds, to flush the pending events and 1128 summaries to disk. 1129filename_suffix: Every event file's name is suffixed with this suffix. 1130 }]; 1131 1132 let arguments = (ins 1133 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1134 TF_StrTensor:$logdir, 1135 TF_Int32Tensor:$max_queue, 1136 TF_Int32Tensor:$flush_millis, 1137 TF_StrTensor:$filename_suffix 1138 ); 1139 1140 let results = (outs); 1141} 1142 1143def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> { 1144 let summary = "Flushes the writer's unwritten events."; 1145 1146 let description = [{ 1147writer: A handle to the summary writer resource. 1148 }]; 1149 1150 let arguments = (ins 1151 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer 1152 ); 1153 1154 let results = (outs); 1155} 1156 1157def TF_ImportEventOp : TF_Op<"ImportEvent", []> { 1158 let summary = "Outputs a `tf.Event` protocol buffer."; 1159 1160 let description = [{ 1161When CreateSummaryDbWriter is being used, this op can be useful for 1162importing data from event logs. 1163 1164writer: A handle to a summary writer. 1165event: A string containing a binary-encoded tf.Event proto. 1166 }]; 1167 1168 let arguments = (ins 1169 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1170 TF_StrTensor:$event 1171 ); 1172 1173 let results = (outs); 1174} 1175 1176def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> { 1177 let summary = "Returns a handle to be used to access a summary writer."; 1178 1179 let description = [{ 1180The summary writer is an in-graph resource which can be used by ops to write 1181summaries to event files. 1182 1183writer: the summary writer resource. Scalar handle. 1184 }]; 1185 1186 let arguments = (ins 1187 StrAttr:$shared_name, 1188 StrAttr:$container 1189 ); 1190 1191 let results = (outs 1192 Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer 1193 ); 1194} 1195 1196def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { 1197 let summary = "Writes a `Summary` protocol buffer with audio."; 1198 1199 let description = [{ 1200The summary has up to `max_outputs` summary values containing audio. The 1201audio is built from `tensor` which must be 3-D with shape `[batch_size, 1202frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are 1203assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. 1204 1205The `tag` argument is a scalar `Tensor` of type `string`. It is used to 1206build the `tag` of the summary values: 1207 1208* If `max_outputs` is 1, the summary value tag is '*tag*/audio'. 1209* If `max_outputs` is greater than 1, the summary value tags are 1210 generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. 1211 1212writer: A handle to a summary writer. 1213step: The step to write the summary for. 1214tag: Scalar. Used to build the `tag` attribute of the summary values. 1215tensor: 2-D of shape `[batch_size, frames]`. 1216sample_rate: The sample rate of the signal in hertz. 1217max_outputs: Max number of batch elements to generate audio for. 1218 }]; 1219 1220 let arguments = (ins 1221 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1222 TF_Int64Tensor:$step, 1223 TF_StrTensor:$tag, 1224 TF_Float32Tensor:$tensor, 1225 TF_Float32Tensor:$sample_rate, 1226 1227 ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs 1228 ); 1229 1230 let results = (outs); 1231} 1232 1233def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> { 1234 let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`."; 1235 1236 let description = [{ 1237writer: Handle of `SummaryWriter`. 1238step: The step to write the summary for. 1239tensor: A scalar string of the serialized tf.GraphDef proto. 1240 }]; 1241 1242 let arguments = (ins 1243 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1244 TF_Int64Tensor:$step, 1245 TF_StrTensor:$tensor 1246 ); 1247 1248 let results = (outs); 1249} 1250 1251def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> { 1252 let summary = "Writes a histogram summary."; 1253 1254 let description = [{ 1255The generated 1256[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 1257has one summary value containing a histogram for `values`. 1258 1259This op reports an `InvalidArgument` error if any value is not finite. 1260 1261writer: A handle to a summary writer. 1262step: The step to write the summary for. 1263tag: Scalar. Tag to use for the `Summary.Value`. 1264values: Any shape. Values to use to build the histogram. 1265 }]; 1266 1267 let arguments = (ins 1268 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1269 TF_Int64Tensor:$step, 1270 TF_StrTensor:$tag, 1271 TF_IntOrFpTensor:$values 1272 ); 1273 1274 let results = (outs); 1275 1276 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; 1277} 1278 1279def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> { 1280 let summary = "Writes a `Summary` protocol buffer with images."; 1281 1282 let description = [{ 1283The summary has up to `max_images` summary values containing images. The 1284images are built from `tensor` which must be 4-D with shape `[batch_size, 1285height, width, channels]` and where `channels` can be: 1286 1287* 1: `tensor` is interpreted as Grayscale. 1288* 3: `tensor` is interpreted as RGB. 1289* 4: `tensor` is interpreted as RGBA. 1290 1291The images have the same number of channels as the input tensor. For float 1292input, the values are normalized one image at a time to fit in the range 1293`[0, 255]`. `uint8` values are unchanged. The op uses two different 1294normalization algorithms: 1295 1296* If the input values are all positive, they are rescaled so the largest one 1297 is 255. 1298 1299* If any input value is negative, the values are shifted so input value 0.0 1300 is at 127. They are then rescaled so that either the smallest value is 0, 1301 or the largest one is 255. 1302 1303The `tag` argument is a scalar `Tensor` of type `string`. It is used to 1304build the `tag` of the summary values: 1305 1306* If `max_images` is 1, the summary value tag is '*tag*/image'. 1307* If `max_images` is greater than 1, the summary value tags are 1308 generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. 1309 1310The `bad_color` argument is the color to use in the generated images for 1311non-finite input values. It is a `unit8` 1-D tensor of length `channels`. 1312Each element must be in the range `[0, 255]` (It represents the value of a 1313pixel in the output image). Non-finite values in the input tensor are 1314replaced by this tensor in the output image. The default value is the color 1315red. 1316 1317writer: A handle to a summary writer. 1318step: The step to write the summary for. 1319tag: Scalar. Used to build the `tag` attribute of the summary values. 1320tensor: 4-D of shape `[batch_size, height, width, channels]` where 1321 `channels` is 1, 3, or 4. 1322max_images: Max number of batch elements to generate images for. 1323bad_color: Color to use for pixels with non-finite values. 1324 }]; 1325 1326 let arguments = (ins 1327 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1328 TF_Int64Tensor:$step, 1329 TF_StrTensor:$tag, 1330 TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor, 1331 TF_Uint8Tensor:$bad_color, 1332 1333 ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images 1334 ); 1335 1336 let results = (outs); 1337 1338 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; 1339} 1340 1341def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> { 1342 let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers."; 1343 1344 let description = [{ 1345writer: A handle to a summary writer. 1346step: The step to write the summary for. 1347tensor: A tensor holding one or more serialized `Summary` protobufs to write. 1348 }]; 1349 1350 let arguments = (ins 1351 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1352 TF_Int64Tensor:$step, 1353 TF_StrTensor:$tensor 1354 ); 1355 1356 let results = (outs); 1357} 1358 1359def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> { 1360 let summary = "Writes a `Summary` protocol buffer with scalar values."; 1361 1362 let description = [{ 1363The input `tag` and `value` must have the scalars. 1364 1365writer: A handle to a summary writer. 1366step: The step to write the summary for. 1367tag: Tag for the summary. 1368value: Value for the summary. 1369 }]; 1370 1371 let arguments = (ins 1372 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1373 TF_Int64Tensor:$step, 1374 TF_StrTensor:$tag, 1375 TF_IntOrFpTensor:$value 1376 ); 1377 1378 let results = (outs); 1379 1380 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; 1381} 1382 1383def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> { 1384 let summary = "Outputs a `Summary` protocol buffer with a tensor."; 1385 1386 let description = [{ 1387writer: A handle to a summary writer. 1388step: The step to write the summary for. 1389tensor: A tensor to serialize. 1390tag: The summary's tag. 1391summary_metadata: Serialized SummaryMetadata protocol buffer containing 1392 plugin-related metadata for this summary. 1393 }]; 1394 1395 let arguments = (ins 1396 Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer, 1397 TF_Int64Tensor:$step, 1398 TF_Tensor:$tensor, 1399 TF_StrTensor:$tag, 1400 TF_StrTensor:$summary_metadata 1401 ); 1402 1403 let results = (outs); 1404 1405 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; 1406} 1407 1408def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", [NoSideEffect]> { 1409 let summary = [{ 1410Placeholder device ordinal that represents device ordinal of a replicated op. 1411 }]; 1412 1413 let description = [{ 1414This op can be used when certain rewrite passes materialize ops that require a 1415device ordinal of a replicated op but replication logic has been abstracted away 1416using tf_device.replicate op. Subsequent rewrite passes must replace this op with 1417a constant output that represents the correct device ordinal of the replicated 1418operations inside a TPU host. 1419 }]; 1420 1421 let arguments = (ins); 1422 1423 let results = (outs 1424 TF_Int64Tensor:$device_ordinal 1425 ); 1426} 1427 1428def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> { 1429 let summary = [{ 1430An op that groups a list of partitioned inputs together. This op 1431 }]; 1432 1433 let arguments = (ins 1434 Variadic<TF_Tensor>:$inputs, 1435 1436 DefaultValuedAttr<I64Attr, "0">:$partition_dim, 1437 OptionalAttr<StrAttr>:$_XlaSharding 1438 ); 1439 1440 let results = (outs 1441 TF_Tensor:$output 1442 ); 1443 1444 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 1445 TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; 1446} 1447 1448def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> { 1449 let summary = [{ 1450An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned 1451 }]; 1452 1453 let description = [{ 1454outputs outside the XLA computation. 1455 }]; 1456 1457 let arguments = (ins 1458 TF_Tensor:$inputs, 1459 1460 DefaultValuedAttr<I64Attr, "0">:$partition_dim, 1461 OptionalAttr<StrAttr>:$_XlaSharding 1462 ); 1463 1464 let results = (outs 1465 Variadic<TF_Tensor>:$output 1466 ); 1467 1468 TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; 1469 TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>; 1470} 1471 1472// Declares symbol reference attribute `shape_inference_graph` to be optional 1473// unlike the TensorFlow definition. This is required to support ops that use 1474// empty string value for the attribute to signify missing. 1475def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", [TF_SendSideEffect, TF_RecvSideEffect, TF_XlaHostComputeSideEffect]> { 1476 let summary = [{ 1477A pseudo-op to represent host-side computation in an XLA program. 1478 }]; 1479 1480 let arguments = (ins 1481 Arg<Variadic<TF_Tensor>, [{A list of tensors that will be sent to the host.}]>:$inputs, 1482 1483 StrArrayAttr:$ancestors, 1484 TF_ShapeAttrArray:$shapes, 1485 OptionalAttr<SymbolRefAttr>:$shape_inference_graph, 1486 StrAttr:$key, 1487 DefaultValuedStrAttr<StrAttr, "">:$send_key, 1488 DefaultValuedStrAttr<StrAttr, "">:$recv_key, 1489 DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns, 1490 DefaultValuedAttr<I64Attr, "0">:$tpu_core 1491 ); 1492 1493 let results = (outs 1494 Res<Variadic<TF_Tensor>, [{A list of tensors that will be returned to the device.}]>:$outputs 1495 ); 1496 1497 TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; 1498 TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; 1499} 1500 1501def TF_ConfigureAndInitializeGlobalTPUOp : TF_Op<"ConfigureAndInitializeGlobalTPU", []> { 1502 let summary = [{ 1503An op that initialize the TPU system in a multi-client set up. 1504 }]; 1505 1506 let description = [{ 1507Initializes global TPU system for mutli-client execution. 1508 1509This op does the work of both ConfigureDistributedTpuOp and 1510InitializeHostForDistributedTpuOp, and outputs the latter's result. 1511 }]; 1512 1513 let arguments = (ins); 1514 1515 let results = (outs 1516 Res<TF_Int32Tensor, [{A vector containing the global TPU id of each TPU on the host.}]>:$output 1517 ); 1518} 1519 1520def TF_ShutdownTPUSystemOp : TF_Op<"ShutdownTPUSystem", []> { 1521 let summary = [{ 1522An op that shuts down the TPU system. 1523 }]; 1524 1525 let arguments = (ins); 1526 let results = (outs 1527 TF_BoolTensor:$success 1528 ); 1529} 1530 1531// Internal op for testing value-based side-effects for non-resource values. 1532// TODO(mgester) We should have an extension of TF dialect only for testing so 1533// TF dialect is not polluted with test ops. 1534def TF__InternalTestNonResourceValueSideEffects_ : TF_Op<"_InternalTestNonResourceValueSideEffects_", []> { 1535 let summary = "Internal op for testing only"; 1536 1537 let arguments = (ins 1538 Arg<TF_StrTensor,"", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$key 1539 ); 1540 let results = (outs); 1541} 1542 1543def TF__InternalTestMustExecuteTrait_ : TF_Op<"_InternalTestMustExecuteTrait_", [TF_MustExecute]> { 1544 let summary = "Internal op for testing only"; 1545 1546 let arguments = (ins); 1547 let results = (outs); 1548} 1549 1550def TF_SetStaticDimensionBoundsOp : TF_Op<"SetStaticDimensionBounds", []> { 1551 let summary = "Op used to indicate to the compiler and runtime the static bounds of a tensor."; 1552 let description = [{ 1553The information passed through this op can possibly be used by the compiler and 1554runtime to perform certain optimizations such as more efficient DMAs. The 1555bounds passed via this op should be considered advisory only, and depending on 1556the implementation, might do nothing and simply be an identity 1557 1558`input`: The tensor that has dynamic dimensions. 1559`static_shape`: The static shape of the tensor, corresponds to the maximum bounds of each dimension. 1560`output` is the input tensor with no changes done to it. 1561 1562Example usage: 1563 1564def tpu_call(args): 1565 def model_fn(args): 1566 # do something with dynamic tensor 1567 1568 @function.Defun(capture_resource_var_by_value=False) 1569 def tpu_subgraph(): 1570 return tf.tpu.rewrite(model_fn, args) 1571 1572 return tf.raw_ops.TPUPartitionedCall( 1573 args=tpu_subgraph.captured_inputs, 1574 Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], 1575 f=tpu_subgraph, 1576 device_ordinal=[0]) 1577 1578static_shape = tf.placeholder(tf.int32, shape=([3]), name='static_size') 1579 1580w = tf.Variable(tf.constant([[1.0], [2.0], [3.0]]), name='w') 1581 1582w_dyn = tf.SetDynamicDimensionBounds(w, static_size]) 1583tpu_call([w_dyn]) 1584}]; 1585 let arguments = (ins 1586 TF_Tensor:$input, 1587 TF_I32OrI64Tensor:$static_shape 1588 ); 1589 1590 let hasVerifier = 1; 1591 1592 let results = (outs 1593 TF_Tensor:$output 1594 ); 1595} 1596 1597def TF_TPUCompileMlirAndExecuteOp : TF_Op<"TPUCompileMlirAndExecute", [AttrSizedOperandSegments]> { 1598 let summary = "Op that compiles a computation in MLIR into a TPU program, and loads and executes it on a TPU device."; 1599 1600 let description = [{ 1601For the internal use of the TPU compiler. 1602 1603'static_shapes' are tensors specifying the maximum dimension sizes for the tensors specified in `dynamic_operands`. 1604'args' are inputs to the TPU computation. 1605'operands_with_static_shape' are the indices of the operands that have a maximal static shape specified. 1606'mlir_module' is a serialized MLIR module with a `main` function that contains 1607target computation. 1608'metadata' is a serialized TPUCompileMetadataProto describing the shapes and 1609types of the inputs to the computation, as well as a mapping onto the TPU pod 1610topology. 1611'producer_name' is a string describing the name of the framework that add support for running this portion of the model on TPUs. 1612 }]; 1613 1614 let arguments = (ins 1615 Variadic<TF_Tensor>:$args, 1616 Variadic<TF_Int64Tensor>:$static_shapes, 1617 OptionalAttr<I32ArrayAttr>:$operands_with_static_shape, 1618 DefaultValuedStrAttr<StrAttr, "">:$mlir_module, 1619 StrAttr:$metadata, 1620 StrAttr:$producer_name 1621 ); 1622 1623 let results = (outs 1624 TF_Tensor:$rendezvous_key_base, 1625 Variadic<TF_Tensor>:$results 1626 ); 1627 1628 TF_DerivedOperandTypeListAttr Targs = TF_DerivedOperandTypeListAttr<0>; 1629 TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>; 1630} 1631 1632#endif // TF_OPS 1633