xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the operation definition file for TensorFlow.
17//
18// This file contains TensorFlow ops whose definitions are amended to fix
19// issues or provide more information. In this file you have full control
20// of the op definition; all changes will be retained with subsequent
21// refreshes.
22//
23// This file includes another file, `tf_generated_ops.td`, which contains
24// all ops whose definitions are generated from TensorFlow codebase.
25// Changes made there are not respected.
26
27#ifndef TF_OPS
28#define TF_OPS
29
30include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
31include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
32include "mlir/Interfaces/CallInterfaces.td"
33include "mlir/Interfaces/ControlFlowInterfaces.td"
34include "mlir/Interfaces/InferTypeOpInterface.td"
35include "mlir/Interfaces/LoopLikeInterface.td"
36include "mlir/Interfaces/SideEffectInterfaces.td"
37include "mlir/IR/OpAsmInterface.td"
38include "mlir/IR/OpBase.td"
39include "mlir/IR/SymbolInterfaces.td"
40
41class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
42  let results = (outs
43    TF_VariantTensor:$handle
44  );
45
46  TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
47
48  let hasVerifier = 1;
49
50  DerivedTypeAttr element_dtype = DerivedTypeAttr<
51      "return getElementTypeOrSelf(element_type());">;
52
53  let extraClassDeclaration = [{
54    // Returns type of the TensorList element produced by this op.
55    TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
56
57    // Returns data type of the result handle. Returned type contains type of
58    // the TensorList element as a subtype.
59    VariantType handle_dtype() {
60      return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>();
61    }
62  }];
63}
64
65def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
66  let summary = [{
67An n-way switch statement which calls a single branch function.
68  }];
69
70  let description = [{
71An n-way switch statement, implementing the following:
72    ```
73    switch (branch_index) {
74      case 0:
75        output = branches[0](input);
76        break;
77      case 1:
78        output = branches[1](input);
79        break;
80      ...
81      case [[nbranches-1]]:
82      default:
83        output = branches[nbranches-1](input);
84        break;
85    }
86    ```
87  }];
88
89  let arguments = (ins
90    I32Tensor:$branch_index,
91    Variadic<TF_Tensor>:$input,
92
93    ConfinedAttr<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
94
95    // Used to map StatelessCase and Case op defined in TensorFlow to a common
96    // op.
97    BoolAttr:$is_stateless
98  );
99
100  let results = (outs
101    Variadic<TF_Tensor>:$output
102  );
103
104  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
105  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
106  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
107
108  let hasCanonicalizer = 1;
109
110  let hasVerifier = 1;
111
112
113 let extraClassDeclaration = [{
114    int num_branches() { return branches().size(); }
115
116    // Gets function corresponding branch # `index`.
117    // Prefer passing in SymbolTableCollection to reduce lookup costs by
118    // enabling reusing cached symbol table lookup.
119    func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) {
120      auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>();
121      if (table)
122        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref);
123      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, flat_sym_ref);
124    }
125    // TODO(b/204997177): Deprecate and remove.
126    func::FuncOp branch_function(int index) { return ResolveBranchFunction(nullptr, index); }
127
128    // Resolve all branch functions.
129    // Prefer passing in SymbolTableCollection to reduce lookup costs by
130    // enabling reusing cached symbol table lookup.
131    void ResolveBranchFunctions(::mlir::SymbolTableCollection* table,
132        SmallVectorImpl<func::FuncOp> &functions) {
133      functions.reserve(num_branches());
134      for (int idx : llvm::seq<int>(0, num_branches()))
135        functions.push_back(ResolveBranchFunction(table, idx));
136    }
137    // TODO(b/204997177): Deprecate and remove.
138    void get_branch_functions(SmallVectorImpl<func::FuncOp> &functions) {
139      return ResolveBranchFunctions(nullptr, functions);
140    }
141  }];
142}
143
144def TF_CaseRegionOp : TF_Op<"CaseRegion",
145      [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
146  let summary = [{
147An n-way switch statement which calls a single branch function.
148  }];
149
150  let description = [{
151An n-way switch statement, implementing the following:
152    ```
153    switch (branch_index) {
154      case 0:
155        output = branches[0](input);
156        break;
157      case 1:
158        output = branches[1](input);
159        break;
160      ...
161      case [[nbranches-1]]:
162      default:
163        output = branches[nbranches-1](input);
164        break;
165    }
166    ```
167  }];
168
169  let arguments = (ins
170    I32Tensor:$branch_index,
171
172    // Used to map StatelessCase and Case op defined in TensorFlow to a common
173    // op.
174    BoolAttr:$is_stateless
175  );
176
177  let results = (outs
178    Variadic<TF_Tensor>:$output
179  );
180
181  let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
182
183  let hasVerifier = 1;
184
185  let hasCanonicalizer = 1;
186
187}
188
189// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
190// its type encoding the tensor's shape and data type.
191def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
192    DeclareOpInterfaceMethods<InferTypeOpInterface>,
193    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
194  let summary = "Constant tensor op";
195
196  let arguments = (ins
197    ElementsAttr:$value
198  );
199
200  let results = (outs
201    TF_Tensor:$output
202  );
203
204  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
205
206  let builders = [
207    OpBuilder<(ins "Attribute":$value)>,
208    OpBuilder<(ins "Type":$type, "Attribute":$value)>,
209  ];
210
211  let hasFolder = 1;
212
213  let extraClassDeclaration = [{
214    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
215      return BroadcastCompatible(l, r);
216    }
217  }];
218}
219
220def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
221  let summary = "Creates and returns an empty tensor list.";
222
223  let description = [{
224All list elements must be tensors of dtype element_dtype and shape compatible
225with element_shape.
226
227handle: an empty tensor list.
228element_dtype: the type of elements in the list.
229element_shape: a shape compatible with that of elements in the list.
230  }];
231
232  let arguments = (ins
233    TF_I32OrI64Tensor:$element_shape,
234    TF_Int32Tensor:$max_num_elements
235  );
236}
237
238def TF_IfOp : TF_Op<"If", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
239  let summary = "output = cond ? then_branch(input) : else_branch(input)";
240
241  let description = [{
242output = cond ? then_branch(input) : else_branch(input)
243
244cond: A Tensor. If the tensor is a scalar of non-boolean type, the
245    scalar is converted to a boolean according to the
246    following rule: if the scalar is a numerical value, non-zero means
247    True and zero means False; if the scalar is a string, non-empty
248    means True and empty means False. If the tensor is not a scalar,
249    being empty means False and being non-empty means True.
250input: A list of input tensors.
251then_branch: A function that takes 'inputs' and returns a list of
252    tensors, whose types are the same as what else_branch returns.
253else_branch: A function that takes 'inputs' and returns a list of
254    tensors.  whose types are the same as what then_branch returns.
255  }];
256
257  let arguments = (ins
258    TF_Tensor:$cond,
259    Variadic<TF_Tensor>:$input,
260
261    FlatSymbolRefAttr:$then_branch,
262    FlatSymbolRefAttr:$else_branch,
263
264    // Used to map StatelessIf and If op defined in TensorFlow to a common op.
265    BoolAttr:$is_stateless
266  );
267
268  let results = (outs
269    Variadic<TF_Tensor>:$output
270  );
271
272  TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
273  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
274  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
275  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
276
277  let hasCanonicalizer = 1;
278
279  let extraClassDeclaration = [{
280    // Resolve the then branch function.
281    // Prefer passing in SymbolTableCollection to reduce lookup costs by
282    // enabling reusing cached symbol table lookup.
283    func::FuncOp ResolveThenFunction(::mlir::SymbolTableCollection* table) {
284      if (table)
285        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, then_branchAttr());
286      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
287        *this, then_branchAttr());
288    }
289    // TODO(b/204997177): Deprecate and remove.
290    func::FuncOp then_function(::mlir::SymbolTableCollection* table = nullptr) {
291      return ResolveThenFunction(table);
292    }
293
294    // Resolve the else branch function.
295    // Prefer passing in SymbolTableCollection to reduce lookup costs by
296    // enabling reusing cached symbol table lookup.
297    func::FuncOp ResolveElseFunction(::mlir::SymbolTableCollection* table) {
298      if (table)
299        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, else_branchAttr());
300      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
301        *this, else_branchAttr());
302    }
303    // TODO(b/204997177): Deprecate and remove.
304    func::FuncOp else_function(::mlir::SymbolTableCollection* table = nullptr) {
305      return ResolveElseFunction(table);
306    }
307  }];
308}
309
310def TF_YieldOp : TF_Op<"Yield",
311      [NoSideEffect, ReturnLike, Terminator,
312       ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
313  let summary = "Yield operation";
314
315  let description = [{
316    The "yield" operation represents a return operation within the conditional
317    and body of structured control flow (e.g., if and while). The operation
318    takes a variable number of operands and produces no results. The number and
319    types of inputs must match the signature of the operation that contains the
320    region.
321  }];
322
323  let arguments = (ins Variadic<AnyType>:$operands);
324}
325
326def TF_IfRegionOp : TF_Op<"IfRegion",
327      [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
328  let summary = "output = cond ? then_branch output : else_branch output";
329
330  let description = [{
331"output = cond ? then_branch output : else_branch output"
332
333cond: A Tensor. If the tensor is a scalar of non-boolean type, the
334    scalar is converted to a boolean according to the
335    following rule: if the scalar is a numerical value, non-zero means
336    True and zero means False; if the scalar is a string, non-empty
337    means True and empty means False. If the tensor is not a scalar,
338    being empty means False and being non-empty means True.
339then_branch: A region that computes the outputs of the op if cond = true.
340    It returns a list of tensors using tf.yield (as the terminator). The
341    types of these returned tensors is same as that of the else_branch
342else_branch: A region that computes the outputs of the op if cond = false.
343    It returns a list of tensors using tf.yield (as the terminator). The
344    types of these returned tensors is same as that of the then_branch
345  }];
346
347  let arguments = (ins
348    0DTensorOf<[I1]>:$cond,
349
350    // Used to map StatelessIf and If op defined in TensorFlow to a common op.
351    BoolAttr:$is_stateless,
352    // Used to maintain function name when round-tripping
353    // between functional and regional control flow.  This can be removed if
354    // the runtime does not require globally unique then/else branch function names.
355    OptionalAttr<StrAttr>:$_then_func_name,
356    OptionalAttr<StrAttr>:$_else_func_name
357  );
358
359  let results = (outs
360    Variadic<TF_Tensor>:$output
361  );
362
363  let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch);
364
365  let hasRegionVerifier = 1;
366
367  let builders = [
368    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
369      "llvm::ArrayRef<::mlir::NamedAttribute>":$attributes,
370      "unsigned":$numRegions),
371    [{
372      assert(numRegions == 2u && "mismatched number of regions");
373      build($_builder, $_state, resultTypes, operands, attributes);
374    }]>];
375
376  let hasCanonicalizer = 1;
377}
378
379def TF_LegacyCallOp : TF_Op<"LegacyCall",
380                            [CallOpInterface, NoSideEffect]> {
381  let summary =
382    "returns `f(inputs)`, where `f` is a function.";
383
384  let description = [{
385    The LegacyCall operation represents a direct call to a function that is
386    within the same symbol scope as the call and is mapped to a GraphDef node
387    with the function name as the op name. Unlike a PartitionedCall which
388    represents asynchronously executing a function across multiple devices, a
389    LegacyCall ignores specification for ops in the attached function and
390    instead executes it on the device assigned to this op.
391  }];
392
393  let arguments = (ins
394    Variadic<TF_Tensor>:$args,
395
396    FlatSymbolRefAttr:$f,
397    DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference
398  );
399
400  let results = (outs
401    Variadic<TF_Tensor>:$output
402  );
403
404  let extraClassDeclaration = [{
405    // Gets the argument operands to the called function.
406    operand_range getArgOperands() { return args(); }
407
408    // Returns the callee of this operation.
409    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
410
411    // Returns the resolved callee function of this operation.
412    // Prefer passing in SymbolTableCollection to reduce lookup costs by
413    // enabling reusing cached symbol table lookup.
414    func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
415      if (table)
416        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
417      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
418    }
419    // TODO(b/204997177): Deprecate and remove.
420    func::FuncOp func() {  return ResolveFunc(nullptr); }
421  }];
422}
423
424def TF_ParseExampleOp : TF_Op<"ParseExample",
425                               [NoSideEffect,
426                                AttrSizedResultSegments,
427                                AttrSizedOperandSegments]> {
428
429  let summary =
430    "Transforms a vector of tf.Example protos (as strings) into typed tensors.";
431
432  let arguments = (ins
433    TF_StrTensor:$serialized,
434    TF_StrTensor:$names,
435    Variadic<TF_StrTensor>:$sparse_keys,
436    Variadic<TF_StrTensor>:$dense_keys,
437    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
438
439    TF_ShapeAttrArray:$dense_shapes,
440    DenseI32ArrayAttr:$result_segment_sizes,
441    DenseI32ArrayAttr:$operand_segment_sizes
442  );
443
444  let results = (outs
445    Variadic<TF_Int64Tensor>:$sparse_indices,                           // len(sparse_types)
446    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values,  // len(sparse_types)
447    Variadic<TF_Int64Tensor>:$sparse_shapes,                            // len(sparse_types)
448    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values    // len(Tdense)
449  );
450
451  TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
452  TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>;
453  TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>;
454  TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
455
456  let hasVerifier = 0;
457}
458
459def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
460                                [NoSideEffect,
461                                 AttrSizedResultSegments]> {
462
463  let summary =
464    "Transforms a vector of tf.Example protos (as strings) into typed tensors.";
465
466  let arguments = (ins
467    TF_StrTensor:$serialized,
468    TF_StrTensor:$names,
469    TF_StrTensor:$sparse_keys,
470    TF_StrTensor:$dense_keys,
471    TF_StrTensor:$ragged_keys,
472    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
473
474    ConfinedAttr<I64Attr, [IntMinValue<0>]>:$num_sparse,
475    TF_ShapeAttrArray:$dense_shapes,
476    DenseI32ArrayAttr:$result_segment_sizes
477  );
478
479  let results = (outs
480    Variadic<TF_Int64Tensor>:$sparse_indices,                           // len(sparse_types)
481    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values,  // len(sparse_types)
482    Variadic<TF_Int64Tensor>:$sparse_shapes,                            // len(sparse_types)
483    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values,   // len(Tdense)
484    Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values,  // len(ragged_value_types)
485                                                            //     = len(ragged_split_types)
486    Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits         // len(ragged_split_types)
487                                                            //     = len(ragged_value_types)
488  );
489
490  // The Verify(ParseExampleV2Op) function validates that the lengths and types
491  // of these attrs are compatible.
492  TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>;
493  TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
494  TF_DerivedResultTypeListAttr ragged_value_types =
495    TF_DerivedResultTypeListAttr<4>;
496  TF_DerivedResultTypeListAttr ragged_split_types =
497    TF_DerivedResultTypeListAttr<5>;
498
499  let hasVerifier = 1;
500}
501
502def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> {
503  let summary = "Placeholder op";
504
505  let description = [{
506Inserts a placeholder for a tensor that will be always fed.
507  }];
508
509  let arguments = (ins
510  );
511
512  let results = (outs
513    TF_Tensor:$output
514  );
515
516  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
517}
518
519def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
520  let summary = "Placeholder op";
521
522  let description = [{
523    A placeholder op that passes through input when its output is not fed.
524  }];
525
526  let arguments = (ins
527    TF_Tensor:$input
528  );
529
530  let results = (outs
531    TF_Tensor:$output
532  );
533
534  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
535  DerivedAttr shape = TF_DerivedResultShapeAttr;
536}
537
538def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
539                                         [CallOpInterface, SymbolUserOpInterface]> {
540  let summary =
541    "returns `f(inputs)`, where `f`'s body is placed and partitioned.";
542
543  let description = [{
544Asynchronously executes a function, potentially across multiple devices but
545within a single process. The kernel places and partitions a given function's
546underlying graph, and executes each of the partitioned subgraphs as a function.
547  }];
548
549  let arguments = (ins
550    Variadic<TF_Tensor>:$args,
551
552    FlatSymbolRefAttr:$f,
553    StrAttr:$config,
554    StrAttr:$config_proto,
555    StrAttr:$executor_type
556  );
557
558  let results = (outs
559    Variadic<TF_Tensor>:$output
560  );
561
562  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
563  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
564
565  let extraClassDeclaration = [{
566    // Gets the argument operands to the called function.
567    operand_range getArgOperands() { return args(); }
568
569    // Returns the callee of this operation.
570    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
571
572    // Returns the resolved callee function of this operation.
573    // Prefer passing in SymbolTableCollection to reduce lookup costs by
574    // enabling reusing cached symbol table lookup.
575    func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
576      if (table)
577        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
578      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
579    }
580    // TODO(b/204997177): Deprecate and remove.
581    func::FuncOp func() {  return ResolveFunc(nullptr); }
582
583    // SymbolUserOpInterface verifier.
584    LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable);
585  }];
586}
587
588def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
589  let summary = [{
590output = input; While (Cond(output)) { output = Body(output) }
591  }];
592
593  let description = [{
594output = input; While (Cond(output)) { output = Body(output) }
595
596input: A list of input tensors whose types are T.
597output: A list of output tensors whose types are T.
598cond: A function that takes 'input' and returns a tensor.  If the tensor is
599    a scalar of non-boolean, the scalar is converted to a boolean
600    according to the following rule: if the scalar is a numerical
601    value, non-zero means True and zero means False; if the scalar is
602    a string, non-empty means True and empty means False. If the
603    tensor is not a scalar, non-emptiness means True and False
604    otherwise.
605body: A function that takes a list of tensors and returns another
606      list of tensors. Both lists have the same types as specified
607      by T.
608  }];
609
610  let arguments = (ins
611    Variadic<TF_Tensor>:$input,
612
613    FlatSymbolRefAttr:$cond,
614    FlatSymbolRefAttr:$body,
615    ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations,
616
617    // Used to map StatelessWhile and While op defined in TensorFlow to a common
618    // op.
619    BoolAttr:$is_stateless,
620
621    // In TensorFlow, While has a special behavior where if `output_shapes`
622    // attribute is not empty, those shapes are used in its shape function
623    // as result shapes instead of propagating operand shapes as result shapes.
624    // This allows for different result shapes from operand shapes. While these
625    // shapes are imported and set as a part of the result type, there is no
626    // indicator differentiating between having no output shapes compared to
627    // having all unranked shapes. Thus this attribute is set to determine
628    // which shape function behavior to use for this op, specifically
629    // propagating operand shapes as result shapes when this attribute is not
630    // set, or preserving result shapes as is when this attribute is set.
631    UnitAttr:$shape_invariant
632  );
633
634  let results = (outs
635    Variadic<TF_Tensor>:$output
636  );
637
638  TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
639  TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
640
641  let extraClassDeclaration = [{
642    // Get the condition function.
643    // Prefer passing in SymbolTableCollection to reduce lookup costs by
644    // enabling reusing cached symbol table lookup.
645    func::FuncOp ResolveCondFunction(::mlir::SymbolTableCollection* table) {
646      if (table)
647        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
648      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
649    }
650    // TODO(b/204997177): Deprecate and remove.
651    func::FuncOp cond_function() { return ResolveCondFunction(nullptr); }
652
653    // Get the body function.
654    // Prefer passing in SymbolTableCollection to reduce lookup costs by
655    // enabling reusing cached symbol table lookup.
656    func::FuncOp ResolveBodyFunction(::mlir::SymbolTableCollection* table) {
657      if (table)
658        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
659      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
660    }
661    // TODO(b/204997177): Deprecate and remove.
662    func::FuncOp body_function() { return ResolveBodyFunction(nullptr); }
663  }];
664}
665
666def TF_WhileRegionOp : TF_Op<"WhileRegion",
667      [DeclareOpInterfaceMethods<LoopLikeOpInterface>,
668       SingleBlockImplicitTerminator<"YieldOp">]> {
669  let summary = "while operation";
670  let description = [{
671  The tf.WhileRegion op represents a while loop using 2 regions and a set of
672  iteration variables. The iteration variables maintained by this Op have the
673  same types as the inputs. The Op executes a while loop described by the
674  following pseudo code:
675
676  ```
677     func WhileRegionOp(inputs) {
678       iteration_vars = inputs;
679       while (cond(iteration_vars)) {
680           iteration_vars = body(iteration_vars);
681       }
682       return iteration_vars;
683     }
684  ```
685
686  `cond` is the condition region and `body` is the body region. Both these
687  regions accept the current value of the iteration variables as inputs. The
688  condition region returns a tensor<i1> which, if false, will exit the loop.
689  The body region computes new values of the iteration variables. The iteration
690  variables are initialized to the Op input, and the results of the
691  tf.WhileRegion op are the final values of the iteration variables.
692
693  This implies that the operand and result types for tf.WhileRegion should be
694  the same. Note that the condition and body regions can implicitly capture
695  loop invariant values directly. In canonical form, iteration variables that
696  pass through the loop body unmodified are converted to implicitly captured
697  references to their values outside the loop.
698  }];
699
700  let arguments = (ins
701    Variadic<AnyTensor>:$input,
702
703    ConfinedAttr<DefaultValuedAttr<I64Attr, "10">, [IntMinValue<1>]>:$parallel_iterations,
704
705    // Used to map StatelessWhile and While op defined in TensorFlow to a common
706    // op.
707    BoolAttr:$is_stateless,
708
709    // In TensorFlow, While has a special behavior where if `output_shapes`
710    // attribute is not empty, those shapes are used in its shape function
711    // as result shapes instead of propagating operand shapes as result shapes.
712    // This allows for different result shapes from operand shapes. While these
713    // shapes are imported and set as a part of the result type, there is no
714    // indicator differentiating between having no output shapes compared to
715    // having all unranked shapes. Thus this attribute is set to determine
716    // which shape function behavior to use for this op, specifically
717    // propagating operand shapes as result shapes when this attribute is not
718    // set, or preserving result shapes as is when this attribute is set.
719    UnitAttr:$shape_invariant
720  );
721  let results = (outs Variadic<AnyTensor>:$output);
722
723  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
724
725  let hasVerifier = 1;
726
727  let hasCanonicalizer = 1;
728}
729
730def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
731  let summary = "List of the given size with empty elements.";
732
733  let description = [{
734element_shape: the shape of the future elements of the list
735num_elements: the number of elements to reserve
736handle: the output list
737element_dtype: the desired type of elements in the list.
738  }];
739
740  let arguments = (ins
741    TF_I32OrI64Tensor:$element_shape,
742    TF_Int32Tensor:$num_elements
743  );
744}
745
746def TF_VarHandleOp : TF_Op<"VarHandleOp", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
747  let summary = "Creates a handle to a Variable resource from its name.";
748
749  let description = [{
750container: the container this variable is placed in.
751shared_name: the name by which this variable is referred to.
752dtype and shape: attributes representing the data type and shape held in the
753  variable.
754
755Example:
756    resource_variable_ops.var_handle_op(
757          dtype=dtypes.int32, shape=[8, 16], container="foo", shared_name="bar")
758  returns a handle for a variable with name "bar" in container "foo", and the
759  variable holds a tensor of shape [8, 16] and dtype int32.
760  }];
761
762  let arguments = (ins
763    DefaultValuedStrAttr<StrAttr, "">:$container,
764    DefaultValuedStrAttr<StrAttr, "">:$shared_name
765  );
766
767  let results = (outs
768    Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
769  );
770
771  DerivedTypeAttr dtype = DerivedTypeAttr<
772      "return getElementTypeOrSelf(resource_subtype());">;
773  DerivedAttr shape = DerivedAttr<
774      "ShapedType",
775      "return resource_subtype().cast<ShapedType>();",
776      [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>;
777
778  let extraClassDeclaration = [{
779    TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
780
781    ResourceType resource_type() {
782      return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
783    }
784  }];
785
786  let hasVerifier = 1;
787}
788
789def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect, TF_NoConstantFold]> {
790  let summary = [{
791An op which shards the input based on the given sharding attribute.
792  }];
793
794  let arguments = (ins
795    TF_Tensor:$input,
796
797    DefaultValuedStrAttr<StrAttr, "">:$sharding,
798    OptionalAttr<StrAttr>:$_XlaSharding
799  );
800
801  let results = (outs
802    TF_Tensor:$output
803  );
804
805  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
806}
807
808def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
809  let summary = "Fetches multiple values from infeed as an XLA tuple.";
810
811  let arguments = (ins
812    OptionalAttr<StrAttr>:$_XlaSharding,
813    OptionalAttr<ArrayAttr>:$layouts
814  );
815
816  let results = (outs
817    Variadic<TF_Tensor>:$outputs
818  );
819
820  TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>;
821  TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
822}
823
824// TODO(b/177675373): Make dtypes and shapes derived attributes,
825// use more general solution.
826def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> {
827  let summary = [{
828Feeds multiple Tensor values into the computation as an XLA tuple.
829  }];
830
831  let arguments = (ins
832    Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs,
833
834    ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes,
835    TF_ShapeAttrArray:$shapes,
836    DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts,
837    DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
838  );
839
840  let results = (outs);
841}
842
843// This op is manually defined because the attribute name `template` (which is
844// a keyword) is changed to `strtemplate`.
845def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
846  let summary = "Formats a string template using a list of tensors.";
847
848  let description = [{
849Formats a string template using a list of tensors, pretty-printing tensor summaries.
850  }];
851
852  let arguments = (ins
853    Variadic<TF_Tensor>:$inputs,
854
855    DefaultValuedStrAttr<StrAttr, "%s">:$strtemplate,
856    DefaultValuedStrAttr<StrAttr, "%s">:$placeholder,
857    DefaultValuedAttr<I64Attr, "3">:$summarize
858  );
859
860  let results = (outs
861    TF_StrTensor:$output
862  );
863
864  TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
865}
866
867//===----------------------------------------------------------------------===//
868// tf.data ops
869//===----------------------------------------------------------------------===//
870
871def TF_ReduceDatasetOp : TF_Op<"ReduceDataset", [SameVariadicOperandSize]> {
872  let summary = [{
873    Reduces the input dataset to a singleton using a reduce function.
874  }];
875
876  let arguments = (ins
877    TF_VariantTensor:$input_dataset,
878    Variadic<TF_Tensor>:$initial_state,
879    Variadic<TF_Tensor>:$other_arguments,
880
881    SymbolRefAttr:$f,
882    ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$Tstate,
883    ConfinedAttr<TypeArrayAttr, [ArrayMinCount<0>]>:$Targuments,
884    ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
885    ConfinedAttr<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
886    DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism
887  );
888
889  let results = (outs
890    Variadic<TF_Tensor>:$components
891  );
892}
893
894// Manually defined to restrict result type to `I1Tensor`.
895def TF_ToBoolOp : TF_Op<"ToBool", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
896  let summary = "Converts a tensor to a scalar predicate.";
897
898  let description = [{
899Converts a tensor to a scalar predicate with the following rules:
900
901- For 0D tensors, truthiness is determined by comparing against a "zero"
902  value. For numerical types it is the obvious zero. For strings it is the
903  empty string.
904
905- For >0D tensors, truthiness is determined by looking at the number of
906  elements. If has zero elements, then the result is false. Otherwise the
907  result is true.
908
909This matches the behavior of If and While for determining if a tensor counts
910as true/false for a branch condition.
911  }];
912
913  let arguments = (ins
914    TF_Tensor:$input
915  );
916
917  let results = (outs
918    I1Tensor:$output
919  );
920
921  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
922
923  let hasCanonicalizer = 1;
924
925  let extraClassDeclaration = [{
926    // InferTypeOpInterface:
927    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
928      return ArraysAreCastCompatible(l, r);
929    }
930  }];
931}
932
933def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
934  let summary = "Computes the Bessel i0e function of `x` element-wise.";
935
936  let description = [{
937Exponentially scaled modified Bessel function of order 0 defined as
938`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
939
940This function is faster and numerically stabler than `bessel_i0(x)`.
941  }];
942
943  let arguments = (ins
944    TF_FloatTensor:$x
945  );
946
947  let results = (outs
948    TF_FloatTensor:$y
949  );
950
951  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
952}
953
954def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
955  let summary = "Computes the Bessel i1e function of `x` element-wise.";
956
957  let description = [{
958Exponentially scaled modified Bessel function of order 0 defined as
959`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
960
961This function is faster and numerically stabler than `bessel_i1(x)`.
962  }];
963
964  let arguments = (ins
965    TF_FloatTensor:$x
966  );
967
968  let results = (outs
969    TF_FloatTensor:$y
970  );
971
972  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
973}
974
975def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface, SymbolUserOpInterface]> {
976  let summary = "Calls a function placed on a specified TPU device.";
977
978  let arguments = (ins
979    Variadic<TF_Tensor>:$args,
980    TF_Int32Tensor:$device_ordinal,
981
982    SymbolRefAttr:$f,
983    DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
984  );
985
986  let results = (outs
987    Variadic<TF_Tensor>:$output
988  );
989
990  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
991  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
992
993  let extraClassDeclaration = [{
994    // Gets the argument operands to the called function.
995    operand_range getArgOperands() { return args(); }
996
997    // Returns the callee of this operation.
998    CallInterfaceCallable getCallableForCallee() { return fAttr(); }
999
1000    // Returns the resolved callee function of this operation.
1001    // Prefer passing in SymbolTableCollection to reduce lookup costs by
1002    // enabling reusing cached symbol table lookup.
1003    func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) {
1004      if (table)
1005        return table->lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
1006      return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, fAttr());
1007    }
1008    // TODO(b/204997177): Deprecate and remove.
1009    func::FuncOp func() {  return ResolveFunc(nullptr); }
1010
1011    // SymbolUserOpInterface verifier.
1012    LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable);
1013  }];
1014}
1015
1016def TF_StatefulUniformFullIntOp : TF_Op<"StatefulUniformFullInt", []> {
1017  let summary = "Outputs random integers from a uniform distribution.";
1018
1019  let description = [{
1020The generated values are uniform integers covering the whole range of `dtype`.
1021  }];
1022
1023  let arguments = (ins
1024    Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
1025    TF_Int64Tensor:$algorithm,
1026    TF_I32OrI64Tensor:$shape
1027  );
1028
1029  let results = (outs
1030    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
1031  );
1032
1033  TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
1034  TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
1035}
1036
1037// TODO(lyandy): Investigate supported dtypes (`minval`, `maxval`, `output`) for
1038// `tf.StatefulUniformInt`. tf2xla kernels support i32, i64, ui32, and ui64
1039// while TensorFlow CPU/GPU kernels only support i32 and i64.
1040def TF_StatefulUniformIntOp : TF_Op<"StatefulUniformInt", []> {
1041  let summary = "Outputs random integers from a uniform distribution.";
1042
1043  let description = [{
1044The generated values are uniform integers in the range `[minval, maxval)`.
1045The lower bound `minval` is included in the range, while the upper bound
1046`maxval` is excluded.
1047
1048The random integers are slightly biased unless `maxval - minval` is an exact
1049power of two.  The bias is small for values of `maxval - minval` significantly
1050smaller than the range of the output (either `2^32` or `2^64`).
1051  }];
1052
1053  let arguments = (ins
1054    Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
1055    TF_Int64Tensor:$algorithm,
1056    TF_I32OrI64Tensor:$shape,
1057    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval,
1058    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval
1059  );
1060
1061  let results = (outs
1062    TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
1063  );
1064
1065  TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
1066  TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<3>;
1067}
1068
1069def TF_CloseSummaryWriterOp : TF_Op<"CloseSummaryWriter", []> {
1070  let summary = "Flushes and closes the summary writer.";
1071
1072  let description = [{
1073Also removes it from the resource manager. To reopen, use another
1074CreateSummaryFileWriter op.
1075
1076writer: A handle to the summary writer resource.
1077  }];
1078
1079  let arguments = (ins
1080    Arg<TF_ResourceTensor, "", [TF_SummaryFree]>:$writer
1081  );
1082
1083  let results = (outs);
1084}
1085
1086// TODO(b/168035831): Model db_uri read/write.
1087def TF_CreateSummaryDbWriterOp : TF_Op<"CreateSummaryDbWriter", []> {
1088  let summary = "Creates summary database writer accessible by given resource handle.";
1089
1090  let description = [{
1091This can be used to write tensors from the execution graph directly
1092to a database. Only SQLite is supported right now. This function
1093will create the schema if it doesn't exist. Entries in the Users,
1094Experiments, and Runs tables will be created automatically if they
1095don't already exist.
1096
1097writer: Handle to SummaryWriter resource to overwrite.
1098db_uri: For example "file:/tmp/foo.sqlite".
1099experiment_name: Can't contain ASCII control characters or <>. Case
1100  sensitive. If empty, then the Run will not be associated with any
1101  Experiment.
1102run_name: Can't contain ASCII control characters or <>. Case sensitive.
1103  If empty, then each Tag will not be associated with any Run.
1104user_name: Must be valid as both a DNS label and Linux username. If
1105  empty, then the Experiment will not be associated with any User.
1106  }];
1107
1108  let arguments = (ins
1109    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1110    TF_StrTensor:$db_uri,
1111    TF_StrTensor:$experiment_name,
1112    TF_StrTensor:$run_name,
1113    TF_StrTensor:$user_name
1114  );
1115
1116  let results = (outs);
1117}
1118
1119// TODO(b/168035831): Model logdir read/write.
1120def TF_CreateSummaryFileWriterOp : TF_Op<"CreateSummaryFileWriter", []> {
1121  let summary = "Creates a summary file writer accessible by the given resource handle.";
1122
1123  let description = [{
1124writer: A handle to the summary writer resource
1125logdir: Directory where the event file will be written.
1126max_queue: Size of the queue of pending events and summaries.
1127flush_millis: How often, in milliseconds, to flush the pending events and
1128  summaries to disk.
1129filename_suffix: Every event file's name is suffixed with this suffix.
1130  }];
1131
1132  let arguments = (ins
1133    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1134    TF_StrTensor:$logdir,
1135    TF_Int32Tensor:$max_queue,
1136    TF_Int32Tensor:$flush_millis,
1137    TF_StrTensor:$filename_suffix
1138  );
1139
1140  let results = (outs);
1141}
1142
1143def TF_FlushSummaryWriterOp : TF_Op<"FlushSummaryWriter", []> {
1144  let summary = "Flushes the writer's unwritten events.";
1145
1146  let description = [{
1147writer: A handle to the summary writer resource.
1148  }];
1149
1150  let arguments = (ins
1151    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer
1152  );
1153
1154  let results = (outs);
1155}
1156
1157def TF_ImportEventOp : TF_Op<"ImportEvent", []> {
1158  let summary = "Outputs a `tf.Event` protocol buffer.";
1159
1160  let description = [{
1161When CreateSummaryDbWriter is being used, this op can be useful for
1162importing data from event logs.
1163
1164writer: A handle to a summary writer.
1165event: A string containing a binary-encoded tf.Event proto.
1166  }];
1167
1168  let arguments = (ins
1169    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1170    TF_StrTensor:$event
1171  );
1172
1173  let results = (outs);
1174}
1175
1176def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [DeclareOpInterfaceMethods<TF_ResourceHandleAllocatorInterface>]> {
1177  let summary = "Returns a handle to be used to access a summary writer.";
1178
1179  let description = [{
1180The summary writer is an in-graph resource which can be used by ops to write
1181summaries to event files.
1182
1183writer: the summary writer resource. Scalar handle.
1184  }];
1185
1186  let arguments = (ins
1187    StrAttr:$shared_name,
1188    StrAttr:$container
1189  );
1190
1191  let results = (outs
1192    Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer
1193  );
1194}
1195
1196def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> {
1197  let summary = "Writes a `Summary` protocol buffer with audio.";
1198
1199  let description = [{
1200The summary has up to `max_outputs` summary values containing audio. The
1201audio is built from `tensor` which must be 3-D with shape `[batch_size,
1202frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
1203assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
1204
1205The `tag` argument is a scalar `Tensor` of type `string`.  It is used to
1206build the `tag` of the summary values:
1207
1208*  If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
1209*  If `max_outputs` is greater than 1, the summary value tags are
1210   generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
1211
1212writer: A handle to a summary writer.
1213step: The step to write the summary for.
1214tag: Scalar. Used to build the `tag` attribute of the summary values.
1215tensor: 2-D of shape `[batch_size, frames]`.
1216sample_rate: The sample rate of the signal in hertz.
1217max_outputs: Max number of batch elements to generate audio for.
1218  }];
1219
1220  let arguments = (ins
1221    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1222    TF_Int64Tensor:$step,
1223    TF_StrTensor:$tag,
1224    TF_Float32Tensor:$tensor,
1225    TF_Float32Tensor:$sample_rate,
1226
1227    ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs
1228  );
1229
1230  let results = (outs);
1231}
1232
1233def TF_WriteGraphSummaryOp : TF_Op<"WriteGraphSummary", []> {
1234  let summary = "Writes a `GraphDef` protocol buffer to a `SummaryWriter`.";
1235
1236  let description = [{
1237writer: Handle of `SummaryWriter`.
1238step: The step to write the summary for.
1239tensor: A scalar string of the serialized tf.GraphDef proto.
1240  }];
1241
1242  let arguments = (ins
1243    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1244    TF_Int64Tensor:$step,
1245    TF_StrTensor:$tensor
1246  );
1247
1248  let results = (outs);
1249}
1250
1251def TF_WriteHistogramSummaryOp : TF_Op<"WriteHistogramSummary", []> {
1252  let summary = "Writes a histogram summary.";
1253
1254  let description = [{
1255The generated
1256[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
1257has one summary value containing a histogram for `values`.
1258
1259This op reports an `InvalidArgument` error if any value is not finite.
1260
1261writer: A handle to a summary writer.
1262step: The step to write the summary for.
1263tag: Scalar.  Tag to use for the `Summary.Value`.
1264values: Any shape. Values to use to build the histogram.
1265  }];
1266
1267  let arguments = (ins
1268    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1269    TF_Int64Tensor:$step,
1270    TF_StrTensor:$tag,
1271    TF_IntOrFpTensor:$values
1272  );
1273
1274  let results = (outs);
1275
1276  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1277}
1278
1279def TF_WriteImageSummaryOp : TF_Op<"WriteImageSummary", []> {
1280  let summary = "Writes a `Summary` protocol buffer with images.";
1281
1282  let description = [{
1283The summary has up to `max_images` summary values containing images. The
1284images are built from `tensor` which must be 4-D with shape `[batch_size,
1285height, width, channels]` and where `channels` can be:
1286
1287*  1: `tensor` is interpreted as Grayscale.
1288*  3: `tensor` is interpreted as RGB.
1289*  4: `tensor` is interpreted as RGBA.
1290
1291The images have the same number of channels as the input tensor. For float
1292input, the values are normalized one image at a time to fit in the range
1293`[0, 255]`.  `uint8` values are unchanged.  The op uses two different
1294normalization algorithms:
1295
1296*  If the input values are all positive, they are rescaled so the largest one
1297   is 255.
1298
1299*  If any input value is negative, the values are shifted so input value 0.0
1300   is at 127.  They are then rescaled so that either the smallest value is 0,
1301   or the largest one is 255.
1302
1303The `tag` argument is a scalar `Tensor` of type `string`.  It is used to
1304build the `tag` of the summary values:
1305
1306*  If `max_images` is 1, the summary value tag is '*tag*/image'.
1307*  If `max_images` is greater than 1, the summary value tags are
1308   generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
1309
1310The `bad_color` argument is the color to use in the generated images for
1311non-finite input values.  It is a `unit8` 1-D tensor of length `channels`.
1312Each element must be in the range `[0, 255]` (It represents the value of a
1313pixel in the output image).  Non-finite values in the input tensor are
1314replaced by this tensor in the output image.  The default value is the color
1315red.
1316
1317writer: A handle to a summary writer.
1318step: The step to write the summary for.
1319tag: Scalar. Used to build the `tag` attribute of the summary values.
1320tensor: 4-D of shape `[batch_size, height, width, channels]` where
1321  `channels` is 1, 3, or 4.
1322max_images: Max number of batch elements to generate images for.
1323bad_color: Color to use for pixels with non-finite values.
1324  }];
1325
1326  let arguments = (ins
1327    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1328    TF_Int64Tensor:$step,
1329    TF_StrTensor:$tag,
1330    TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor,
1331    TF_Uint8Tensor:$bad_color,
1332
1333    ConfinedAttr<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images
1334  );
1335
1336  let results = (outs);
1337
1338  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1339}
1340
1341def TF_WriteRawProtoSummaryOp : TF_Op<"WriteRawProtoSummary", []> {
1342  let summary = "Writes a `Summary` protocol buffer with serialized string `Summary` protocol buffers.";
1343
1344  let description = [{
1345writer: A handle to a summary writer.
1346step: The step to write the summary for.
1347tensor: A tensor holding one or more serialized `Summary` protobufs to write.
1348  }];
1349
1350  let arguments = (ins
1351    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1352    TF_Int64Tensor:$step,
1353    TF_StrTensor:$tensor
1354  );
1355
1356  let results = (outs);
1357}
1358
1359def TF_WriteScalarSummaryOp : TF_Op<"WriteScalarSummary", []> {
1360  let summary = "Writes a `Summary` protocol buffer with scalar values.";
1361
1362  let description = [{
1363The input `tag` and `value` must have the scalars.
1364
1365writer: A handle to a summary writer.
1366step: The step to write the summary for.
1367tag: Tag for the summary.
1368value: Value for the summary.
1369  }];
1370
1371  let arguments = (ins
1372    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1373    TF_Int64Tensor:$step,
1374    TF_StrTensor:$tag,
1375    TF_IntOrFpTensor:$value
1376  );
1377
1378  let results = (outs);
1379
1380  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
1381}
1382
1383def TF_WriteSummaryOp : TF_Op<"WriteSummary", []> {
1384  let summary = "Outputs a `Summary` protocol buffer with a tensor.";
1385
1386  let description = [{
1387writer: A handle to a summary writer.
1388step: The step to write the summary for.
1389tensor: A tensor to serialize.
1390tag: The summary's tag.
1391summary_metadata: Serialized SummaryMetadata protocol buffer containing
1392 plugin-related metadata for this summary.
1393  }];
1394
1395  let arguments = (ins
1396    Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
1397    TF_Int64Tensor:$step,
1398    TF_Tensor:$tensor,
1399    TF_StrTensor:$tag,
1400    TF_StrTensor:$summary_metadata
1401  );
1402
1403  let results = (outs);
1404
1405  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
1406}
1407
1408def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", [NoSideEffect]> {
1409  let summary = [{
1410Placeholder device ordinal that represents device ordinal of a replicated op.
1411  }];
1412
1413  let description = [{
1414This op can be used when certain rewrite passes materialize ops that require a
1415device ordinal of a replicated op but replication logic has been abstracted away
1416using tf_device.replicate op. Subsequent rewrite passes must replace this op with
1417a constant output that represents the correct device ordinal of the replicated
1418operations inside a TPU host.
1419  }];
1420
1421  let arguments = (ins);
1422
1423  let results = (outs
1424    TF_Int64Tensor:$device_ordinal
1425  );
1426}
1427
1428def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [NoSideEffect]> {
1429  let summary = [{
1430An op that groups a list of partitioned inputs together. This op
1431  }];
1432
1433  let arguments = (ins
1434    Variadic<TF_Tensor>:$inputs,
1435
1436    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
1437    OptionalAttr<StrAttr>:$_XlaSharding
1438  );
1439
1440  let results = (outs
1441    TF_Tensor:$output
1442  );
1443
1444  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
1445  TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
1446}
1447
1448def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [NoSideEffect]> {
1449  let summary = [{
1450An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned
1451  }];
1452
1453  let description = [{
1454outputs outside the XLA computation.
1455  }];
1456
1457  let arguments = (ins
1458    TF_Tensor:$inputs,
1459
1460    DefaultValuedAttr<I64Attr, "0">:$partition_dim,
1461    OptionalAttr<StrAttr>:$_XlaSharding
1462  );
1463
1464  let results = (outs
1465    Variadic<TF_Tensor>:$output
1466  );
1467
1468  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
1469  TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>;
1470}
1471
1472// Declares symbol reference attribute `shape_inference_graph` to be optional
1473// unlike the TensorFlow definition. This is required to support ops that use
1474// empty string value for the attribute to signify missing.
1475def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", [TF_SendSideEffect, TF_RecvSideEffect, TF_XlaHostComputeSideEffect]> {
1476  let summary = [{
1477A pseudo-op to represent host-side computation in an XLA program.
1478  }];
1479
1480  let arguments = (ins
1481    Arg<Variadic<TF_Tensor>, [{A list of tensors that will be sent to the host.}]>:$inputs,
1482
1483    StrArrayAttr:$ancestors,
1484    TF_ShapeAttrArray:$shapes,
1485    OptionalAttr<SymbolRefAttr>:$shape_inference_graph,
1486    StrAttr:$key,
1487    DefaultValuedStrAttr<StrAttr, "">:$send_key,
1488    DefaultValuedStrAttr<StrAttr, "">:$recv_key,
1489    DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
1490    DefaultValuedAttr<I64Attr, "0">:$tpu_core
1491  );
1492
1493  let results = (outs
1494    Res<Variadic<TF_Tensor>, [{A list of tensors that will be returned to the device.}]>:$outputs
1495  );
1496
1497  TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
1498  TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
1499}
1500
1501def TF_ConfigureAndInitializeGlobalTPUOp : TF_Op<"ConfigureAndInitializeGlobalTPU", []> {
1502  let summary = [{
1503An op that initialize the TPU system in a multi-client set up.
1504  }];
1505
1506  let description = [{
1507Initializes global TPU system for mutli-client execution.
1508
1509This op does the work of both ConfigureDistributedTpuOp and
1510InitializeHostForDistributedTpuOp, and outputs the latter's result.
1511  }];
1512
1513  let arguments = (ins);
1514
1515  let results = (outs
1516    Res<TF_Int32Tensor, [{A vector containing the global TPU id of each TPU on the host.}]>:$output
1517  );
1518}
1519
1520def TF_ShutdownTPUSystemOp : TF_Op<"ShutdownTPUSystem", []> {
1521  let summary = [{
1522An op that shuts down the TPU system.
1523  }];
1524
1525  let arguments = (ins);
1526  let results = (outs
1527    TF_BoolTensor:$success
1528  );
1529}
1530
1531// Internal op for testing value-based side-effects for non-resource values.
1532// TODO(mgester) We should have an extension of TF dialect only for testing so
1533// TF dialect is not polluted with test ops.
1534def TF__InternalTestNonResourceValueSideEffects_ : TF_Op<"_InternalTestNonResourceValueSideEffects_", []> {
1535  let summary = "Internal op for testing only";
1536
1537  let arguments = (ins
1538    Arg<TF_StrTensor,"", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$key
1539  );
1540  let results = (outs);
1541}
1542
1543def TF__InternalTestMustExecuteTrait_ : TF_Op<"_InternalTestMustExecuteTrait_", [TF_MustExecute]> {
1544  let summary = "Internal op for testing only";
1545
1546  let arguments = (ins);
1547  let results = (outs);
1548}
1549
1550def TF_SetStaticDimensionBoundsOp : TF_Op<"SetStaticDimensionBounds", []> {
1551  let summary = "Op used to indicate to the compiler and runtime the static bounds of a tensor.";
1552  let description = [{
1553The information passed through this op can possibly be used by the compiler and
1554runtime to perform certain optimizations such as more efficient DMAs. The
1555bounds passed via this op should be considered advisory only, and depending on
1556the implementation, might do nothing and simply be an identity
1557
1558`input`: The tensor that has dynamic dimensions.
1559`static_shape`: The static shape of the tensor, corresponds to the maximum bounds of each dimension.
1560`output` is the input tensor with no changes done to it.
1561
1562Example usage:
1563
1564def tpu_call(args):
1565  def model_fn(args):
1566    # do something with dynamic tensor
1567
1568  @function.Defun(capture_resource_var_by_value=False)
1569  def tpu_subgraph():
1570      return tf.tpu.rewrite(model_fn, args)
1571
1572  return tf.raw_ops.TPUPartitionedCall(
1573      args=tpu_subgraph.captured_inputs,
1574      Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],
1575      f=tpu_subgraph,
1576      device_ordinal=[0])
1577
1578static_shape = tf.placeholder(tf.int32, shape=([3]), name='static_size')
1579
1580w = tf.Variable(tf.constant([[1.0], [2.0], [3.0]]), name='w')
1581
1582w_dyn = tf.SetDynamicDimensionBounds(w, static_size])
1583tpu_call([w_dyn])
1584}];
1585  let arguments = (ins
1586    TF_Tensor:$input,
1587    TF_I32OrI64Tensor:$static_shape
1588  );
1589
1590  let hasVerifier = 1;
1591
1592  let results = (outs
1593    TF_Tensor:$output
1594  );
1595}
1596
1597def TF_TPUCompileMlirAndExecuteOp : TF_Op<"TPUCompileMlirAndExecute", [AttrSizedOperandSegments]> {
1598  let summary = "Op that compiles a computation in MLIR into a TPU program, and loads and executes it on a TPU device.";
1599
1600  let description = [{
1601For the internal use of the TPU compiler.
1602
1603'static_shapes' are tensors specifying the maximum dimension sizes for the tensors specified in `dynamic_operands`.
1604'args' are inputs to the TPU computation.
1605'operands_with_static_shape' are the indices of the operands that have a maximal static shape specified.
1606'mlir_module' is a serialized MLIR module with a `main` function that contains
1607target computation.
1608'metadata' is a serialized TPUCompileMetadataProto describing the shapes and
1609types of the inputs to the computation, as well as a mapping onto the TPU pod
1610topology.
1611'producer_name' is a string describing the name of the framework that add support for running this portion of the model on TPUs.
1612  }];
1613
1614  let arguments = (ins
1615    Variadic<TF_Tensor>:$args,
1616    Variadic<TF_Int64Tensor>:$static_shapes,
1617    OptionalAttr<I32ArrayAttr>:$operands_with_static_shape,
1618    DefaultValuedStrAttr<StrAttr, "">:$mlir_module,
1619    StrAttr:$metadata,
1620    StrAttr:$producer_name
1621  );
1622
1623  let results = (outs
1624    TF_Tensor:$rendezvous_key_base,
1625    Variadic<TF_Tensor>:$results
1626  );
1627
1628  TF_DerivedOperandTypeListAttr Targs = TF_DerivedOperandTypeListAttr<0>;
1629  TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
1630}
1631
1632#endif // TF_OPS
1633