1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
17
18 #include <algorithm>
19 #include <iterator>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/Dialect/Traits.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/Matchers.h" // from @llvm-project
38 #include "mlir/IR/OpDefinition.h" // from @llvm-project
39 #include "mlir/IR/OpImplementation.h" // from @llvm-project
40 #include "mlir/IR/PatternMatch.h" // from @llvm-project
41 #include "mlir/IR/Types.h" // from @llvm-project
42 #include "mlir/IR/Value.h" // from @llvm-project
43 #include "mlir/Support/LogicalResult.h" // from @llvm-project
44 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
45 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47
48 namespace mlir {
49 namespace tf_executor {
50
51 //===----------------------------------------------------------------------===//
52 // TF Executor Dialect
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56
57 struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks
62 //===--------------------------------------------------------------------===//
63
64 // Allow all call operations to be inlined.
isLegalToInlinemlir::tf_executor::__anon2366cd000111::TensorFlowExecutorInlinerInterface65 bool isLegalToInline(Operation *call, Operation *callable,
66 bool wouldBeCloned) const final {
67 return true;
68 }
69 // Override the inlining hook to determine if 'src' can be inlined into
70 // 'dest'.
isLegalToInlinemlir::tf_executor::__anon2366cd000111::TensorFlowExecutorInlinerInterface71 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72 BlockAndValueMapping &value_mapping) const final {
73 // Allow inlining into tf.island regions if the incoming region has a single
74 // block.
75 return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
76 llvm::hasSingleElement(*src);
77 }
78 };
79
80 struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
81 using DialectFoldInterface::DialectFoldInterface;
82
83 // Registered hook to check if the given region, which is attached to an
84 // operation that is *not* isolated from above (i.e. no internal regions
85 // reference values defined in an enclosing region), should be used when
86 // materializing constants.
87 // In the executor dialect we materialize inside an island.
shouldMaterializeIntomlir::tf_executor::__anon2366cd000111::TensorFlowExecutorDialectFoldInterface88 bool shouldMaterializeInto(Region *region) const final {
89 return isa<tf_executor::IslandOp>(region->getParentOp());
90 }
91 };
92
93 } // namespace
94
TensorFlowExecutorDialect(MLIRContext * context)95 TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
96 : Dialect(/*name=*/"tf_executor", context,
97 TypeID::get<TensorFlowExecutorDialect>()) {
98 addOperations<
99 #define GET_OP_LIST
100 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
101 >();
102
103 addInterfaces<TensorFlowExecutorInlinerInterface,
104 TensorFlowExecutorDialectFoldInterface>();
105
106 addTypes<ControlType, TokenType>();
107 }
108
parseType(DialectAsmParser & parser) const109 Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const {
110 StringRef data_type;
111 if (parser.parseKeyword(&data_type)) return Type();
112
113 if (data_type == "control") return ControlType::get(getContext());
114 if (data_type == "token") return TokenType::get(getContext());
115 parser.emitError(parser.getNameLoc())
116 << "unknown tf_executor type: " << data_type;
117 return nullptr;
118 }
119
printType(Type type,DialectAsmPrinter & os) const120 void TensorFlowExecutorDialect::printType(Type type,
121 DialectAsmPrinter &os) const {
122 if (type.isa<ControlType>()) {
123 os << "control";
124 return;
125 }
126 if (type.isa<TokenType>()) {
127 os << "token";
128 return;
129 }
130 os << "<unknown tf_executor type>";
131 }
132
133 //===----------------------------------------------------------------------===//
134 // Implementation for all the operations defined in ODS (op definition spec).
135 //===----------------------------------------------------------------------===//
136
137 namespace {
138
139 // Verifies that every control operands are at the end of the list.
140 // Used by the constraint `ControlOperandsAfterAllData` in ODS.
VerifyControlOperandsAfterAllData(Operation * op)141 LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
142 bool found_control = false;
143 for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
144 if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
145 found_control = true;
146 continue;
147 }
148 if (found_control)
149 return op->emitOpError() << "found non-control operand #" << operand_idx
150 << " after control operand";
151 }
152 return success();
153 }
154
155 } // anonymous namespace
156
157 //===----------------------------------------------------------------------===//
158 // tf_executor.graph
159 //===----------------------------------------------------------------------===//
160
GetFetch()161 FetchOp GraphOp::GetFetch() { return llvm::cast<FetchOp>(GetBody().back()); }
162
verify()163 LogicalResult GraphOp::verify() {
164 GraphOp graph = *this;
165 auto *executorDialect = graph->getDialect();
166
167 if (graph.GetBody().empty())
168 return graph.emitOpError() << "expects a non-empty body";
169
170 // Only tf_executor dialect operations are allowed to be immediately nested
171 // in a tf_executor.graph region.
172 for (Operation &op : graph.GetBody()) {
173 if (op.getDialect() != executorDialect)
174 return op.emitOpError() << "unallowed inside a tf_executor.graph region";
175 if (isa<GraphOp>(op))
176 return op.emitOpError()
177 << "unallowed directly inside another tf_executor.graph";
178 }
179
180 Operation &fetch = graph.GetBody().back();
181 if (!isa<FetchOp>(fetch))
182 return fetch.emitOpError()
183 << "invalid tf_executor.graph terminator, fetch expected";
184
185 // Ensure that the fetch terminator operands matches the graph result type.
186 // All the non-control operands of the fetch operation must match the graph
187 // returned value.
188 if (fetch.getNumOperands() < graph.getNumResults())
189 return fetch.emitOpError() << "does not have enough operands to cover the "
190 "graph returned values";
191 for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
192 Value operand = fetch.getOperand(i);
193 // Break out of the loop at the first control operand encountered.
194 const int64_t num_results = graph.getNumResults();
195 if (operand.getType().isa<ControlType>()) {
196 if (i != num_results)
197 return fetch.emitOpError()
198 << "operand #" << i
199 << " is a control type, can't be bound to a graph result";
200 break;
201 }
202 if (i >= num_results)
203 return fetch.emitOpError()
204 << "operand #" << i << " does not have a graph results to bind";
205 if (graph.getResult(i).getType() != operand.getType()) {
206 return fetch.emitOpError()
207 << "operand #" << i << " type mismatch graph results ("
208 << graph.getResult(i).getType() << " != " << operand.getType()
209 << ")";
210 }
211 }
212 return success();
213 }
214
print(OpAsmPrinter & p)215 void GraphOp::print(OpAsmPrinter &p) {
216 p << ' ';
217 p.printRegion(getOperation()->getRegion(0));
218 p.printOptionalAttrDict(getOperation()->getAttrs());
219 }
220
parse(OpAsmParser & parser,OperationState & result)221 ParseResult GraphOp::parse(OpAsmParser &parser, OperationState &result) {
222 llvm::SMLoc loc = parser.getCurrentLocation();
223
224 // Parse the body region.
225 Region &body = *result.addRegion();
226 if (parser.parseRegion(body)) return failure();
227
228 // Ensure that the region is well formed: it contains at least a block with
229 // a FetchOp terminator.
230 GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
231
232 if (!llvm::hasSingleElement(body))
233 return parser.emitError(loc) << "expects a single block region";
234
235 // Get the results type from the terminator type inside the graph.
236 Operation &fetch = body.back().back();
237 if (!isa<FetchOp>(fetch))
238 return parser.emitError(loc) << "expects a tf_executor.fetch terminator";
239
240 // The return value of the graph operation are the non-control operands of
241 // the fetch operation.
242 result.types.reserve(fetch.getNumOperands());
243 for (Type type : fetch.getOperandTypes()) {
244 if (type.isa<ControlType>()) break;
245 result.types.push_back(type);
246 }
247
248 // Parse the optional attribute list.
249 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
250
251 return success();
252 }
253
254 //===----------------------------------------------------------------------===//
255 // tf_executor.fetch
256 //===----------------------------------------------------------------------===//
257
258 //===----------------------------------------------------------------------===//
259 // tf_executor.island
260 //===----------------------------------------------------------------------===//
261
GetYield()262 YieldOp IslandOp::GetYield() { return llvm::cast<YieldOp>(GetBody().back()); }
263
264 // Checks if a tf_executor.island wraps a single operation and the single
265 // operation results are perfectly forwarded to the islands yield.
WrapsSingleOp()266 bool IslandOp::WrapsSingleOp() {
267 auto body = GetBody().without_terminator();
268 if (!hasSingleElement(body)) return false;
269
270 Operation &wrapped_op = *body.begin();
271 YieldOp yield = GetYield();
272 return wrapped_op.getNumResults() == yield.getNumOperands() &&
273 std::equal(wrapped_op.getResults().begin(),
274 wrapped_op.getResults().end(), yield.getOperands().begin());
275 }
276
verify()277 mlir::LogicalResult IslandOp::verify() {
278 IslandOp island = *this;
279 if (!island.GetBody().args_empty())
280 return island.emitOpError() << "expects body without any arguments";
281
282 Operation &yield = island.GetBody().back();
283 if (!isa<YieldOp>(yield))
284 return yield.emitOpError()
285 << "invalid tf_executor.island terminator, yield expected";
286
287 // Ensure that the yield terminator operands matches the island results type.
288 int result_count = island.getNumResults() - 1; // -1 for the control token
289 const int num_operands = yield.getNumOperands();
290 if (num_operands != result_count)
291 return yield.emitOpError()
292 << "has " << yield.getNumOperands()
293 << " operand, but island returns " << result_count;
294 for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
295 if (island.getResult(operand_idx).getType() !=
296 yield.getOperand(operand_idx).getType())
297 return yield.emitOpError()
298 << "operand #" << operand_idx << " type mismatch island results";
299 }
300
301 // Check that there aren't any control results other than the last one.
302 Type control_type = ControlType::get(island.getContext());
303 for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
304 if (island.getResult(operand_idx).getType() == control_type)
305 return yield.emitOpError()
306 << "unexpected control type for operand #" << operand_idx;
307 }
308 return success();
309 }
310
print(OpAsmPrinter & p)311 void IslandOp::print(OpAsmPrinter &p) {
312 if (getNumOperands()) {
313 // These are always control operand, no explicit type needed.
314 p << '(';
315 p.printOperands(getOperands());
316 p << ')';
317 }
318
319 // Check if we can print the short "wraps" form: that is if the island
320 // contains a single operation and the result of this operation are perfectly
321 // forwarded to the yield.
322 if (getOperation()->getAttrs().empty() && WrapsSingleOp()) {
323 Operation &wrapped_op = GetBody().front();
324 YieldOp yield_op = GetYield();
325 // The "wraps" syntax only encodes a single location.
326 // In order to correctly round-trip, we can only use this syntax when all
327 // the locations are identical.
328 if (wrapped_op.getLoc() == getLoc() && yield_op.getLoc() == getLoc()) {
329 p << " wraps ";
330 p.printGenericOp(&wrapped_op);
331 return;
332 }
333 }
334 p << ' ';
335 p.printRegion(getOperation()->getRegion(0));
336 p.printOptionalAttrDict(getOperation()->getAttrs());
337 }
338
parse(OpAsmParser & parser,OperationState & result)339 ParseResult IslandOp::parse(OpAsmParser &parser, OperationState &result) {
340 llvm::SMLoc loc = parser.getCurrentLocation();
341 Type control_type = ControlType::get(parser.getBuilder().getContext());
342
343 // Parse optional argument list (control dependencies only).
344 SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
345 if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
346 return failure();
347 if (!op_infos.empty()) {
348 SmallVector<Type, 2> types(op_infos.size(), control_type);
349 if (parser.resolveOperands(op_infos, types, loc, result.operands))
350 return failure();
351 }
352
353 // Parse the body region.
354 Region &body = *result.addRegion();
355
356 if (succeeded(parser.parseOptionalKeyword("wraps"))) {
357 // If we parse the short version of the island, we have an operation in the
358 // generic form that follows the "wraps" keyword. Parse it inside the region
359 // and forward all of its results as-is to the yield operation.
360 body.push_back(new Block);
361 Block &block = body.back();
362 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
363 if (!wrapped_op) return failure();
364 OpBuilder builder(parser.getBuilder().getContext());
365 builder.setInsertionPointToEnd(&block);
366 builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
367 result.location = wrapped_op->getLoc();
368 } else if (parser.parseRegion(body)) {
369 return failure();
370 }
371
372 IslandOp::ensureTerminator(body, parser.getBuilder(), result.location);
373
374 // Get the results type for the island from the terminator operands.
375 Operation &yield = body.back().back();
376 result.types.reserve(yield.getNumOperands() + 1);
377 result.types.append(yield.operand_type_begin(), yield.operand_type_end());
378 result.types.push_back(control_type);
379
380 // Parse the optional attribute list.
381 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
382 return success();
383 }
384
385 //===----------------------------------------------------------------------===//
386 // tf_executor.yield
387 //===----------------------------------------------------------------------===//
388
389 //===----------------------------------------------------------------------===//
390 // tf_executor.Switch
391 //===----------------------------------------------------------------------===//
392
parse(OpAsmParser & parser,OperationState & result)393 ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
394 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
395 SmallVector<Type, 1> types;
396 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
397 return failure();
398 if (types.size() != 1)
399 return parser.emitError(parser.getNameLoc())
400 << " expects only a single data type";
401
402 // Support parsing either a functional type (in which case all the types are
403 // fully qualified) or a short form with a single type (in which case the data
404 // input and the outputs are all using this type and predicate is tensor<i1>
405 // type).
406 if (types.front().isa<FunctionType>()) {
407 FunctionType type = types.front().cast<FunctionType>();
408 if (type.getNumInputs() < 2)
409 return parser.emitError(parser.getNameLoc())
410 << " expects a single data type and a predicate";
411 result.types.assign(type.getResults().begin(), type.getResults().end());
412 types.assign(type.getInputs().begin(), type.getInputs().end());
413 } else {
414 if (op_infos.size() < 2)
415 return parser.emitError(parser.getNameLoc())
416 << " expects a single data type and a predicate";
417 Type control_type = ControlType::get(parser.getBuilder().getContext());
418 result.types.append(2, types[0]);
419 result.types.push_back(control_type);
420 Type i1_type = parser.getBuilder().getI1Type();
421 RankedTensorType predicate_type = RankedTensorType::get({}, i1_type);
422 types.push_back(predicate_type);
423 types.append(op_infos.size() - 2, control_type);
424 }
425
426 llvm::SMLoc loc = parser.getCurrentLocation();
427 if (parser.resolveOperands(op_infos, types, loc, result.operands))
428 return failure();
429
430 return parser.parseOptionalAttrDict(result.attributes);
431 }
432
print(OpAsmPrinter & p)433 void SwitchOp::print(OpAsmPrinter &p) {
434 p << ' ';
435 p.printOperands(getOperands());
436 Type data_operand_ty = data().getType();
437 // If the types aren't perfectly matching, print the functional type syntax
438 // else print the shorter single type.
439 p << " : ";
440 if (trueOutput().getType() != data_operand_ty ||
441 falseOutput().getType() != data_operand_ty ||
442 predicate().getType().isa<UnrankedTensorType>()) {
443 p.printFunctionalType(getOperation());
444 } else {
445 p << getType(0);
446 }
447 p.printOptionalAttrDict(getOperation()->getAttrs());
448 }
449
450 //===----------------------------------------------------------------------===//
451 // tf_executor.SwitchN
452 //===----------------------------------------------------------------------===//
453
verify()454 LogicalResult SwitchNOp::verify() {
455 SwitchNOp switchn = *this;
456 IntegerAttr num_outs = switchn->getAttrOfType<IntegerAttr>("num_outs");
457 if (!num_outs)
458 return switchn.emitOpError() << "expects a `num_outs` integer attribute";
459
460 // Expect num_outs results + 1 control output.
461 if (switchn.getNumResults() != num_outs.getInt() + 1)
462 return switchn.emitOpError()
463 << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
464 << (switchn.getNumResults() - 1);
465
466 // Check that operand can be broadcasted to each output type.
467 auto operand0_type = switchn.getOperand(0).getType();
468 TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
469 if (!operand0_tensor_type) {
470 return switchn.emitOpError()
471 << "expects data operand to have tensor type but got "
472 << operand0_type;
473 }
474 for (Type output_type : switchn.getResultTypes()) {
475 if (output_type.isa<ControlType>()) break;
476
477 TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
478 if (!output_tensor_type) {
479 return switchn.emitOpError()
480 << "expects outputs to have tensor type but got " << output_type;
481 }
482
483 // If the output type is a ref type, then the operand type should also be of
484 // the same ref type. However, if the output type is a non-ref type T, then
485 // the operand can be tensor of type T or T_REF.
486 bool is_output_ref =
487 output_tensor_type.getElementType().isa<tf_type::TensorFlowRefType>();
488 if (is_output_ref && !operand0_tensor_type.getElementType()
489 .isa<tf_type::TensorFlowRefType>()) {
490 return switchn.emitOpError()
491 << "expects same operand and output element type but got "
492 << operand0_tensor_type << " vs " << output_tensor_type;
493 }
494 Type broadcasted_type = OpTrait::util::getBroadcastedType(
495 tf_type::DropRefAndSubTypes(operand0_tensor_type),
496 tf_type::DropRefAndSubTypes(output_tensor_type));
497 if (!broadcasted_type) {
498 return switchn.emitOpError()
499 << "expects data operand to be broadcastable with all output types"
500 << " but got " << operand0_tensor_type << " vs "
501 << output_tensor_type;
502 }
503 }
504 return success();
505 }
506
print(OpAsmPrinter & p)507 void SwitchNOp::print(OpAsmPrinter &p) {
508 p << ' ';
509 auto operands = getOperands();
510 // Print the 2 data operands.
511 p.printOperands(operands.begin(), std::next(operands.begin(), 2));
512 p << " of " << (getNumResults() - 1);
513 // print control dependencies if any
514 if (!llvm::empty(controlInputs())) {
515 p << " (";
516 p.printOperands(controlInputs());
517 p << ")";
518 }
519 p << " : " << getType(0);
520 p.printOptionalAttrDict(getOperation()->getAttrs(), {"num_outs"});
521 }
522
parse(OpAsmParser & parser,OperationState & result)523 ParseResult SwitchNOp::parse(OpAsmParser &parser, OperationState &result) {
524 // Parsing:
525 // %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
526 // Where the first operand is the data to replicate, the second is an i32
527 // indicating which output to populate, followed by the keyword `of` and the
528 // number of outputs (+1 for the control token).
529 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
530 SmallVector<Type, 1> types;
531 llvm::SMLoc loc = parser.getCurrentLocation();
532 IntegerAttr num_outs;
533 Type i64_type = parser.getBuilder().getIntegerType(64);
534 if (parser.parseOperandList(op_infos, 2) || parser.parseKeyword("of") ||
535 parser.parseAttribute(num_outs, i64_type, "num_outs",
536 result.attributes) ||
537 parser.parseOperandList(op_infos,
538 OpAsmParser::Delimiter::OptionalParen) ||
539 parser.parseColonTypeList(types))
540 return failure();
541 if (types.size() != 1)
542 return parser.emitError(parser.getNameLoc())
543 << " expects only a single data type";
544
545 if (num_outs.getInt() <= 0)
546 return parser.emitError(parser.getNameLoc())
547 << " expects a positive number of outputs";
548
549 // `types` already contains the type for the data, add an i32 for the
550 // output_index, and then the optional control inputs.
551 auto builder = parser.getBuilder();
552 types.push_back(RankedTensorType::get({}, builder.getIntegerType(32)));
553 Type control_type = ControlType::get(builder.getContext());
554 types.append(op_infos.size() - 2, control_type);
555
556 if (parser.resolveOperands(op_infos, types, loc, result.operands))
557 return failure();
558
559 // Output result types is a replication `num_outs` times the data input type.
560 result.types.append(num_outs.getInt(), types[0]);
561 result.types.push_back(control_type);
562
563 return parser.parseOptionalAttrDict(result.attributes);
564 }
565
566 //===----------------------------------------------------------------------===//
567 // tf_executor.Merge
568 //===----------------------------------------------------------------------===//
569
verify()570 LogicalResult MergeOp::verify() {
571 MergeOp merge = *this;
572 if (!merge.getNumOperands())
573 return merge.emitOpError() << "expects at least one operand";
574
575 Type data_type = merge.getOperand(0).getType();
576 if (data_type.isa<ControlType>())
577 return merge.emitOpError() << "expects a non-control input";
578
579 // Check that each operand can be individually broadcasted to the output type.
580 Type output_type = merge.output().getType();
581 TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
582 if (!output_tensor_ty) {
583 return merge.emitOpError()
584 << "expects output to have tensor type but got " << output_type;
585 }
586 bool is_output_ref =
587 output_tensor_ty.getElementType().isa<tf_type::TensorFlowRefType>();
588 for (Type operand_type : merge.getOperandTypes()) {
589 if (operand_type.isa<ControlType>()) break;
590
591 // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this
592 // constraint.
593 TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>();
594 if (!operand_tensor_ty)
595 return merge.emitOpError()
596 << "expects data operands to have tensor type but got "
597 << operand_type;
598
599 // If output type is a ref type then all operand types should also be of the
600 // same ref type. However, if the output type is a non-ref type T, operands
601 // can be tensor of type T or T_REF.
602 if (is_output_ref &&
603 !operand_tensor_ty.getElementType().isa<tf_type::TensorFlowRefType>()) {
604 return merge.emitOpError()
605 << "expects same operand and output element type but got "
606 << operand_tensor_ty << " vs " << output_tensor_ty;
607 }
608 Type broadcasted_type = OpTrait::util::getBroadcastedType(
609 tf_type::DropRefAndSubTypes(output_tensor_ty),
610 tf_type::DropRefAndSubTypes(operand_tensor_ty));
611 if (!broadcasted_type)
612 return merge.emitOpError()
613 << "expects all operands to be broadcastable with output type"
614 << " but got " << operand_tensor_ty << " vs " << output_tensor_ty;
615 }
616 return success();
617 }
618
print(OpAsmPrinter & p)619 void MergeOp::print(OpAsmPrinter &p) {
620 // Use short form only when there are exactly two data operands and their
621 // type matches the output type. Otherwise, use the generic printer.
622 bool use_short_form = true;
623 int num_data_operands = 0;
624
625 Type output_type = output().getType();
626 for (Type operand_type : getOperandTypes()) {
627 if (operand_type.isa<ControlType>()) break;
628 num_data_operands++;
629
630 if (operand_type != output_type) {
631 use_short_form = false;
632 break;
633 }
634 }
635
636 p << ' ';
637 p.printOperands(getOperands());
638
639 // Print the type signature of the operation.
640 p << " : ";
641 if (!use_short_form || num_data_operands != 2) {
642 p.printFunctionalType(getOperation());
643 } else {
644 p << output_type;
645 }
646
647 p.printOptionalAttrDict(getOperation()->getAttrs());
648 }
649
parse(OpAsmParser & parser,OperationState & result)650 ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
651 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
652 SmallVector<Type, 1> types;
653 llvm::SMLoc loc = parser.getCurrentLocation();
654 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
655 return failure();
656 if (types.size() != 1)
657 return parser.emitError(parser.getNameLoc())
658 << " expects only a single data type";
659
660 // Support parsing either a functional type (in which case all the types are
661 // fully qualified) or a short form with a single type (in which case the data
662 // inputs and the output are all using this type).
663 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
664 result.types.assign(type.getResults().begin(), type.getResults().end());
665 types.assign(type.getInputs().begin(), type.getInputs().end());
666 } else {
667 // In case of the short form, use the parsed type for both the operands and
668 // the remaining operands are expected to be control inputs.
669 types.push_back(Type(types.front()));
670 Type control_type = ControlType::get(parser.getBuilder().getContext());
671 types.append(op_infos.size() - 2, control_type);
672
673 RankedTensorType i32_tensor =
674 RankedTensorType::get({}, parser.getBuilder().getIntegerType(32));
675 result.types = {types.front(), i32_tensor, control_type};
676 }
677
678 if (parser.resolveOperands(op_infos, types, loc, result.operands))
679 return failure();
680
681 return parser.parseOptionalAttrDict(result.attributes);
682 }
683
684 //===----------------------------------------------------------------------===//
685 // tf_executor.Enter
686 //===----------------------------------------------------------------------===//
687
688 // Default number for the parallel_iterations attributes on Enter nodes.
689 static constexpr int kDefaultParallelIterations = 10;
690
print(OpAsmPrinter & p)691 void EnterOp::print(OpAsmPrinter &p) {
692 p << ' ';
693 p.printOperands(getOperands());
694
695 p << " frame \"";
696 printEscapedString(frame_name(), p.getStream());
697 p << "\"";
698 if (parallel_iterations() != kDefaultParallelIterations)
699 p << " parallel_iterations " << parallel_iterations();
700 if (is_constant()) p << " constant ";
701
702 // If the types aren't perfectly matching, print the functional type syntax
703 // else print the shorter single type.
704 p << " : ";
705 if (data().getType() != output().getType()) {
706 p.printFunctionalType(getOperation());
707 } else {
708 p << getType(0);
709 }
710
711 p.printOptionalAttrDict(getOperation()->getAttrs(),
712 {"frame_name", "parallel_iterations", "is_constant"});
713 }
714
parse(OpAsmParser & parser,OperationState & result)715 ParseResult EnterOp::parse(OpAsmParser &parser, OperationState &result) {
716 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
717 llvm::SMLoc loc = parser.getCurrentLocation();
718 MLIRContext *context = parser.getBuilder().getContext();
719 if (parser.parseOperandList(op_infos)) return failure();
720 if (op_infos.empty())
721 return parser.emitError(loc) << " expects at least one data operand";
722
723 Attribute frame;
724 if (parser.parseKeyword("frame") ||
725 parser.parseAttribute(frame, NoneType::get(context), "frame_name",
726 result.attributes))
727 return failure();
728
729 Type i64 = parser.getBuilder().getIntegerType(64);
730 if (parser.parseOptionalKeyword("parallel_iterations")) {
731 result.addAttribute("parallel_iterations",
732 IntegerAttr::get(i64, kDefaultParallelIterations));
733 } else {
734 IntegerAttr parallel_iterations;
735 if (parser.parseAttribute(parallel_iterations, i64, "parallel_iterations",
736 result.attributes))
737 return failure();
738 }
739 bool has_constant = succeeded(parser.parseOptionalKeyword("constant"));
740 result.addAttribute("is_constant", BoolAttr::get(context, has_constant));
741
742 SmallVector<Type, 1> types;
743 if (parser.parseColonTypeList(types)) return failure();
744 if (types.size() != 1)
745 return parser.emitError(loc) << " expects only a single data type";
746
747 // Support parsing either a functional type (in which case all the types are
748 // fully qualified) or a short form with a single type (in which case the data
749 // input and the outputs are all using this type).
750 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
751 // One data input, and any number of control inputs.
752 if (type.getNumInputs() >= 1) {
753 result.types.assign(type.getResults().begin(), type.getResults().end());
754 types.assign(type.getInputs().begin(), type.getInputs().end());
755 } else {
756 return parser.emitError(parser.getNameLoc()) << " expects a data input";
757 }
758 } else {
759 Type control_type = ControlType::get(context);
760 types.append(op_infos.size() - 1, control_type);
761 result.addTypes({types.front(), control_type});
762 }
763
764 // Extra operands are expected to be control inputs.
765
766 if (parser.resolveOperands(op_infos, types, loc, result.operands))
767 return failure();
768
769 return parser.parseOptionalAttrDict(result.attributes);
770 }
771
772 //===----------------------------------------------------------------------===//
773 // tf_executor.NextIteration.Source
774 //===----------------------------------------------------------------------===//
775
verify()776 LogicalResult NextIterationSourceOp::verify() {
777 NextIterationSourceOp source = *this;
778 Value token = source.token();
779 if (!token.hasOneUse())
780 return source.emitOpError() << "expects a single user for produced token";
781 if (!isa<NextIterationSinkOp>(*token.user_begin()))
782 return source.emitOpError() << "token should be consumed by a sink op";
783 return success();
784 }
785
786 //===----------------------------------------------------------------------===//
787 // tf_executor.NextIteration.Sink
788 //===----------------------------------------------------------------------===//
789
verify()790 LogicalResult NextIterationSinkOp::verify() {
791 NextIterationSinkOp sink = *this;
792 Value token = sink.token();
793 Operation *definingOp = token.getDefiningOp();
794 if (!definingOp)
795 return sink.emitOpError() << "expects a token directly produced by a "
796 "tf_executor.NextIteration.Source op: ";
797 auto source = dyn_cast<NextIterationSourceOp>(definingOp);
798 if (!source)
799 return sink.emitOpError() << "expects a token produced by a "
800 "tf_executor.NextIteration.Source op: ";
801 if (source.output().getType() != sink.input().getType())
802 return sink.emitOpError()
803 << "input type " << sink.input().getType()
804 << " mismatch the tf_executor.NextIteration.Source output type: "
805 << source.output().getType();
806 return success();
807 }
808
GetSource()809 NextIterationSourceOp NextIterationSinkOp::GetSource() {
810 return cast<NextIterationSourceOp>(token().getDefiningOp());
811 }
812
813 //===----------------------------------------------------------------------===//
814 // tf_executor.Exit
815 //===----------------------------------------------------------------------===//
816
print(OpAsmPrinter & p)817 void ExitOp::print(OpAsmPrinter &p) {
818 p << ' ';
819 p.printOperands(getOperands());
820 p << " : " << getType(0);
821 p.printOptionalAttrDict(getOperation()->getAttrs());
822 }
823
parse(OpAsmParser & parser,OperationState & result)824 ParseResult ExitOp::parse(OpAsmParser &parser, OperationState &result) {
825 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
826 SmallVector<Type, 1> types;
827
828 if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
829 return failure();
830
831 llvm::SMLoc loc = parser.getCurrentLocation();
832 Type control_type = ControlType::get(parser.getBuilder().getContext());
833 types.append(op_infos.size() - 1, control_type);
834 if (parser.resolveOperands(op_infos, types, loc, result.operands))
835 return failure();
836
837 result.addTypes({types.front(), control_type});
838 return parser.parseOptionalAttrDict(result.attributes);
839 }
840
841 //===----------------------------------------------------------------------===//
842 // tf_executor.ControlTrigger
843 //===----------------------------------------------------------------------===//
844
845 //===----------------------------------------------------------------------===//
846 // tf_executor.LoopCond
847 //===----------------------------------------------------------------------===//
848
print(OpAsmPrinter & p)849 void LoopCondOp::print(OpAsmPrinter &p) {
850 p << ' ';
851 p.printOperands(getOperands());
852
853 // If the types aren't matching (broadcast), print the functional type syntax.
854 if (input().getType() != output().getType()) {
855 p << " : ";
856 p.printFunctionalType(getOperation());
857 } else {
858 p << " : " << input().getType();
859 }
860
861 p.printOptionalAttrDict(getOperation()->getAttrs());
862 }
863
parse(OpAsmParser & parser,OperationState & result)864 ParseResult LoopCondOp::parse(OpAsmParser &parser, OperationState &result) {
865 SmallVector<OpAsmParser::UnresolvedOperand, 2> op_infos;
866
867 if (parser.parseOperandList(op_infos)) return failure();
868 if (op_infos.empty())
869 return parser.emitError(parser.getNameLoc())
870 << "expects at least one operand";
871
872 SmallVector<Type, 1> types;
873 if (parser.parseColonTypeList(types)) return failure();
874
875 // Support parsing either a functional type (in which case all the types are
876 // fully qualified) or a short form with a single type (in which case the data
877 // input and the outputs are all using this type).
878 Type control_type = ControlType::get(parser.getBuilder().getContext());
879 if (FunctionType type = types.front().dyn_cast<FunctionType>()) {
880 if (llvm::count_if(type.getInputs(),
881 [=](Type type) { return type != control_type; }) != 1)
882 return parser.emitError(parser.getNameLoc())
883 << " expects a single data type";
884 result.types.assign(type.getResults().begin(), type.getResults().end());
885 types.assign(type.getInputs().begin(), type.getInputs().end());
886 } else {
887 if (types.size() != 1)
888 return parser.emitError(parser.getNameLoc())
889 << " expects a single data type";
890 types.append(op_infos.size() - 1, control_type);
891 result.addTypes({types.front(), control_type});
892 }
893
894 llvm::SMLoc loc = parser.getCurrentLocation();
895 if (parser.resolveOperands(op_infos, types, loc, result.operands))
896 return failure();
897
898 return parser.parseOptionalAttrDict(result.attributes);
899 }
900
901 //===----------------------------------------------------------------------===//
902 // Canonicalization patterns
903 //===----------------------------------------------------------------------===//
904
905 // TODO(lyandy): Add canonicalization for dedupping control inputs.
906
907 //===----------------------------------------------------------------------===//
908 // tf_executor.graph
909 //===----------------------------------------------------------------------===//
910
911 namespace {
912 // Finds in a block if the op of type `InnerOpT` is the first operation and
913 // optionally followed by a terminator.
914 template <typename InnerOpT>
HasSingleOpInBlock(Block * block)915 bool HasSingleOpInBlock(Block *block) {
916 if (block->empty()) return false;
917 if (!llvm::isa<InnerOpT>(block->front())) return false;
918 // Either InnerOpT is the only instruction in the block, or there is a
919 // possible terminator.
920 return std::next(block->begin()) == block->end() ||
921 std::next(block->begin(), 2) == block->end();
922 }
923
924 // This pattern matches GraphOps with only one FetchOp (empty) and remaps the
925 // results of the GraphOp to the operands of the FetchOp.
926 struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
927 using OpRewritePattern<GraphOp>::OpRewritePattern;
928
matchAndRewritemlir::tf_executor::__anon2366cd000411::DropEmptyGraph929 LogicalResult matchAndRewrite(GraphOp op,
930 PatternRewriter &rewriter) const override {
931 Block &block = op.GetBody();
932 // Check if graph only has one fetch.
933 if (&block.front() != &block.back()) return failure();
934
935 // Map graph results to fetch operands.
936 rewriter.replaceOp(op, op.GetFetch().fetches());
937
938 return success();
939 }
940 };
941
942 // This pattern matches GraphOps with only one island, pulls out all inner ops
943 // of the island to the block containing the GraphOp, and then removes the
944 // GraphOp.
945 struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
946 using OpRewritePattern<GraphOp>::OpRewritePattern;
947
matchAndRewritemlir::tf_executor::__anon2366cd000411::HoistInnerOpsSingleIslandGraph948 LogicalResult matchAndRewrite(GraphOp op,
949 PatternRewriter &rewriter) const override {
950 Block &block = op.GetBody();
951 // Check if graph only has one island.
952 if (!HasSingleOpInBlock<IslandOp>(&block)) return failure();
953
954 FetchOp fetch_op = op.GetFetch();
955 auto island_op = llvm::cast<IslandOp>(block.front());
956 YieldOp yield_op = island_op.GetYield();
957
958 // Map graph results to inner ops results of single island.
959 llvm::SmallVector<Value, 8> new_rets;
960 for (Value operand : fetch_op.fetches()) {
961 // Control results should not be propagated out.
962 if (operand.getType().isa<ControlType>()) break;
963
964 if (operand.getDefiningOp() != island_op) {
965 // Operand is not from island, simply propagate it out.
966 new_rets.push_back(operand);
967 } else {
968 // Lookup yield operand in island for inner op result.
969 auto result = operand.cast<OpResult>();
970 new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
971 }
972 }
973
974 // Move inner ops from island to block containing graph.
975 auto &island_body = island_op.GetBody().getOperations();
976 Operation *operation = op.getOperation();
977 operation->getBlock()->getOperations().splice(
978 operation->getIterator(), island_body, island_body.begin(),
979 std::prev(island_body.end()));
980 rewriter.replaceOp(op, new_rets);
981
982 return success();
983 }
984 };
985 } // anonymous namespace
986
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)987 void GraphOp::getCanonicalizationPatterns(RewritePatternSet &results,
988 MLIRContext *context) {
989 results.add<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context);
990 }
991
992 //===----------------------------------------------------------------------===//
993 // tf_executor.island
994 //===----------------------------------------------------------------------===//
995
996 namespace {
997 // This pattern matches and removes IslandOps with no inner ops, no control
998 // operands and no data results. Control result users will have their relevant
999 // operands removed.
1000 struct DropEmptyIslandNoOperandNoDataResult
1001 : public OpRewritePattern<IslandOp> {
1002 using OpRewritePattern<IslandOp>::OpRewritePattern;
1003
matchAndRewritemlir::tf_executor::__anon2366cd000511::DropEmptyIslandNoOperandNoDataResult1004 LogicalResult matchAndRewrite(IslandOp op,
1005 PatternRewriter &rewriter) const override {
1006 if (op.getNumOperands() != 0 || op.getNumResults() != 1 ||
1007 !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1008 return failure();
1009
1010 for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1011 use.getOwner()->eraseOperand(use.getOperandNumber());
1012
1013 rewriter.eraseOp(op);
1014
1015 return success();
1016 }
1017 };
1018
1019 // This pattern matches and removes IslandOps with no inner ops, no control
1020 // operands, one data result and no control result user. The single data result
1021 // (from YieldOps first operand) is forwarded to the IslandOp single data result
1022 // users.
1023 struct DropEmptyIslandNoOperandOneDataResult
1024 : public OpRewritePattern<IslandOp> {
1025 using OpRewritePattern<IslandOp>::OpRewritePattern;
1026
matchAndRewritemlir::tf_executor::__anon2366cd000511::DropEmptyIslandNoOperandOneDataResult1027 LogicalResult matchAndRewrite(IslandOp op,
1028 PatternRewriter &rewriter) const override {
1029 if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
1030 !op.control().use_empty() ||
1031 !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
1032 return failure();
1033
1034 rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr});
1035
1036 return success();
1037 }
1038 };
1039
1040 // TODO(lyandy): Add canonicalization for empty IslandOps with more than one
1041 // control operand and no data results.
1042
1043 } // anonymous namespace
1044
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1045 void IslandOp::getCanonicalizationPatterns(RewritePatternSet &results,
1046 MLIRContext *context) {
1047 results.add<DropEmptyIslandNoOperandNoDataResult,
1048 DropEmptyIslandNoOperandOneDataResult>(context);
1049 }
1050
1051 //===----------------------------------------------------------------------===//
1052 // tf_executor.ControlTrigger
1053 //===----------------------------------------------------------------------===//
1054
1055 namespace {
1056 // This pattern matches and removes ControlTriggerOps with no control operands.
1057 // Control result users will have their relevant operands removed.
1058 struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
1059 using OpRewritePattern<ControlTriggerOp>::OpRewritePattern;
1060
matchAndRewritemlir::tf_executor::__anon2366cd000611::DropEmptyControlTrigger1061 LogicalResult matchAndRewrite(ControlTriggerOp op,
1062 PatternRewriter &rewriter) const override {
1063 if (op.getNumOperands() != 0) return failure();
1064
1065 for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
1066 use.getOwner()->eraseOperand(use.getOperandNumber());
1067
1068 rewriter.eraseOp(op);
1069
1070 return success();
1071 }
1072 };
1073 } // anonymous namespace
1074
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1075 void ControlTriggerOp::getCanonicalizationPatterns(RewritePatternSet &results,
1076 MLIRContext *context) {
1077 results.add<DropEmptyControlTrigger>(context);
1078 }
1079
1080 //===----------------------------------------------------------------------===//
1081 // Folders
1082 //===----------------------------------------------------------------------===//
1083
1084 //===----------------------------------------------------------------------===//
1085 // tf_executor.island
1086 //===----------------------------------------------------------------------===//
1087
fold(llvm::ArrayRef<Attribute> operands,llvm::SmallVectorImpl<OpFoldResult> & results)1088 LogicalResult IslandOp::fold(llvm::ArrayRef<Attribute> operands,
1089 llvm::SmallVectorImpl<OpFoldResult> &results) {
1090 // This folds IslandOps with no inner ops, one control operand and no data
1091 // results. The single control operand is forwarded to the IslandOp control
1092 // result users.
1093 if (getNumOperands() != 1 || getNumResults() != 1 ||
1094 !HasSingleOpInBlock<YieldOp>(&GetBody()))
1095 return failure();
1096
1097 results.emplace_back(getOperand(0));
1098
1099 return success();
1100 }
1101
1102 } // namespace tf_executor
1103 } // namespace mlir
1104
1105 //===----------------------------------------------------------------------===//
1106 // TableGen'd op method definitions
1107 //===----------------------------------------------------------------------===//
1108
1109 #define GET_OP_CLASSES
1110 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
1111