xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/Dialect/Traits.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/Matchers.h"  // from @llvm-project
38 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
39 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
40 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
41 #include "mlir/IR/Types.h"  // from @llvm-project
42 #include "mlir/IR/Value.h"  // from @llvm-project
43 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
44 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47 
48 namespace mlir {
49 namespace tf_executor {
50 
51 //===----------------------------------------------------------------------===//
52 // TF Executor Dialect
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 
57 struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
58   using DialectInlinerInterface::DialectInlinerInterface;
59 
60   //===--------------------------------------------------------------------===//
61   // Analysis Hooks
62   //===--------------------------------------------------------------------===//
63 
64   // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_executor::__anon2366cd000111::TensorFlowExecutorInlinerInterface65   bool isLegalToInline(Operation *call, Operation *callable,
66                        bool wouldBeCloned) const final {
67     return true;
68   }
69   // Override the inlining hook to determine if 'src' can be inlined into
70   // 'dest'.
isLegalToInlinemlir::tf_executor::__anon2366cd000111::TensorFlowExecutorInlinerInterface71   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72                        BlockAndValueMapping &value_mapping) const final {
73     // Allow inlining into tf.island regions if the incoming region has a single
74     // block.
75     return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
76            llvm::hasSingleElement(*src);
77   }
78 };
79 
80 struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
81   using DialectFoldInterface::DialectFoldInterface;
82 
83   // Registered hook to check if the given region, which is attached to an
84   // operation that is *not* isolated from above (i.e. no internal regions
85   // reference values defined in an enclosing region), should be used when
86   // materializing constants.
87   // In the executor dialect we materialize inside an island.
shouldMaterializeIntomlir::tf_executor::__anon2366cd000111::TensorFlowExecutorDialectFoldInterface88   bool shouldMaterializeInto(Region *region) const final {
89     return isa<tf_executor::IslandOp>(region->getParentOp());
90   }
91 };
92 
93 }  // namespace
94 
TensorFlowExecutorDialect(MLIRContext * context)95 TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
96     : Dialect(/*name=*/"tf_executor", context,
97               TypeID::get<TensorFlowExecutorDialect>()) {
98   addOperations<
99 #define GET_OP_LIST
100 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
101       >();
102 
103   addInterfaces<TensorFlowExecutorInlinerInterface,
104                 TensorFlowExecutorDialectFoldInterface>();
105 
106   addTypes<ControlType, TokenType>();
107 }
108 
parseType(DialectAsmParser & parser) const109 Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const {
110   StringRef data_type;
111   if (parser.parseKeyword(&data_type)) return Type();
112 
113   if (data_type == "control") return ControlType::get(getContext());
114   if (data_type == "token") return TokenType::get(getContext());
115   parser.emitError(parser.getNameLoc())
116       << "unknown tf_executor type: " << data_type;
117   return nullptr;
118 }
119 
printType(Type type,DialectAsmPrinter & os) const120 void TensorFlowExecutorDialect::printType(Type type,
121                                           DialectAsmPrinter &os) const {
122   if (type.isa<ControlType>()) {
123     os << "control";
124     return;
125   }
126   if (type.isa<TokenType>()) {
127     os << "token";
128     return;
129   }
130   os << "<unknown tf_executor type>";
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Implementation for all the operations defined in ODS (op definition spec).
135 //===----------------------------------------------------------------------===//
136 
137 namespace {
138 
139 // Verifies that every control operands are at the end of the list.
140 // Used by the constraint `ControlOperandsAfterAllData` in ODS.
VerifyControlOperandsAfterAllData(Operation * op)141 LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
142   bool found_control = false;
143   for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
144     if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
145       found_control = true;
146       continue;
147     }
148     if (found_control)
149       return op->emitOpError() << "found non-control operand #" << operand_idx
150                                << " after control operand";
151   }
152   return success();
153 }
154 
155 }  // anonymous namespace
156 
157 //===----------------------------------------------------------------------===//
158 // tf_executor.graph
159 //===----------------------------------------------------------------------===//
160 
GetFetch()161 FetchOp GraphOp::GetFetch() { return llvm::cast<FetchOp>(GetBody().back()); }
162 
verify()163 LogicalResult GraphOp::verify() {
164   GraphOp graph = *this;
165   auto *executorDialect = graph->getDialect();
166 
167   if (graph.GetBody().empty())
168     return graph.emitOpError() << "expects a non-empty body";
169 
170   // Only tf_executor dialect operations are allowed to be immediately nested
171   // in a tf_executor.graph region.
172   for (Operation &op : graph.GetBody()) {
173     if (op.getDialect() != executorDialect)
174       return op.emitOpError() << "unallowed inside a tf_executor.graph region";
175     if (isa<GraphOp>(op))
176       return op.emitOpError()
177              << "unallowed directly inside another tf_executor.graph";
178   }
179 
180   Operation &fetch = graph.GetBody().back();
181   if (!isa<FetchOp>(fetch))
182     return fetch.emitOpError()
183            << "invalid tf_executor.graph terminator, fetch expected";
184 
185   // Ensure that the fetch terminator operands matches the graph result type.
186   // All the non-control operands of the fetch operation must match the graph
187   // returned value.
188   if (fetch.getNumOperands() < graph.getNumResults())
189     return fetch.emitOpError() << "does not have enough operands to cover the "
190                                   "graph returned values";
191   for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
192     Value operand = fetch.getOperand(i);
193     // Break out of the loop at the first control operand encountered.
194     const int64_t num_results = graph.getNumResults();
195     if (operand.getType().isa<ControlType>()) {
196       if (i != num_results)
197         return fetch.emitOpError()
198                << "operand #" << i
199                << " is a control type, can't be bound to a graph result";
200       break;
201     }
202     if (i >= num_results)
203       return fetch.emitOpError()
204              << "operand #" << i << " does not have a graph results to bind";
205     if (graph.getResult(i).getType() != operand.getType()) {
206       return fetch.emitOpError()
207              << "operand #" << i << " type mismatch graph results ("
208              << graph.getResult(i).getType() << " != " << operand.getType()
209              << ")";
210     }
211   }
212   return success();
213 }
214 
print(OpAsmPrinter & p)215 void GraphOp::print(OpAsmPrinter &p) {
216   p << ' ';
217   p.printRegion(getOperation()->getRegion(0));
218   p.printOptionalAttrDict(getOperation()->getAttrs());
219 }
220 
parse(OpAsmParser & parser,OperationState & result)221 ParseResult GraphOp::parse(OpAsmParser &parser, OperationState &result) {
222   llvm::SMLoc loc = parser.getCurrentLocation();
223 
224   // Parse the body region.
225   Region &body = *result.addRegion();
226   if (parser.parseRegion(body)) return failure();
227 
228   // Ensure that the region is well formed: it contains at least a block with
229   // a FetchOp terminator.
230   GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
231 
232   if (!llvm::hasSingleElement(body))
233     return parser.emitError(loc) << "expects a single block region";
234 
235   // Get the results type from the terminator type inside the graph.
236   Operation &fetch = body.back().back();
237   if (!isa<FetchOp>(fetch))
238     return parser.emitError(loc) << "expects a tf_executor.fetch terminator";
239 
240   // The return value of the graph operation are the non-control operands of
241   // the fetch operation.
242   result.types.reserve(fetch.getNumOperands());
243   for (Type type : fetch.getOperandTypes()) {
244     if (type.isa<ControlType>()) break;
245     result.types.push_back(type);
246   }
247 
248   // Parse the optional attribute list.
249   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
250 
251   return success();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // tf_executor.fetch
256 //===----------------------------------------------------------------------===//
257 
258 //===----------------------------------------------------------------------===//
259 // tf_executor.island
260 //===----------------------------------------------------------------------===//
261 
GetYield()262 YieldOp IslandOp::GetYield() { return llvm::cast<YieldOp>(GetBody().back()); }
263 
264 // Checks if a tf_executor.island wraps a single operation and the single
265 // operation results are perfectly forwarded to the islands yield.
WrapsSingleOp()266 bool IslandOp::WrapsSingleOp() {
267   auto body = GetBody().without_terminator();
268   if (!hasSingleElement(body)) return false;
269 
270   Operation &wrapped_op = *body.begin();
271   YieldOp yield = GetYield();
272   return wrapped_op.getNumResults() == yield.getNumOperands() &&
273          std::equal(wrapped_op.getResults().begin(),
274                     wrapped_op.getResults().end(), yield.getOperands().begin());
275 }
276 
verify()277 mlir::LogicalResult IslandOp::verify() {
278   IslandOp island = *this;
279   if (!island.GetBody().args_empty())
280     return island.emitOpError() << "expects body without any arguments";
281 
282   Operation &yield = island.GetBody().back();
283   if (!isa<YieldOp>(yield))
284     return yield.emitOpError()
285            << "invalid tf_executor.island terminator, yield expected";
286 
287   // Ensure that the yield terminator operands matches the island results type.
288   int result_count = island.getNumResults() - 1;  // -1 for the control token
289   const int num_operands = yield.getNumOperands();
290   if (num_operands != result_count)
291     return yield.emitOpError()
292            << "has " << yield.getNumOperands()
293            << " operand, but island returns " << result_count;
294   for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
295     if (island.getResult(operand_idx).getType() !=
296         yield.getOperand(operand_idx).getType())
297       return yield.emitOpError()
298              << "operand #" << operand_idx << " type mismatch island results";
299   }
300 
301   // Check that there aren't any control results other than the last one.
302   Type control_type = ControlType::get(island.getContext());
303   for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
304     if (island.getResult(operand_idx).getType() == control_type)
305       return yield.emitOpError()
306              << "unexpected control type for operand #" << operand_idx;
307   }
308   return success();
309 }
310 
print(OpAsmPrinter & p)311 void IslandOp::print(OpAsmPrinter &p) {
312   if (getNumOperands()) {
313     // These are always control operand, no explicit type needed.
314     p << '(';
315     p.printOperands(getOperands());
316     p << ')';
317   }
318 
319   // Check if we can print the short "wraps" form: that is if the island
320   // contains a single operation and the result of this operation are perfectly
321   // forwarded to the yield.
322   if (getOperation()->getAttrs().empty() && WrapsSingleOp()) {
323     Operation &wrapped_op = GetBody().front();
324     YieldOp yield_op = GetYield();
325     // The "wraps" syntax only encodes a single location.
326     // In order to correctly round-trip, we can only use this syntax when all
327     // the locations are identical.
328     if (wrapped_op.getLoc() == getLoc() && yield_op.getLoc() == getLoc()) {
329       p << " wraps ";
330       p.printGenericOp(&wrapped_op);
331       return;
332     }
333   }
334   p << ' ';
335   p.printRegion(getOperation()->getRegion(0));
336   p.printOptionalAttrDict(getOperation()->getAttrs());
337 }
338 
parse(OpAsmParser & parser,OperationState & result)339 ParseResult IslandOp::parse(OpAsmParser &parser, OperationState &result) {
340   llvm::SMLoc loc = parser.getCurrentLocation();
341   Type control_type = ControlType::get(parser.getBuilder().getContext());
342 
343   // Parse optional argument list (control dependencies only).
344   SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
345   if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
346     return failure();
347   if (!op_infos.empty()) {
348     SmallVector<Type, 2> types(op_infos.size(), control_type);
349     if (parser.resolveOperands(op_infos, types, loc, result.operands))
350       return failure();
351   }
352 
353   // Parse the body region.
354   Region &body = *result.addRegion();
355 
356   if (succeeded(parser.parseOptionalKeyword("wraps"))) {
357     // If we parse the short version of the island, we have an operation in the
358     // generic form that follows the "wraps" keyword. Parse it inside the region
359     // and forward all of its results as-is to the yield operation.
360     body.push_back(new Block);
361     Block &block = body.back();
362     Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
363     if (!wrapped_op) return failure();
364     OpBuilder builder(parser.getBuilder().getContext());
365     builder.setInsertionPointToEnd(&block);
366     builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
367     result.location = wrapped_op->getLoc();
368   } else if (parser.parseRegion(body)) {
369     return failure();
370   }
371 
372   IslandOp::ensureTerminator(body, parser.getBuilder(), result.location);
373 
374   // Get the results type for the island from the terminator operands.
375   Operation &yield = body.back().back();
376   result.types.reserve(yield.getNumOperands() + 1);
377   result.types.append(yield.operand_type_begin(), yield.operand_type_end());
378   result.types.push_back(control_type);
379 
380   // Parse the optional attribute list.
381   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
382   return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // tf_executor.yield
387 //===----------------------------------------------------------------------===//
388 
389 //===----------------------------------------------------------------------===//
390 // tf_executor.Switch
391 //===----------------------------------------------------------------------===//
392 
parse(OpAsmParser & parser,OperationState & result)393 ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
394   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
395   SmallVector<Type, 1> types;
396   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
397     return failure();
398   if (types.size() != 1)
399     return parser.emitError(parser.getNameLoc())
400            << " expects only a single data type";
401 
402   // Support parsing either a functional type (in which case all the types are
403   // fully qualified) or a short form with a single type (in which case the data
404   // input and the outputs are all using this type and predicate is tensor<i1>
405   // type).
406   if (types.front().isa<FunctionType>()) {
407     FunctionType type = types.front().cast<FunctionType>();
408     if (type.getNumInputs() < 2)
409       return parser.emitError(parser.getNameLoc())
410              << " expects a single data type and a predicate";
411     result.types.assign(type.getResults().begin(), type.getResults().end());
412     types.assign(type.getInputs().begin(), type.getInputs().end());
413   } else {
414     if (op_infos.size() < 2)
415       return parser.emitError(parser.getNameLoc())
416              << " expects a single data type and a predicate";
417     Type control_type = ControlType::get(parser.getBuilder().getContext());
418     result.types.append(2, types[0]);
419     result.types.push_back(control_type);
420     Type i1_type = parser.getBuilder().getI1Type();
421     RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
422     types.push_back(predicate_type);
423     types.append(op_infos.size() - 2, control_type);
424   }
425 
426   llvm::SMLoc loc = parser.getCurrentLocation();
427   if (parser.resolveOperands(op_infos, types, loc, result.operands))
428     return failure();
429 
430   return parser.parseOptionalAttrDict(result.attributes);
431 }
432 
print(OpAsmPrinter & p)433 void SwitchOp::print(OpAsmPrinter &p) {
434   p << ' ';
435   p.printOperands(getOperands());
436   Type data_operand_ty = data().getType();
437   // If the types aren't perfectly matching, print the functional type syntax
438   // else print the shorter single type.
439   p << " : ";
440   if (trueOutput().getType() != data_operand_ty ||
441       falseOutput().getType() != data_operand_ty ||
442       predicate().getType().isa<UnrankedTensorType>()) {
443     p.printFunctionalType(getOperation());
444   } else {
445     p << getType(0);
446   }
447   p.printOptionalAttrDict(getOperation()->getAttrs());
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // tf_executor.SwitchN
452 //===----------------------------------------------------------------------===//
453 
verify()454 LogicalResult SwitchNOp::verify() {
455   SwitchNOp switchn = *this;
456   IntegerAttr num_outs = switchn->getAttrOfType<IntegerAttr>("num_outs");
457   if (!num_outs)
458     return switchn.emitOpError() << "expects a `num_outs` integer attribute";
459 
460   // Expect num_outs results + 1 control output.
461   if (switchn.getNumResults() != num_outs.getInt() + 1)
462     return switchn.emitOpError()
463            << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
464            << (switchn.getNumResults() - 1);
465 
466   // Check that operand can be broadcasted to each output type.
467   auto operand0_type = switchn.getOperand(0).getType();
468   TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
469   if (!operand0_tensor_type) {
470     return switchn.emitOpError()
471            << "expects data operand to have tensor type but got "
472            << operand0_type;
473   }
474   for (Type output_type : switchn.getResultTypes()) {
475     if (output_type.isa<ControlType>()) break;
476 
477     TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
478     if (!output_tensor_type) {
479       return switchn.emitOpError()
480              << "expects outputs to have tensor type but got " << output_type;
481     }
482 
483     // If the output type is a ref type, then the operand type should also be of
484     // the same ref type. However, if the output type is a non-ref type T, then
485     // the operand can be tensor of type T or T_REF.
486     bool is_output_ref =
487         output_tensor_type.getElementType().isa<tf_type::TensorFlowRefType>();
488     if (is_output_ref && !operand0_tensor_type.getElementType()
489                               .isa<tf_type::TensorFlowRefType>()) {
490       return switchn.emitOpError()
491              << "expects same operand and output element type but got "
492              << operand0_tensor_type << " vs " << output_tensor_type;
493     }
494     Type broadcasted_type = OpTrait::util::getBroadcastedType(
495         tf_type::DropRefAndSubTypes(operand0_tensor_type),
496         tf_type::DropRefAndSubTypes(output_tensor_type));
497     if (!broadcasted_type) {
498       return switchn.emitOpError()
499              << "expects data operand to be broadcastable with all output types"
500              << " but got " << operand0_tensor_type << " vs "
501              << output_tensor_type;
502     }
503   }
504   return success();
505 }
506 
print(OpAsmPrinter & p)507 void SwitchNOp::print(OpAsmPrinter &p) {
508   p << ' ';
509   auto operands = getOperands();
510   // Print the 2 data operands.
511   p.printOperands(operands.begin(), std::next(operands.begin(), 2));
512   p << " of " << (getNumResults() - 1);
513   // print control dependencies if any
514   if (!llvm::empty(controlInputs())) {
515     p << " (";
516     p.printOperands(controlInputs());
517     p << ")";
518   }
519   p << " : " << getType(0);
520   p.printOptionalAttrDict(getOperation()->getAttrs(), {"num_outs"});
521 }
522 
parse(OpAsmParser & parser,OperationState & result)523 ParseResult SwitchNOp::parse(OpAsmParser &parser, OperationState &result) {
524   // Parsing:
525   //       %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
526   // Where the first operand is the data to replicate, the second is an i32
527   // indicating which output to populate, followed by the keyword `of` and the
528   // number of outputs (+1 for the control token).
529   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
530   SmallVector<Type, 1> types;
531   llvm::SMLoc loc = parser.getCurrentLocation();
532   IntegerAttr num_outs;
533   Type i64_type = parser.getBuilder().getIntegerType(64);
534   if (parser.parseOperandList(op_infos, 2) || parser.parseKeyword("of") ||
535       parser.parseAttribute(num_outs, i64_type, "num_outs",
536                             result.attributes) ||
537       parser.parseOperandList(op_infos,
538                               OpAsmParser::Delimiter::OptionalParen) ||
539       parser.parseColonTypeList(types))
540     return failure();
541   if (types.size() != 1)
542     return parser.emitError(parser.getNameLoc())
543            << " expects only a single data type";
544 
545   if (num_outs.getInt() <= 0)
546     return parser.emitError(parser.getNameLoc())
547            << " expects a positive number of outputs";
548 
549   // `types` already contains the type for the data, add an i32 for the
550   // output_index, and then the optional control inputs.
551   auto builder = parser.getBuilder();
552   types.push_back(RankedTensorType::get({}, builder.getIntegerType(32)));
553   Type control_type = ControlType::get(builder.getContext());
554   types.append(op_infos.size() - 2, control_type);
555 
556   if (parser.resolveOperands(op_infos, types, loc, result.operands))
557     return failure();
558 
559   // Output result types is a replication `num_outs` times the data input type.
560   result.types.append(num_outs.getInt(), types[0]);
561   result.types.push_back(control_type);
562 
563   return parser.parseOptionalAttrDict(result.attributes);
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // tf_executor.Merge
568 //===----------------------------------------------------------------------===//
569 
verify()570 LogicalResult MergeOp::verify() {
571   MergeOp merge = *this;
572   if (!merge.getNumOperands())
573     return merge.emitOpError() << "expects at least one operand";
574 
575   Type data_type = merge.getOperand(0).getType();
576   if (data_type.isa<ControlType>())
577     return merge.emitOpError() << "expects a non-control input";
578 
579   // Check that each operand can be individually broadcasted to the output type.
580   Type output_type = merge.output().getType();
581   TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
582   if (!output_tensor_ty) {
583     return merge.emitOpError()
584            << "expects output to have tensor type but got " << output_type;
585   }
586   bool is_output_ref =
587       output_tensor_ty.getElementType().isa<tf_type::TensorFlowRefType>();
588   for (Type operand_type : merge.getOperandTypes()) {
589     if (operand_type.isa<ControlType>()) break;
590 
591     // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
592     // constraint.
593     TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
594     if (!operand_tensor_ty)
595       return merge.emitOpError()
596              << "expects data operands to have tensor type but got "
597              << operand_type;
598 
599     // If output type is a ref type then all operand types should also be of the
600     // same ref type. However, if the output type is a non-ref type T, operands
601     // can be tensor of type T or T_REF.
602     if (is_output_ref &&
603         !operand_tensor_ty.getElementType().isa<tf_type::TensorFlowRefType>()) {
604       return merge.emitOpError()
605              << "expects same operand and output element type but got "
606              << operand_tensor_ty << " vs " << output_tensor_ty;
607     }
608     Type broadcasted_type = OpTrait::util::getBroadcastedType(
609         tf_type::DropRefAndSubTypes(output_tensor_ty),
610         tf_type::DropRefAndSubTypes(operand_tensor_ty));
611     if (!broadcasted_type)
612       return merge.emitOpError()
613              << "expects all operands to be broadcastable with output type"
614              << " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
615   }
616   return success();
617 }
618 
print(OpAsmPrinter & p)619 void MergeOp::print(OpAsmPrinter &p) {
620   // Use short form only when there are exactly two data operands and their
621   // type matches the output type. Otherwise, use the generic printer.
622   bool use_short_form = true;
623   int num_data_operands = 0;
624 
625   Type output_type = output().getType();
626   for (Type operand_type : getOperandTypes()) {
627     if (operand_type.isa<ControlType>()) break;
628     num_data_operands++;
629 
630     if (operand_type != output_type) {
631       use_short_form = false;
632       break;
633     }
634   }
635 
636   p << ' ';
637   p.printOperands(getOperands());
638 
639   // Print the type signature of the operation.
640   p << " : ";
641   if (!use_short_form || num_data_operands != 2) {
642     p.printFunctionalType(getOperation());
643   } else {
644     p << output_type;
645   }
646 
647   p.printOptionalAttrDict(getOperation()->getAttrs());
648 }
649 
parse(OpAsmParser & parser,OperationState & result)650 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
651   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
652   SmallVector<Type, 1> types;
653   llvm::SMLoc loc = parser.getCurrentLocation();
654   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
655     return failure();
656   if (types.size() != 1)
657     return parser.emitError(parser.getNameLoc())
658            << " expects only a single data type";
659 
660   // Support parsing either a functional type (in which case all the types are
661   // fully qualified) or a short form with a single type (in which case the data
662   // inputs and the output are all using this type).
663   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
664     result.types.assign(type.getResults().begin(), type.getResults().end());
665     types.assign(type.getInputs().begin(), type.getInputs().end());
666   } else {
667     // In case of the short form, use the parsed type for both the operands and
668     // the remaining operands are expected to be control inputs.
669     types.push_back(Type(types.front()));
670     Type control_type = ControlType::get(parser.getBuilder().getContext());
671     types.append(op_infos.size() - 2, control_type);
672 
673     RankedTensorType i32_tensor =
674         RankedTensorType::get({}, parser.getBuilder().getIntegerType(32));
675     result.types = {types.front(), i32_tensor, control_type};
676   }
677 
678   if (parser.resolveOperands(op_infos, types, loc, result.operands))
679     return failure();
680 
681   return parser.parseOptionalAttrDict(result.attributes);
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // tf_executor.Enter
686 //===----------------------------------------------------------------------===//
687 
688 // Default number for the parallel_iterations attributes on Enter nodes.
689 static constexpr int kDefaultParallelIterations = 10;
690 
print(OpAsmPrinter & p)691 void EnterOp::print(OpAsmPrinter &p) {
692   p << ' ';
693   p.printOperands(getOperands());
694 
695   p << " frame \"";
696   printEscapedString(frame_name(), p.getStream());
697   p << "\"";
698   if (parallel_iterations() != kDefaultParallelIterations)
699     p << " parallel_iterations " << parallel_iterations();
700   if (is_constant()) p << " constant ";
701 
702   // If the types aren't perfectly matching, print the functional type syntax
703   // else print the shorter single type.
704   p << " : ";
705   if (data().getType() != output().getType()) {
706     p.printFunctionalType(getOperation());
707   } else {
708     p << getType(0);
709   }
710 
711   p.printOptionalAttrDict(getOperation()->getAttrs(),
712                           {"frame_name", "parallel_iterations", "is_constant"});
713 }
714 
parse(OpAsmParser & parser,OperationState & result)715 ParseResult EnterOp::parse(OpAsmParser &parser, OperationState &result) {
716   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
717   llvm::SMLoc loc = parser.getCurrentLocation();
718   MLIRContext *context = parser.getBuilder().getContext();
719   if (parser.parseOperandList(op_infos)) return failure();
720   if (op_infos.empty())
721     return parser.emitError(loc) << " expects at least one data operand";
722 
723   Attribute frame;
724   if (parser.parseKeyword("frame") ||
725       parser.parseAttribute(frame, NoneType::get(context), "frame_name",
726                             result.attributes))
727     return failure();
728 
729   Type i64 = parser.getBuilder().getIntegerType(64);
730   if (parser.parseOptionalKeyword("parallel_iterations")) {
731     result.addAttribute("parallel_iterations",
732                         IntegerAttr::get(i64, kDefaultParallelIterations));
733   } else {
734     IntegerAttr parallel_iterations;
735     if (parser.parseAttribute(parallel_iterations, i64, "parallel_iterations",
736                               result.attributes))
737       return failure();
738   }
739   bool has_constant = succeeded(parser.parseOptionalKeyword("constant"));
740   result.addAttribute("is_constant", BoolAttr::get(context, has_constant));
741 
742   SmallVector<Type, 1> types;
743   if (parser.parseColonTypeList(types)) return failure();
744   if (types.size() != 1)
745     return parser.emitError(loc) << " expects only a single data type";
746 
747   // Support parsing either a functional type (in which case all the types are
748   // fully qualified) or a short form with a single type (in which case the data
749   // input and the outputs are all using this type).
750   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
751     // One data input, and any number of control inputs.
752     if (type.getNumInputs() >= 1) {
753       result.types.assign(type.getResults().begin(), type.getResults().end());
754       types.assign(type.getInputs().begin(), type.getInputs().end());
755     } else {
756       return parser.emitError(parser.getNameLoc()) << " expects a data input";
757     }
758   } else {
759     Type control_type = ControlType::get(context);
760     types.append(op_infos.size() - 1, control_type);
761     result.addTypes({types.front(), control_type});
762   }
763 
764   // Extra operands are expected to be control inputs.
765 
766   if (parser.resolveOperands(op_infos, types, loc, result.operands))
767     return failure();
768 
769   return parser.parseOptionalAttrDict(result.attributes);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // tf_executor.NextIteration.Source
774 //===----------------------------------------------------------------------===//
775 
verify()776 LogicalResult NextIterationSourceOp::verify() {
777   NextIterationSourceOp source = *this;
778   Value token = source.token();
779   if (!token.hasOneUse())
780     return source.emitOpError() << "expects a single user for produced token";
781   if (!isa<NextIterationSinkOp>(*token.user_begin()))
782     return source.emitOpError() << "token should be consumed by a sink op";
783   return success();
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // tf_executor.NextIteration.Sink
788 //===----------------------------------------------------------------------===//
789 
verify()790 LogicalResult NextIterationSinkOp::verify() {
791   NextIterationSinkOp sink = *this;
792   Value token = sink.token();
793   Operation *definingOp = token.getDefiningOp();
794   if (!definingOp)
795     return sink.emitOpError() << "expects a token directly produced by a "
796                                  "tf_executor.NextIteration.Source op: ";
797   auto source = dyn_cast<NextIterationSourceOp>(definingOp);
798   if (!source)
799     return sink.emitOpError() << "expects a token produced by a "
800                                  "tf_executor.NextIteration.Source op: ";
801   if (source.output().getType() != sink.input().getType())
802     return sink.emitOpError()
803            << "input type " << sink.input().getType()
804            << " mismatch the tf_executor.NextIteration.Source output type: "
805            << source.output().getType();
806   return success();
807 }
808 
GetSource()809 NextIterationSourceOp NextIterationSinkOp::GetSource() {
810   return cast<NextIterationSourceOp>(token().getDefiningOp());
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // tf_executor.Exit
815 //===----------------------------------------------------------------------===//
816 
print(OpAsmPrinter & p)817 void ExitOp::print(OpAsmPrinter &p) {
818   p << ' ';
819   p.printOperands(getOperands());
820   p << " : " << getType(0);
821   p.printOptionalAttrDict(getOperation()->getAttrs());
822 }
823 
parse(OpAsmParser & parser,OperationState & result)824 ParseResult ExitOp::parse(OpAsmParser &parser, OperationState &result) {
825   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
826   SmallVector<Type, 1> types;
827 
828   if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
829     return failure();
830 
831   llvm::SMLoc loc = parser.getCurrentLocation();
832   Type control_type = ControlType::get(parser.getBuilder().getContext());
833   types.append(op_infos.size() - 1, control_type);
834   if (parser.resolveOperands(op_infos, types, loc, result.operands))
835     return failure();
836 
837   result.addTypes({types.front(), control_type});
838   return parser.parseOptionalAttrDict(result.attributes);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // tf_executor.ControlTrigger
843 //===----------------------------------------------------------------------===//
844 
845 //===----------------------------------------------------------------------===//
846 // tf_executor.LoopCond
847 //===----------------------------------------------------------------------===//
848 
print(OpAsmPrinter & p)849 void LoopCondOp::print(OpAsmPrinter &p) {
850   p << ' ';
851   p.printOperands(getOperands());
852 
853   // If the types aren't matching (broadcast), print the functional type syntax.
854   if (input().getType() != output().getType()) {
855     p << " : ";
856     p.printFunctionalType(getOperation());
857   } else {
858     p << " : " << input().getType();
859   }
860 
861   p.printOptionalAttrDict(getOperation()->getAttrs());
862 }
863 
parse(OpAsmParser & parser,OperationState & result)864 ParseResult LoopCondOp::parse(OpAsmParser &parser, OperationState &result) {
865   SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
866 
867   if (parser.parseOperandList(op_infos)) return failure();
868   if (op_infos.empty())
869     return parser.emitError(parser.getNameLoc())
870            << "expects at least one operand";
871 
872   SmallVector<Type, 1> types;
873   if (parser.parseColonTypeList(types)) return failure();
874 
875   // Support parsing either a functional type (in which case all the types are
876   // fully qualified) or a short form with a single type (in which case the data
877   // input and the outputs are all using this type).
878   Type control_type = ControlType::get(parser.getBuilder().getContext());
879   if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
880     if (llvm::count_if(type.getInputs(),
881                        [=](Type type) { return type != control_type; }) != 1)
882       return parser.emitError(parser.getNameLoc())
883              << " expects a single data type";
884     result.types.assign(type.getResults().begin(), type.getResults().end());
885     types.assign(type.getInputs().begin(), type.getInputs().end());
886   } else {
887     if (types.size() != 1)
888       return parser.emitError(parser.getNameLoc())
889              << " expects a single data type";
890     types.append(op_infos.size() - 1, control_type);
891     result.addTypes({types.front(), control_type});
892   }
893 
894   llvm::SMLoc loc = parser.getCurrentLocation();
895   if (parser.resolveOperands(op_infos, types, loc, result.operands))
896     return failure();
897 
898   return parser.parseOptionalAttrDict(result.attributes);
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // Canonicalization patterns
903 //===----------------------------------------------------------------------===//
904 
905 // TODO(lyandy): Add canonicalization for dedupping control inputs.
906 
907 //===----------------------------------------------------------------------===//
908 // tf_executor.graph
909 //===----------------------------------------------------------------------===//
910 
911 namespace {
912 // Finds in a block if the op of type `InnerOpT` is the first operation and
913 // optionally followed by a terminator.
914 template <typename InnerOpT>
HasSingleOpInBlock(Block * block)915 bool HasSingleOpInBlock(Block *block) {
916   if (block->empty()) return false;
917   if (!llvm::isa<InnerOpT>(block->front())) return false;
918   // Either InnerOpT is the only instruction in the block, or there is a
919   // possible terminator.
920   return std::next(block->begin()) == block->end() ||
921          std::next(block->begin(), 2) == block->end();
922 }
923 
924 // This pattern matches GraphOps with only one FetchOp (empty) and remaps the
925 // results of the GraphOp to the operands of the FetchOp.
926 struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
927   using OpRewritePattern<GraphOp>::OpRewritePattern;
928 
matchAndRewritemlir::tf_executor::__anon2366cd000411::DropEmptyGraph929   LogicalResult matchAndRewrite(GraphOp op,
930                                 PatternRewriter &rewriter) const override {
931     Block &block = op.GetBody();
932     // Check if graph only has one fetch.
933     if (&block.front() != &block.back()) return failure();
934 
935     // Map graph results to fetch operands.
936     rewriter.replaceOp(op, op.GetFetch().fetches());
937 
938     return success();
939   }
940 };
941 
942 // This pattern matches GraphOps with only one island, pulls out all inner ops
943 // of the island to the block containing the GraphOp, and then removes the
944 // GraphOp.
945 struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
946   using OpRewritePattern<GraphOp>::OpRewritePattern;
947 
matchAndRewritemlir::tf_executor::__anon2366cd000411::HoistInnerOpsSingleIslandGraph948   LogicalResult matchAndRewrite(GraphOp op,
949                                 PatternRewriter &rewriter) const override {
950     Block &block = op.GetBody();
951     // Check if graph only has one island.
952     if (!HasSingleOpInBlock<IslandOp>(&block)) return failure();
953 
954     FetchOp fetch_op = op.GetFetch();
955     auto island_op = llvm::cast<IslandOp>(block.front());
956     YieldOp yield_op = island_op.GetYield();
957 
958     // Map graph results to inner ops results of single island.
959     llvm::SmallVector<Value, 8> new_rets;
960     for (Value operand : fetch_op.fetches()) {
961       // Control results should not be propagated out.
962       if (operand.getType().isa<ControlType>()) break;
963 
964       if (operand.getDefiningOp() != island_op) {
965         // Operand is not from island, simply propagate it out.
966         new_rets.push_back(operand);
967       } else {
968         // Lookup yield operand in island for inner op result.
969         auto result = operand.cast<OpResult>();
970         new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
971       }
972     }
973 
974     // Move inner ops from island to block containing graph.
975     auto &island_body = island_op.GetBody().getOperations();
976     Operation *operation = op.getOperation();
977     operation->getBlock()->getOperations().splice(
978         operation->getIterator(), island_body, island_body.begin(),
979         std::prev(island_body.end()));
980     rewriter.replaceOp(op, new_rets);
981 
982     return success();
983   }
984 };
985 }  // anonymous namespace
986 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)987 void GraphOp::getCanonicalizationPatterns(RewritePatternSet &results,
988                                           MLIRContext *context) {
989   results.add<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context);
990 }
991 
992 //===----------------------------------------------------------------------===//
993 // tf_executor.island
994 //===----------------------------------------------------------------------===//
995 
996 namespace {
997 // This pattern matches and removes IslandOps with no inner ops, no control
998 // operands and no data results. Control result users will have their relevant
999 // operands removed.
1000 struct DropEmptyIslandNoOperandNoDataResult
1001     : public OpRewritePattern<IslandOp> {
1002   using OpRewritePattern<IslandOp>::OpRewritePattern;
1003 
matchAndRewritemlir::tf_executor::__anon2366cd000511::DropEmptyIslandNoOperandNoDataResult1004   LogicalResult matchAndRewrite(IslandOp op,
1005                                 PatternRewriter &rewriter) const override {
1006     if (op.getNumOperands() != 0 || op.getNumResults() != 1 ||
1007         !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1008       return failure();
1009 
1010     for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1011       use.getOwner()->eraseOperand(use.getOperandNumber());
1012 
1013     rewriter.eraseOp(op);
1014 
1015     return success();
1016   }
1017 };
1018 
1019 // This pattern matches and removes IslandOps with no inner ops, no control
1020 // operands, one data result and no control result user. The single data result
1021 // (from YieldOps first operand) is forwarded to the IslandOp single data result
1022 // users.
1023 struct DropEmptyIslandNoOperandOneDataResult
1024     : public OpRewritePattern<IslandOp> {
1025   using OpRewritePattern<IslandOp>::OpRewritePattern;
1026 
matchAndRewritemlir::tf_executor::__anon2366cd000511::DropEmptyIslandNoOperandOneDataResult1027   LogicalResult matchAndRewrite(IslandOp op,
1028                                 PatternRewriter &rewriter) const override {
1029     if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
1030         !op.control().use_empty() ||
1031         !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1032       return failure();
1033 
1034     rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr});
1035 
1036     return success();
1037   }
1038 };
1039 
1040 // TODO(lyandy): Add canonicalization for empty IslandOps with more than one
1041 // control operand and no data results.
1042 
1043 }  // anonymous namespace
1044 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1045 void IslandOp::getCanonicalizationPatterns(RewritePatternSet &results,
1046                                            MLIRContext *context) {
1047   results.add<DropEmptyIslandNoOperandNoDataResult,
1048               DropEmptyIslandNoOperandOneDataResult>(context);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // tf_executor.ControlTrigger
1053 //===----------------------------------------------------------------------===//
1054 
1055 namespace {
1056 // This pattern matches and removes ControlTriggerOps with no control operands.
1057 // Control result users will have their relevant operands removed.
1058 struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
1059   using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
1060 
matchAndRewritemlir::tf_executor::__anon2366cd000611::DropEmptyControlTrigger1061   LogicalResult matchAndRewrite(ControlTriggerOp op,
1062                                 PatternRewriter &rewriter) const override {
1063     if (op.getNumOperands() != 0) return failure();
1064 
1065     for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1066       use.getOwner()->eraseOperand(use.getOperandNumber());
1067 
1068     rewriter.eraseOp(op);
1069 
1070     return success();
1071   }
1072 };
1073 }  // anonymous namespace
1074 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1075 void ControlTriggerOp::getCanonicalizationPatterns(RewritePatternSet &results,
1076                                                    MLIRContext *context) {
1077   results.add<DropEmptyControlTrigger>(context);
1078 }
1079 
1080 //===----------------------------------------------------------------------===//
1081 // Folders
1082 //===----------------------------------------------------------------------===//
1083 
1084 //===----------------------------------------------------------------------===//
1085 // tf_executor.island
1086 //===----------------------------------------------------------------------===//
1087 
fold(llvm::ArrayRef<Attribute> operands,llvm::SmallVectorImpl<OpFoldResult> & results)1088 LogicalResult IslandOp::fold(llvm::ArrayRef<Attribute> operands,
1089                              llvm::SmallVectorImpl<OpFoldResult> &results) {
1090   // This folds IslandOps with no inner ops, one control operand and no data
1091   // results. The single control operand is forwarded to the IslandOp control
1092   // result users.
1093   if (getNumOperands() != 1 || getNumResults() != 1 ||
1094       !HasSingleOpInBlock<YieldOp>(&GetBody()))
1095     return failure();
1096 
1097   results.emplace_back(getOperand(0));
1098 
1099   return success();
1100 }
1101 
1102 }  // namespace tf_executor
1103 }  // namespace mlir
1104 
1105 //===----------------------------------------------------------------------===//
1106 // TableGen'd op method definitions
1107 //===----------------------------------------------------------------------===//
1108 
1109 #define GET_OP_CLASSES
1110 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
1111