1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the base operation definition file for TensorFlow. 17// 18// This file includes the definition for the TensorFlow dialect, base TensorFlow 19// op, and various commonly used TensorFlow traits, types, attributes, and 20// builders. 21 22#ifndef TF_OP_BASE 23#define TF_OP_BASE 24 25include "mlir/IR/OpBase.td" 26include "mlir/Interfaces/SideEffectInterfaces.td" 27include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" 28 29//===----------------------------------------------------------------------===// 30// TensorFlow dialect definitions 31//===----------------------------------------------------------------------===// 32 33def TensorFlowDialect : Dialect { 34 let name = "tf"; 35 36 let description = [{ 37The TensorFlow dialect. 38 39This dialect maps to TensorFlow operations. 40 41Invariants: 42 43* All values are of Tensor type (in particular, scalars are 44 represented using zero-dimensional tensors); 45 46TODO: Make invariants more structured so that we can reference them in ops. 47 }]; 48 49 let cppNamespace = "::mlir::TF"; 50 51 let emitAccessorPrefix = kEmitAccessorPrefix_Raw; 52} 53 54//===----------------------------------------------------------------------===// 55// TensorFlow traits 56//===----------------------------------------------------------------------===// 57 58// Specify this trait if the op requires all outputs to have the same type and 59// the inputs either have the same type as result or a ref type corresponding to 60// the result type. 61def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< 62 "TF::OperandsSameAsResultsTypeOrRef">; 63 64// Op has the same operand and result element types (or type itself, if scalar) 65// after resolving reference types (i.e., after converting reference types to 66// their corresponding TensorFlow or standard types). 67def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< 68 "TF::SameOperandsAndResultElementTypeResolveRef">; 69 70// Op has the same operand and result types after resolving reference types 71// (i.e., after converting reference types to their corresponding TensorFlow or 72// standard types). 73def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait< 74 "TF::SameOperandsAndResultTypeResolveRef">; 75 76// Layout agnostic operations do not depend on the operands data layout (data 77// format), as an example all element wise operations are layout agnostic. 78def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; 79 80// Trait to indicate operations that cannot be duplicated as they might carry 81// certain state around within their implementations. 82def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; 83 84// Trait to indicate an operation cannot be constant folded. 85def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; 86 87// Coefficient wise binary operation with implicit broadcasting support, for 88// example tf.Sub operation. 89def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; 90 91// Coefficient wise unary operation, for example tf.Sqrt operation. 92def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">; 93 94// Variant of broadcastable trait that considers TF's subtype behavior. 95class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[ 96 TCOpResIsShapedTypePred<opId, resId>, 97 CPred<"mlir::tf_type::BroadcastCompatible(" 98 "$_op.getOperand(" # opId # ").getType(), " 99 "$_op.getResult(" # resId # ").getType())">]>; 100 101 102class TF_AllTypesMatchPred<list<string> values> : 103 CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" # 104 !interleave(values, ", ") # "}))">; 105 106class TF_AllTypesMatch<list<string> names> : 107 PredOpTrait< 108 "all of {" # !interleave(names, ", ") # 109 "} have dynamically equal types ", 110 TF_AllTypesMatchPred< 111 !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; 112 113// This trait indicates that all returned resources are unique for a 114// resource-allocating op (i.e. op with `MemAlloc` side effect). 115// 116// Note that if the trait is used where this invariant is not true, then this 117// might lead to incorrect execution order, while if not used where it should 118// be, it can only lead to reduced performance due to conservative ordering. 119// Example op where the invariant is not true: `TF_VarHandleOp`. 120def TF_UniqueResourceAllocation: TraitList<[ 121 TF_ResourceHandleAllocatorInterface, 122 NativeOpTrait<"TF::UniqueResourceAllocation"> 123]>; 124 125//===----------------------------------------------------------------------===// 126// Rank/Shape helpers. 127//===----------------------------------------------------------------------===// 128 129class TF_OperandIsUnrankedPred<int n> : 130 CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">; 131 132class TF_ResultIsUnrankedPred<int n> : 133 CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">; 134 135// Returns true if the n-th operand has unknown rank or has rank m. 136class TF_OperandHasRank<int n, int m> : 137 PredOpTrait<"operand " # n # " is " # m # "-D", 138 Or<[TF_OperandIsUnrankedPred<n>, 139 CPred<"$_op.getOperand(" # n # 140 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 141 142// Returns true if the n-th result has unknown rank or has rank m. 143class TF_ResultHasRank<int n, int m> : 144 PredOpTrait<"result " # n # " is " # m # "-D", 145 Or<[TF_ResultIsUnrankedPred<n>, 146 CPred<"$_op.getResult(" # n # 147 ").getType().cast<ShapedType>().getRank() == " # m>]>>; 148 149//===----------------------------------------------------------------------===// 150// TensorFlow resources and side effects 151//===----------------------------------------------------------------------===// 152 153class TF_ResourceBase<string resourceKind> : 154 Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> { 155} 156 157// Resource types 158def TF_VariableResource : TF_ResourceBase<"Variable">; 159def TF_StackResource : TF_ResourceBase<"Stack">; 160def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; 161def TF_SummaryResource : TF_ResourceBase<"Summary">; 162def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; 163def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; 164def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; 165def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; 166def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">; 167def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">; 168def TF_SendResource : TF_ResourceBase<"Send">; 169def TF_RecvResource : TF_ResourceBase<"Recv">; 170def TF_TPUExecuteResource : TF_ResourceBase<"TPUExecute">; 171def TF_RandomGeneratorResource : TF_ResourceBase<"RandomGenerator">; 172def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">; 173def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">; 174 175// Fake resource, see `TF_MustExecute` below. 176def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">; 177 178// Value-based side effects 179// 180// Value-based side effect traits are attached to op operands or results to 181// signal what type of resource is accessed and in which way. 182def TF_VariableRead : MemRead<TF_VariableResource>; 183def TF_StackRead : MemRead<TF_StackResource>; 184def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>; 185def TF_LookupTableRead : MemRead<TF_LookupTableResource>; 186def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>; 187def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>; 188def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>; 189 190def TF_VariableWrite : MemWrite<TF_VariableResource>; 191def TF_StackWrite : MemWrite<TF_StackResource>; 192def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>; 193def TF_SummaryWrite : MemWrite<TF_SummaryResource>; 194def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>; 195def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>; 196def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>; 197def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>; 198 199def TF_VariableAlloc : MemAlloc<TF_VariableResource>; 200def TF_StackAlloc : MemAlloc<TF_StackResource>; 201def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>; 202def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>; 203def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>; 204def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>; 205def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>; 206def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>; 207 208def TF_StackFree : MemFree<TF_StackResource>; 209def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>; 210def TF_SummaryFree : MemFree<TF_SummaryResource>; 211def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>; 212def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>; 213def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>; 214 215// Op-based side effects 216 217// Op-based side effect traits can be used to enforce certain execution order 218// constraints, in particular for ops that don't use resource handles (those 219// typically have value-based side effects). For a `read` effect, all instances 220// of ops with the trait keep their order to all ops with unknown side effects 221// (e.g. `stateful` ops). For a `write` effect, all instances of ops with the 222// trait stay in order, and they also keep their order to all unknown side- 223// effecting ops. Note that for `read` effects ops might be pruned if nothing 224// depends on them. 225def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>; 226 227def TF_TPUEmbeddingWriteEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>; 228def TF_TPUEmbeddingReadEffect : MemoryEffects<[MemRead<TF_TPUEmbeddingResource>]>; 229 230def TF_SendSideEffect : MemoryEffects<[MemWrite<TF_SendResource>]>; 231def TF_RecvSideEffect : MemoryEffects<[MemWrite<TF_RecvResource>]>; 232def TF_XlaHostComputeSideEffect : MemoryEffects<[MemWrite<TF_XlaHostComputeResource>]>; 233 234def TF_TPUExecuteSideEffect : MemoryEffects<[MemWrite<TF_TPUExecuteResource>]>; 235 236def TF_RandomGeneratorSideEffect : MemoryEffects<[MemWrite<TF_RandomGeneratorResource>]>; 237 238// Special effect for keeping `CollectiveReduce` ops in order. 239def TF_CollectiveReduceOrderingEffect : MemoryEffects<[MemWrite<TF_CollectiveReduceOrderingResource>]>; 240 241// Trait for enforcing that a side-effecting op is executed, even if it would be 242// considered dead by MLIR (see b/195782952). 243// The trait is implemented as a write effect for a fake resource which is 244// ignored by side effect analysis, so it does not affect execution order 245// constraints and control dependencies at all (for example, multiple ops with 246// this trait do not have to execute in order). 247def TF_MustExecute : MemoryEffects<[MemWrite<TF_MustExecuteResource>]>; 248 249//===----------------------------------------------------------------------===// 250// TensorFlow op definitions 251//===----------------------------------------------------------------------===// 252 253class TF_Op<string mnemonic, list<Trait> traits = []> : 254 Op<TensorFlowDialect, mnemonic, traits>; 255 256//===----------------------------------------------------------------------===// 257// TensorFlow attribute definitions 258//===----------------------------------------------------------------------===// 259 260class TF_TensorFlowAttr <string name, string description> : 261 Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">, 262 "TensorFlow " # description # " attribute">; 263 264def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { 265 let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>"; 266 let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()"; 267 268 // Create a ranked shape attr by default. 269 let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; 270} 271 272def TF_ShapeAttrArray : 273 TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">; 274 275//===----------------------------------------------------------------------===// 276// TensorFlow type definitions 277//===----------------------------------------------------------------------===// 278 279// Any tensor element type defined in the TensorFlow dialect 280def TF_TFDialectType : 281 Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">; 282 283// Class for any TensorFlow dialect specific type 284class TF_TensorFlowType <string name, string description> : 285 Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">, 286 "TensorFlow " # description # " type">, 287 BuildableType<"getType<mlir::TF::" # name # "Type>()">; 288 289//===----------------------------------------------------------------------===// 290// Reference types 291 292// Float reference types 293def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; 294def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; 295def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; 296def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; 297 298// Complex reference types 299def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; 300def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; 301 302// Integer reference types 303def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; 304def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; 305def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; 306def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; 307 308def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; 309def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; 310def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; 311def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; 312 313// Quantized reference types 314def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; 315def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; 316def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; 317def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; 318def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; 319 320// Other reference types 321def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; 322def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; 323def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; 324def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; 325 326//===----------------------------------------------------------------------===// 327// Integer types (including corresponding reference types) 328 329def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">; 330 331def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">; 332def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">; 333def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">; 334def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">; 335def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref], 336 "32/64-bit signed integer">; 337 338def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">; 339def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">; 340def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">; 341def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">; 342 343// Any unsigned integer type 344def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64], 345 "unsigned integer">; 346 347// Any signed integer type 348def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64], 349 "signed integer">; 350 351// Any integer type 352def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; 353 354// Tensor types 355def TF_BoolTensor : TensorOf<[TF_Bool]>; 356 357def TF_IntTensor : TensorOf<[TF_Int]>; 358def TF_Int8Tensor : TensorOf<[TF_Int8]>; 359def TF_Int16Tensor : TensorOf<[TF_Int16]>; 360def TF_Int32Tensor : TensorOf<[TF_Int32]>; 361def TF_Int64Tensor : TensorOf<[TF_Int64]>; 362def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>; 363 364def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; 365def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; 366def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; 367def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; 368 369//===----------------------------------------------------------------------===// 370// Quantized types (including corresponding reference types) 371 372def TF_Qint8 : AnyTypeOf< 373 [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref], 374 "8-bit quantized integer">; 375def TF_Qint16 : AnyTypeOf< 376 [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref], 377 "16-bit quantized integer">; 378def TF_Qint32 : AnyTypeOf< 379 [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref], 380 "32-bit quantized integer">; 381def TF_Quint8 : AnyTypeOf< 382 [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref], 383 "8-bit quantized unsigned integer">; 384def TF_Quint16 : AnyTypeOf< 385 [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref], 386 "16-bit quantized unsigned integer">; 387 388// Any quantized type 389def TF_Quantized : AnyTypeOf< 390 [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">; 391 392//===----------------------------------------------------------------------===// 393// Floating-point types (including corresponding reference types) 394 395def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; 396def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; 397def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; 398def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; 399 400def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; 401 402def TF_Float : AnyTypeOf< 403 [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], 404 "floating-point">; 405 406// Tensor types 407def TF_FloatTensor : TensorOf<[TF_Float]>; 408def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>; 409def TF_Float16Tensor : TensorOf<[TF_Float16]>; 410def TF_Float32Tensor : TensorOf<[TF_Float32]>; 411def TF_Float64Tensor : TensorOf<[TF_Float64]>; 412def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; 413 414//===----------------------------------------------------------------------===// 415// Complex types (including corresponding reference types) 416 417// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along 418// with the associated cleanup. 419def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref], 420 "64-bit complex">; 421def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref], 422 "128-bit complex">; 423def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; 424 425// Tensor types 426def TF_ComplexTensor : TensorOf<[TF_Complex]>; 427def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; 428def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; 429 430//===----------------------------------------------------------------------===// 431// String/variant/resource types (including corresponding reference types) 432 433def TF_Str : AnyTypeOf< 434 [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">; 435def TF_StrTensor : TensorOf<[TF_Str]>; 436 437def TF_Variant : AnyTypeOf< 438 [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">; 439def TF_VariantTensor : TensorOf<[TF_Variant]>; 440 441def TF_Resource : AnyTypeOf< 442 [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">; 443def TF_ResourceTensor : TensorOf<[TF_Resource]>; 444 445//===----------------------------------------------------------------------===// 446// Multi-category type constraints 447 448def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>; 449def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>; 450def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; 451def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; 452def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; 453 454def TF_Number : AnyTypeOf< 455 [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">; 456def TF_NumberTensor : TensorOf<[TF_Number]>; 457 458def TF_NumberNotQuantizedTensor : TensorOf< 459 [TF_Float, TF_SInt, TF_Complex, TF_Uint8]>; 460 461def TF_NumberNotQuantizedOrStr : 462 AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; 463def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; 464 465//===----------------------------------------------------------------------===// 466// Tensor and tensor element types 467 468// Any tensor element type allowed in TensorFlow ops 469// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) 470def TF_ElementType : Type<Or<[TF_Float.predicate, 471 TF_Complex.predicate, 472 TF_Int.predicate, 473 TF_Bool.predicate, 474 TF_TFDialectType.predicate]>, 475 "tf.dtype">; 476 477// Any TensorFlow tensor type 478def TF_Tensor : TensorOf<[TF_ElementType]>; 479 480//===----------------------------------------------------------------------===// 481// TensorFlow attribute definitions 482//===----------------------------------------------------------------------===// 483 484//===----------------------------------------------------------------------===// 485// String attribute constraints 486 487// A string attribute whose value are one of the values in `cases`. 488class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr< 489 CPred<!foldl( 490 "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"", 491 !foreach(case, !tail(cases), 492 "$_self.cast<StringAttr>().getValue() == \"" # case # "\""), 493 prev, cur, prev # " || " # cur)>, 494 "string attribute whose value is " # 495 !foldl(/*init*/!head(cases), /*list*/!tail(cases), 496 prev, cur, prev # ", or " # cur)>; 497 498// TODO: Use EnumAttr to define the common attribute cases 499 500def TF_ConvnetDataFormatAttr : StringBasedAttr< 501 CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " # 502 "$_self.cast<StringAttr>().getValue() == \"NCHW\"">, 503 "'NHWC' or 'NCHW' convnet data format">; 504 505//===----------------------------------------------------------------------===// 506// Type attributes 507 508// A derived attribute that returns the size of `idx`-th ODS-declared variadic 509// operand. 510class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr< 511 "size_t", 512 "auto range = getODSOperands(" # idx # ");\n" 513 "return std::distance(range.begin(), range.end());", 514 [{ $_builder.getI64IntegerAttr($_self) }]>; 515 516// A derived attribute that returns the element type of `idx`-th ODS-declared 517// operand. If the `idx`-th operand is a variadic operand, then this attribute 518// just returns the element type of its first tensor, which is only meaningful 519// when the variadic operand has at least one tensor and the tensors all have 520// the same element type. 521class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr< 522 "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">; 523 524// A derived attribute that returns the element types of the tensors in the 525// actual value pack that corresponds to the `idx`-th ODS-declared variadic 526// operand. This returns a list of element types so it is used for variadic 527// operands that can have different element types. 528class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr< 529 "mlir::OperandElementTypeRange", 530 "auto values = getODSOperands(" # idx # ");\n" 531 "return {mlir::OperandElementTypeIterator(values.begin()), " 532 "mlir::OperandElementTypeIterator(values.end())};", 533 [{ 534 ArrayAttr::get($_ctxt, 535 [&]() { 536 llvm::SmallVector<Attribute, 4> ret; 537 for (auto t : $_self) 538 ret.push_back(TypeAttr::get(t)); 539 return ret; 540 }()) 541 }] 542>; 543 544// A derived attribute that returns the shapes of the tensors in the actual 545// value pack that corresponds to the `idx`-th ODS-declared variadic operand. 546// This returns a list of shapes so it is used for variadic operands that 547// can have different shapes. 548class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr< 549 "::mlir::TF::OperandShapeRange", 550 "auto values = getODSOperands(" # idx # ");\n" 551 "return {mlir::TF::OperandShapeIterator(values.begin()), " 552 "mlir::TF::OperandShapeIterator(values.end())};", 553 [{ 554 ArrayAttr::get($_ctxt, 555 [&](){ 556 llvm::SmallVector<Attribute, 4> ret; 557 for (auto shape : $_self) 558 ret.push_back(mlir::TF::ShapeAttr::get($_ctxt, shape)); 559 return ret; 560 }()) 561 }] 562>; 563 564// A derived attribute that returns the size of `idx`-th ODS-declared variadic 565// result. 566class TF_DerivedResultSizeAttr<int idx> : DerivedAttr< 567 "size_t", 568 "auto range = getODSResults(" # idx # ");\n" 569 "return std::distance(range.begin(), range.end());", 570 [{ $_builder.getI64IntegerAttr($_self) }]>; 571 572// A derived attribute that returns the element type of `idx`-th ODS-declared 573// result. If the `idx`-th result is a variadic result, then this attribute 574// just returns the element type of its first tensor, which is only meaningful 575// when the variadic result has at least one tensor and the tensors all have 576// the same element type. 577class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr< 578 "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">; 579 580// A derived attribute that returns the element types of the tensors in the 581// actual value pack that corresponds to the `idx`-th ODS-declared variadic 582// result. This returns a list of element types so it is used for variadic 583// results that can have different element types. 584class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr< 585 "mlir::ResultElementTypeRange", 586 "auto values = getODSResults(" # idx # ");\n" 587 "return {mlir::ResultElementTypeIterator(values.begin()), " 588 "mlir::ResultElementTypeIterator(values.end())};", 589 [{ 590 ArrayAttr::get($_ctxt, 591 [&]() { 592 llvm::SmallVector<Attribute, 4> ret; 593 for (auto t : $_self) 594 ret.push_back(TypeAttr::get(t)); 595 return ret; 596 }()) 597 }] 598>; 599 600// A derived attribute that returns the shapes of the tensors in the actual 601// value pack that corresponds to the `idx`-th ODS-declared variadic result. 602// This returns a list of shapes so it is used for variadic results that 603// can have different shapes. 604class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr< 605 "mlir::TF::ResultShapeRange", 606 "auto values = getODSResults(" # idx # ");\n" 607 "return {mlir::TF::ResultShapeIterator(values.begin()), " 608 "mlir::TF::ResultShapeIterator(values.end())};", 609 [{ 610 ArrayAttr::get($_ctxt, 611 [&](){ 612 llvm::SmallVector<Attribute, 4> ret; 613 for (auto shape : $_self) 614 ret.push_back(mlir::TF::ShapeAttr::get($_ctxt, shape)); 615 return ret; 616 }()) 617 }] 618>; 619 620// A derived attribute that returns the shape of the first result type. 621def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", 622 "return (*getOperation()->result_type_begin()).cast<ShapedType>();", 623 [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; 624 625def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { 626 let returnType = "Type"; 627} 628 629//===----------------------------------------------------------------------===// 630// TensorFlow common builders 631//===----------------------------------------------------------------------===// 632 633// Mixin class defining a builder for binary ops supporting broadcast 634// behavior. The result type has the same element type as both operands. 635class WithBroadcastableBinOpBuilder { 636 list<OpBuilder> builders = [ 637 OpBuilder<(ins "Value":$x, "Value":$y), 638 [{ 639 auto resultType = 640 OpTrait::util::getBroadcastedType(x.getType(), y.getType()); 641 if (!resultType) 642 mlir::emitError($_state.location, "non-broadcastable operands"); 643 return build($_builder, $_state, resultType, x, y); 644}]>]; 645} 646 647// Mixin class defining a builder for comparison ops supporting broadcast 648// behavior. The result type has bool element type. 649class WithBroadcastableCmpOpBuilder { 650 list<OpBuilder> builders = [ 651 OpBuilder<(ins "Value":$x, "Value":$y), 652 [{ 653 Type resultType; 654 if (x.getType().isa<UnrankedTensorType>() || 655 y.getType().isa<UnrankedTensorType>()) { 656 resultType = UnrankedTensorType::get($_builder.getI1Type()); 657 } else { 658 SmallVector<int64_t, 4> resultShape; 659 if (!OpTrait::util::getBroadcastedShape( 660 x.getType().cast<ShapedType>().getShape(), 661 y.getType().cast<ShapedType>().getShape(), resultShape)) { 662 mlir::emitError($_state.location, 663 "operands have no broadcastable shapes"); 664 } 665 666 resultType = RankedTensorType::get(resultShape, $_builder.getI1Type()); 667 } 668 return build($_builder, $_state, resultType, x, y); 669}]>]; 670} 671 672#endif // TF_OP_BASE 673