xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the base operation definition file for TensorFlow.
17//
18// This file includes the definition for the TensorFlow dialect, base TensorFlow
19// op, and various commonly used TensorFlow traits, types, attributes, and
20// builders.
21
22#ifndef TF_OP_BASE
23#define TF_OP_BASE
24
25include "mlir/IR/OpBase.td"
26include "mlir/Interfaces/SideEffectInterfaces.td"
27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
28
29//===----------------------------------------------------------------------===//
30// TensorFlow dialect definitions
31//===----------------------------------------------------------------------===//
32
33def TensorFlowDialect : Dialect {
34  let name = "tf";
35
36  let description = [{
37The TensorFlow dialect.
38
39This dialect maps to TensorFlow operations.
40
41Invariants:
42
43* All values are of Tensor type (in particular, scalars are
44  represented using zero-dimensional tensors);
45
46TODO: Make invariants more structured so that we can reference them in ops.
47  }];
48
49  let cppNamespace = "::mlir::TF";
50
51  let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
52}
53
54//===----------------------------------------------------------------------===//
55// TensorFlow traits
56//===----------------------------------------------------------------------===//
57
58// Specify this trait if the op requires all outputs to have the same type and
59// the inputs either have the same type as result or a ref type corresponding to
60// the result type.
61def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
62  "TF::OperandsSameAsResultsTypeOrRef">;
63
64// Op has the same operand and result element types (or type itself, if scalar)
65// after resolving reference types (i.e., after converting reference types to
66// their corresponding TensorFlow or standard types).
67def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
68  "TF::SameOperandsAndResultElementTypeResolveRef">;
69
70// Op has the same operand and result types after resolving reference types
71// (i.e., after converting reference types to their corresponding TensorFlow or
72// standard types).
73def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
74  "TF::SameOperandsAndResultTypeResolveRef">;
75
76// Layout agnostic operations do not depend on the operands data layout (data
77// format), as an example all element wise operations are layout agnostic.
78def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
79
80// Trait to indicate operations that cannot be duplicated as they might carry
81// certain state around within their implementations.
82def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
83
84// Trait to indicate an operation cannot be constant folded.
85def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
86
87// Coefficient wise binary operation with implicit broadcasting support, for
88// example tf.Sub operation.
89def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
90
91// Coefficient wise unary operation, for example tf.Sqrt operation.
92def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
93
94// Variant of broadcastable trait that considers TF's subtype behavior.
95class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
96    TCOpResIsShapedTypePred<opId, resId>,
97    CPred<"mlir::tf_type::BroadcastCompatible("
98              "$_op.getOperand(" # opId # ").getType(), "
99              "$_op.getResult(" # resId # ").getType())">]>;
100
101
102class TF_AllTypesMatchPred<list<string> values> :
103    CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" #
104      !interleave(values, ", ") # "}))">;
105
106class TF_AllTypesMatch<list<string> names> :
107    PredOpTrait<
108        "all of {" # !interleave(names, ", ") #
109          "} have dynamically equal types ",
110        TF_AllTypesMatchPred<
111            !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
112
113// This trait indicates that all returned resources are unique for a
114// resource-allocating op (i.e. op with `MemAlloc` side effect).
115//
116// Note that if the trait is used where this invariant is not true, then this
117// might lead to incorrect execution order, while if not used where it should
118// be, it can only lead to reduced performance due to conservative ordering.
119// Example op where the invariant is not true: `TF_VarHandleOp`.
120def TF_UniqueResourceAllocation: TraitList<[
121    TF_ResourceHandleAllocatorInterface,
122    NativeOpTrait<"TF::UniqueResourceAllocation">
123]>;
124
125//===----------------------------------------------------------------------===//
126// Rank/Shape helpers.
127//===----------------------------------------------------------------------===//
128
129class TF_OperandIsUnrankedPred<int n> :
130  CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
131
132class TF_ResultIsUnrankedPred<int n> :
133  CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
134
135// Returns true if the n-th operand has unknown rank or has rank m.
136class TF_OperandHasRank<int n, int m> :
137  PredOpTrait<"operand " # n # " is " # m # "-D",
138    Or<[TF_OperandIsUnrankedPred<n>,
139      CPred<"$_op.getOperand(" # n #
140      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
141
142// Returns true if the n-th result has unknown rank or has rank m.
143class TF_ResultHasRank<int n, int m> :
144  PredOpTrait<"result " # n # " is " # m # "-D",
145    Or<[TF_ResultIsUnrankedPred<n>,
146      CPred<"$_op.getResult(" # n #
147      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
148
149//===----------------------------------------------------------------------===//
150// TensorFlow resources and side effects
151//===----------------------------------------------------------------------===//
152
153class TF_ResourceBase<string resourceKind> :
154  Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
155}
156
157// Resource types
158def TF_VariableResource : TF_ResourceBase<"Variable">;
159def TF_StackResource : TF_ResourceBase<"Stack">;
160def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
161def TF_SummaryResource : TF_ResourceBase<"Summary">;
162def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
163def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
164def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
165def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
166def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
167def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">;
168def TF_SendResource : TF_ResourceBase<"Send">;
169def TF_RecvResource : TF_ResourceBase<"Recv">;
170def TF_TPUExecuteResource : TF_ResourceBase<"TPUExecute">;
171def TF_RandomGeneratorResource : TF_ResourceBase<"RandomGenerator">;
172def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">;
173def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">;
174
175// Fake resource, see `TF_MustExecute` below.
176def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">;
177
178// Value-based side effects
179//
180// Value-based side effect traits are attached to op operands or results to
181// signal what type of resource is accessed and in which way.
182def TF_VariableRead : MemRead<TF_VariableResource>;
183def TF_StackRead : MemRead<TF_StackResource>;
184def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
185def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
186def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
187def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
188def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
189
190def TF_VariableWrite : MemWrite<TF_VariableResource>;
191def TF_StackWrite : MemWrite<TF_StackResource>;
192def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
193def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
194def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
195def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
196def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
197def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
198
199def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
200def TF_StackAlloc : MemAlloc<TF_StackResource>;
201def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
202def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
203def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
204def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
205def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
206def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
207
208def TF_StackFree : MemFree<TF_StackResource>;
209def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
210def TF_SummaryFree : MemFree<TF_SummaryResource>;
211def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
212def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
213def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
214
215// Op-based side effects
216
217// Op-based side effect traits can be used to enforce certain execution order
218// constraints, in particular for ops that don't use resource handles (those
219// typically have value-based side effects). For a `read` effect, all instances
220// of ops with the trait keep their order to all ops with unknown side effects
221// (e.g. `stateful` ops). For a `write` effect, all instances of ops with the
222// trait stay in order, and they also keep their order to all unknown side-
223// effecting ops. Note that for `read` effects ops might be pruned if nothing
224// depends on them.
225def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>;
226
227def TF_TPUEmbeddingWriteEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
228def TF_TPUEmbeddingReadEffect : MemoryEffects<[MemRead<TF_TPUEmbeddingResource>]>;
229
230def TF_SendSideEffect : MemoryEffects<[MemWrite<TF_SendResource>]>;
231def TF_RecvSideEffect : MemoryEffects<[MemWrite<TF_RecvResource>]>;
232def TF_XlaHostComputeSideEffect : MemoryEffects<[MemWrite<TF_XlaHostComputeResource>]>;
233
234def TF_TPUExecuteSideEffect : MemoryEffects<[MemWrite<TF_TPUExecuteResource>]>;
235
236def TF_RandomGeneratorSideEffect : MemoryEffects<[MemWrite<TF_RandomGeneratorResource>]>;
237
238// Special effect for keeping `CollectiveReduce` ops in order.
239def TF_CollectiveReduceOrderingEffect : MemoryEffects<[MemWrite<TF_CollectiveReduceOrderingResource>]>;
240
241// Trait for enforcing that a side-effecting op is executed, even if it would be
242// considered dead by MLIR (see b/195782952).
243// The trait is implemented as a write effect for a fake resource which is
244// ignored by side effect analysis, so it does not affect execution order
245// constraints and control dependencies at all (for example, multiple ops with
246// this trait do not have to execute in order).
247def TF_MustExecute : MemoryEffects<[MemWrite<TF_MustExecuteResource>]>;
248
249//===----------------------------------------------------------------------===//
250// TensorFlow op definitions
251//===----------------------------------------------------------------------===//
252
253class TF_Op<string mnemonic, list<Trait> traits = []> :
254    Op<TensorFlowDialect, mnemonic, traits>;
255
256//===----------------------------------------------------------------------===//
257// TensorFlow attribute definitions
258//===----------------------------------------------------------------------===//
259
260class TF_TensorFlowAttr <string name, string description> :
261    Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
262         "TensorFlow " # description # " attribute">;
263
264def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
265  let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
266  let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
267
268  // Create a ranked shape attr by default.
269  let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
270}
271
272def TF_ShapeAttrArray :
273    TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
274
275//===----------------------------------------------------------------------===//
276// TensorFlow type definitions
277//===----------------------------------------------------------------------===//
278
279// Any tensor element type defined in the TensorFlow dialect
280def TF_TFDialectType :
281    Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
282
283// Class for any TensorFlow dialect specific type
284class TF_TensorFlowType <string name, string description> :
285    Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
286         "TensorFlow " # description # " type">,
287    BuildableType<"getType<mlir::TF::" # name # "Type>()">;
288
289//===----------------------------------------------------------------------===//
290// Reference types
291
292// Float reference types
293def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
294def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
295def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
296def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
297
298// Complex reference types
299def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
300def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
301
302// Integer reference types
303def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
304def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
305def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
306def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
307
308def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
309def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
310def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
311def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
312
313// Quantized reference types
314def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
315def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
316def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
317def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
318def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
319
320// Other reference types
321def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
322def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
323def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
324def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
325
326//===----------------------------------------------------------------------===//
327// Integer types (including corresponding reference types)
328
329def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
330
331def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
332def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
333def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
334def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
335def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
336                           "32/64-bit signed integer">;
337
338def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
339def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
340def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
341def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
342
343// Any unsigned integer type
344def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
345                        "unsigned integer">;
346
347// Any signed integer type
348def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
349                        "signed integer">;
350
351// Any integer type
352def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
353
354// Tensor types
355def TF_BoolTensor : TensorOf<[TF_Bool]>;
356
357def TF_IntTensor : TensorOf<[TF_Int]>;
358def TF_Int8Tensor : TensorOf<[TF_Int8]>;
359def TF_Int16Tensor : TensorOf<[TF_Int16]>;
360def TF_Int32Tensor : TensorOf<[TF_Int32]>;
361def TF_Int64Tensor : TensorOf<[TF_Int64]>;
362def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
363
364def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
365def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
366def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
367def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
368
369//===----------------------------------------------------------------------===//
370// Quantized types (including corresponding reference types)
371
372def TF_Qint8   : AnyTypeOf<
373  [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
374  "8-bit quantized integer">;
375def TF_Qint16  : AnyTypeOf<
376  [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
377  "16-bit quantized integer">;
378def TF_Qint32  : AnyTypeOf<
379  [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
380  "32-bit quantized integer">;
381def TF_Quint8  : AnyTypeOf<
382  [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
383  "8-bit quantized unsigned integer">;
384def TF_Quint16 : AnyTypeOf<
385  [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
386  "16-bit quantized unsigned integer">;
387
388// Any quantized type
389def TF_Quantized : AnyTypeOf<
390  [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
391
392//===----------------------------------------------------------------------===//
393// Floating-point types (including corresponding reference types)
394
395def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
396def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
397def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
398def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
399
400def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
401
402def TF_Float : AnyTypeOf<
403  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
404  "floating-point">;
405
406// Tensor types
407def TF_FloatTensor : TensorOf<[TF_Float]>;
408def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
409def TF_Float16Tensor : TensorOf<[TF_Float16]>;
410def TF_Float32Tensor : TensorOf<[TF_Float32]>;
411def TF_Float64Tensor : TensorOf<[TF_Float64]>;
412def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
413
414//===----------------------------------------------------------------------===//
415// Complex types (including corresponding reference types)
416
417// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
418// with the associated cleanup.
419def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
420  "64-bit complex">;
421def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
422  "128-bit complex">;
423def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
424
425// Tensor types
426def TF_ComplexTensor : TensorOf<[TF_Complex]>;
427def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
428def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
429
430//===----------------------------------------------------------------------===//
431// String/variant/resource types (including corresponding reference types)
432
433def TF_Str : AnyTypeOf<
434  [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
435def TF_StrTensor : TensorOf<[TF_Str]>;
436
437def TF_Variant : AnyTypeOf<
438  [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
439def TF_VariantTensor : TensorOf<[TF_Variant]>;
440
441def TF_Resource : AnyTypeOf<
442  [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
443def TF_ResourceTensor : TensorOf<[TF_Resource]>;
444
445//===----------------------------------------------------------------------===//
446// Multi-category type constraints
447
448def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
449def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
450def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
451def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
452def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
453
454def TF_Number : AnyTypeOf<
455  [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
456def TF_NumberTensor : TensorOf<[TF_Number]>;
457
458def TF_NumberNotQuantizedTensor : TensorOf<
459  [TF_Float, TF_SInt, TF_Complex, TF_Uint8]>;
460
461def TF_NumberNotQuantizedOrStr :
462  AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
463def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
464
465//===----------------------------------------------------------------------===//
466// Tensor and tensor element types
467
468// Any tensor element type allowed in TensorFlow ops
469// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
470def TF_ElementType : Type<Or<[TF_Float.predicate,
471                              TF_Complex.predicate,
472                              TF_Int.predicate,
473                              TF_Bool.predicate,
474                              TF_TFDialectType.predicate]>,
475                          "tf.dtype">;
476
477// Any TensorFlow tensor type
478def TF_Tensor : TensorOf<[TF_ElementType]>;
479
480//===----------------------------------------------------------------------===//
481// TensorFlow attribute definitions
482//===----------------------------------------------------------------------===//
483
484//===----------------------------------------------------------------------===//
485// String attribute constraints
486
487// A string attribute whose value are one of the values in `cases`.
488class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
489  CPred<!foldl(
490      "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
491      !foreach(case, !tail(cases),
492               "$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
493      prev, cur, prev # " || " # cur)>,
494  "string attribute whose value is " #
495    !foldl(/*init*/!head(cases), /*list*/!tail(cases),
496           prev, cur, prev # ", or " # cur)>;
497
498// TODO: Use EnumAttr to define the common attribute cases
499
500def TF_ConvnetDataFormatAttr : StringBasedAttr<
501    CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
502          "$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
503    "'NHWC' or 'NCHW' convnet data format">;
504
505//===----------------------------------------------------------------------===//
506// Type attributes
507
508// A derived attribute that returns the size of `idx`-th ODS-declared variadic
509// operand.
510class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
511  "size_t",
512  "auto range = getODSOperands(" # idx # ");\n"
513  "return std::distance(range.begin(), range.end());",
514  [{ $_builder.getI64IntegerAttr($_self) }]>;
515
516// A derived attribute that returns the element type of `idx`-th ODS-declared
517// operand. If the `idx`-th operand is a variadic operand, then this attribute
518// just returns the element type of its first tensor, which is only meaningful
519// when the variadic operand has at least one tensor and the tensors all have
520// the same element type.
521class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
522  "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
523
524// A derived attribute that returns the element types of the tensors in the
525// actual value pack that corresponds to the `idx`-th ODS-declared variadic
526// operand. This returns a list of element types so it is used for variadic
527// operands that can have different element types.
528class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
529  "mlir::OperandElementTypeRange",
530  "auto values = getODSOperands(" # idx # ");\n"
531  "return {mlir::OperandElementTypeIterator(values.begin()), "
532          "mlir::OperandElementTypeIterator(values.end())};",
533  [{
534    ArrayAttr::get($_ctxt,
535    [&]() {
536      llvm::SmallVector<Attribute, 4> ret;
537      for (auto t : $_self)
538        ret.push_back(TypeAttr::get(t));
539      return ret;
540    }())
541  }]
542>;
543
544// A derived attribute that returns the shapes of the tensors in the actual
545// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
546// This returns a list of shapes so it is used for variadic operands that
547// can have different shapes.
548class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
549  "::mlir::TF::OperandShapeRange",
550  "auto values = getODSOperands(" # idx # ");\n"
551  "return {mlir::TF::OperandShapeIterator(values.begin()), "
552          "mlir::TF::OperandShapeIterator(values.end())};",
553  [{
554    ArrayAttr::get($_ctxt,
555      [&](){
556        llvm::SmallVector<Attribute, 4> ret;
557        for (auto shape : $_self)
558          ret.push_back(mlir::TF::ShapeAttr::get($_ctxt, shape));
559        return ret;
560      }())
561  }]
562>;
563
564// A derived attribute that returns the size of `idx`-th ODS-declared variadic
565// result.
566class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
567  "size_t",
568  "auto range = getODSResults(" # idx # ");\n"
569  "return std::distance(range.begin(), range.end());",
570  [{ $_builder.getI64IntegerAttr($_self) }]>;
571
572// A derived attribute that returns the element type of `idx`-th ODS-declared
573// result. If the `idx`-th result is a variadic result, then this attribute
574// just returns the element type of its first tensor, which is only meaningful
575// when the variadic result has at least one tensor and the tensors all have
576// the same element type.
577class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
578  "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
579
580// A derived attribute that returns the element types of the tensors in the
581// actual value pack that corresponds to the `idx`-th ODS-declared variadic
582// result. This returns a list of element types so it is used for variadic
583// results that can have different element types.
584class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
585  "mlir::ResultElementTypeRange",
586  "auto values = getODSResults(" # idx # ");\n"
587  "return {mlir::ResultElementTypeIterator(values.begin()), "
588          "mlir::ResultElementTypeIterator(values.end())};",
589  [{
590    ArrayAttr::get($_ctxt,
591    [&]() {
592      llvm::SmallVector<Attribute, 4> ret;
593      for (auto t : $_self)
594        ret.push_back(TypeAttr::get(t));
595      return ret;
596    }())
597  }]
598>;
599
600// A derived attribute that returns the shapes of the tensors in the actual
601// value pack that corresponds to the `idx`-th ODS-declared variadic result.
602// This returns a list of shapes so it is used for variadic results that
603// can have different shapes.
604class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
605  "mlir::TF::ResultShapeRange",
606  "auto values = getODSResults(" # idx # ");\n"
607  "return {mlir::TF::ResultShapeIterator(values.begin()), "
608          "mlir::TF::ResultShapeIterator(values.end())};",
609  [{
610    ArrayAttr::get($_ctxt,
611      [&](){
612        llvm::SmallVector<Attribute, 4> ret;
613        for (auto shape : $_self)
614          ret.push_back(mlir::TF::ShapeAttr::get($_ctxt, shape));
615        return ret;
616      }())
617  }]
618>;
619
620// A derived attribute that returns the shape of the first result type.
621def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
622  "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
623  [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>;
624
625def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
626  let returnType = "Type";
627}
628
629//===----------------------------------------------------------------------===//
630// TensorFlow common builders
631//===----------------------------------------------------------------------===//
632
633// Mixin class defining a builder for binary ops supporting broadcast
634// behavior. The result type has the same element type as both operands.
635class WithBroadcastableBinOpBuilder {
636  list<OpBuilder> builders = [
637    OpBuilder<(ins "Value":$x, "Value":$y),
638    [{
639  auto resultType =
640      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
641  if (!resultType)
642    mlir::emitError($_state.location, "non-broadcastable operands");
643  return build($_builder, $_state, resultType, x, y);
644}]>];
645}
646
647// Mixin class defining a builder for comparison ops supporting broadcast
648// behavior. The result type has bool element type.
649class WithBroadcastableCmpOpBuilder {
650  list<OpBuilder> builders = [
651    OpBuilder<(ins "Value":$x, "Value":$y),
652    [{
653  Type resultType;
654  if (x.getType().isa<UnrankedTensorType>() ||
655      y.getType().isa<UnrankedTensorType>()) {
656    resultType = UnrankedTensorType::get($_builder.getI1Type());
657  } else {
658    SmallVector<int64_t, 4> resultShape;
659    if (!OpTrait::util::getBroadcastedShape(
660            x.getType().cast<ShapedType>().getShape(),
661            y.getType().cast<ShapedType>().getShape(), resultShape)) {
662      mlir::emitError($_state.location,
663                      "operands have no broadcastable shapes");
664    }
665
666    resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
667  }
668  return build($_builder, $_state, resultType, x, y);
669}]>];
670}
671
672#endif // TF_OP_BASE
673