1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to their region based counterparts, i.e.,
18 // tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion
19
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/IR/Value.h" // from @llvm-project
29 #include "mlir/IR/Verifier.h" // from @llvm-project
30 #include "mlir/IR/Visitors.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
38
39 #define DEBUG_TYPE "tf-functional-cf-to-region"
40
41 namespace mlir {
42 namespace TF {
43
44 namespace {
45
46 struct FunctionalControlFlowToRegions
47 : public TF::FunctionalControlFlowToRegionsPassBase<
48 FunctionalControlFlowToRegions> {
49 void runOnOperation() override;
50 };
51
52 // Creates a call to function `func` in region `caller_region`. Use `args` as
53 // the call arguments, and terminate the region with a yield. The arguments are
54 // cast to the required type before the call. `use_region_args` control whether
55 // the input arguments are used as is (for IfOp) or block arguments of the same
56 // type as the input arguments are created and then used as call arguments (for
57 // While).
CreateCall(Operation * op,func::FuncOp func,Region & caller_region,ValueRange args,bool use_region_args)58 YieldOp CreateCall(Operation* op, func::FuncOp func, Region& caller_region,
59 ValueRange args, bool use_region_args) {
60 assert(caller_region.empty() &&
61 "Expected empty region for newly created ops");
62 OpBuilder builder(caller_region);
63 Block* entry = builder.createBlock(&caller_region);
64
65 auto loc = op->getLoc();
66 if (use_region_args) {
67 auto inputs = func.getFunctionType().getInputs();
68 entry->addArguments(inputs, SmallVector<Location>(inputs.size(), loc));
69 args = entry->getArguments();
70 }
71 llvm::SmallVector<Value, 4> casted_args;
72 casted_args.reserve(func.getNumArguments());
73 for (const auto& ArgAndType : zip(args, func.getFunctionType().getInputs())) {
74 Value arg = std::get<0>(ArgAndType);
75 Type expected_type = std::get<1>(ArgAndType);
76 if (arg.getType() != expected_type) {
77 arg = builder.create<CastOp>(loc, expected_type, arg,
78 /*Truncate=*/builder.getBoolAttr(false));
79 }
80 casted_args.push_back(arg);
81 }
82 auto call = builder.create<func::CallOp>(loc, func, casted_args);
83 return builder.create<YieldOp>(loc, call.getResults());
84 }
85
86 // Converts the condition for an IfOp/WhileOp to a boolean value.
ConvertConditionToBoolean(Operation * op,Value cond)87 Value ConvertConditionToBoolean(Operation* op, Value cond) {
88 if (auto ranked_type = cond.getType().dyn_cast<RankedTensorType>())
89 if (ranked_type.getRank() == 0 &&
90 ranked_type.getElementType().isSignlessInteger(1))
91 return cond;
92
93 OpBuilder builder(op);
94 return builder.create<TF::ToBoolOp>(op->getLoc(), cond);
95 }
96
97 // Transform a functional IfOp to a region based IfRegionOp.
ConvertIfOp(IfOp if_op)98 LogicalResult ConvertIfOp(IfOp if_op) {
99 Value cond = ConvertConditionToBoolean(if_op, if_op.cond());
100 OpBuilder builder(if_op);
101 auto if_region = builder.create<TF::IfRegionOp>(
102 if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(),
103 builder.getStringAttr(if_op.then_function().getName()),
104 builder.getStringAttr(if_op.else_function().getName()));
105 CopyDeviceAndUnderscoredAttributes(if_op, if_region);
106
107 CreateCall(if_op, if_op.then_function(),
108 /*caller_region=*/if_region.then_branch(), if_op.input(),
109 /*use_region_args=*/false);
110 CreateCall(if_op, if_op.else_function(),
111 /*caller_region=*/if_region.else_branch(), if_op.input(),
112 /*use_region_args=*/false);
113 if_op.replaceAllUsesWith(if_region.getResults());
114 if_op.erase();
115 return success();
116 }
117
ConvertWhileOp(WhileOp while_op)118 LogicalResult ConvertWhileOp(WhileOp while_op) {
119 auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
120 while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
121 while_op.parallel_iterations(), while_op.is_stateless(),
122 while_op.shape_invariant());
123 CopyDeviceAndUnderscoredAttributes(while_op, while_region);
124
125 YieldOp cond_yield =
126 CreateCall(while_op, while_op.cond_function(),
127 /*caller_region=*/while_region.cond(), while_op.input(),
128 /*use_region_args=*/true);
129 Value i1_cond =
130 ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
131 cond_yield.setOperand(0, i1_cond);
132
133 CreateCall(while_op, while_op.body_function(),
134 /*caller_region=*/while_region.body(), while_op.input(),
135 /*use_region_args=*/true);
136 while_op.replaceAllUsesWith(while_region.getResults());
137 while_op.erase();
138 return success();
139 }
140
runOnOperation()141 void FunctionalControlFlowToRegions::runOnOperation() {
142 ModuleOp module = getOperation();
143 auto result = module.walk([](Operation* op) {
144 if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
145 if (failed(ConvertIfOp(if_op))) {
146 op->emitOpError() << "failed to convert to region form";
147 return WalkResult::interrupt();
148 }
149 } else if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
150 if (failed(ConvertWhileOp(while_op))) {
151 op->emitOpError() << "failed to convert to region form";
152 return WalkResult::interrupt();
153 }
154 }
155 return WalkResult::advance();
156 });
157 if (result.wasInterrupted()) return signalPassFailure();
158 }
159 } // namespace
160
161 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFFunctionalControlFlowToRegions()162 CreateTFFunctionalControlFlowToRegions() {
163 return std::make_unique<FunctionalControlFlowToRegions>();
164 }
165
166 } // namespace TF
167 } // namespace mlir
168