1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
18
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33
34 namespace mlir {
35 namespace TF {
36
37 namespace {
38
39 struct FunctionalControlFlowToCFG
40 : public FunctionalControlFlowToCFGPassBase<FunctionalControlFlowToCFG> {
41 void runOnOperation() override;
42 };
43
44 // Lowers a general tensor argument that is used as a condition to a functional
45 // control flow op into an i1 value.
LowerCondition(Location loc,Value value,OpBuilder * builder)46 static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
47 auto zero_d = builder->create<ToBoolOp>(loc, value);
48 auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
49 return scalar.getResult();
50 }
51
52 // Calls the function `fn` with arguments provided by the given function and
53 // return the CallOp. Arguments are cast to the required type before calling
54 // the function.
55 //
56 // Requires the function to provide arguments for each of the `fn` operands
57 // that is compatible for tensor cast.
CallFn(Location loc,const std::function<Value (int)> & get_arg,func::FuncOp fn,OpBuilder * builder)58 static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
59 func::FuncOp fn, OpBuilder* builder) {
60 FunctionType fn_type = fn.getFunctionType();
61 llvm::SmallVector<Value, 4> operands;
62 int num_operands = fn_type.getNumInputs();
63 operands.reserve(num_operands);
64 for (int i = 0; i < num_operands; ++i) {
65 Value val = get_arg(i);
66 Type expected = fn_type.getInput(i);
67 if (val.getType() != expected) {
68 val =
69 builder->create<TF::CastOp>(loc, expected, val,
70 /*Truncate=*/builder->getBoolAttr(false));
71 }
72 operands.push_back(val);
73 }
74 return builder->create<func::CallOp>(loc, fn, operands).getOperation();
75 }
76
77 // Prepares for jump to the given block by introducing necessary tensor_cast
78 // operations and returning Values of types required by the block.
79 //
80 // Requires the function to provide values for each of the block arguments and
81 // they should be pair-wise compatible for tensor cast.
PrepareValsForJump(Location loc,const std::function<Value (int)> & get_val,Block * block,OpBuilder * builder)82 static llvm::SmallVector<Value, 4> PrepareValsForJump(
83 Location loc, const std::function<Value(int)>& get_val, Block* block,
84 OpBuilder* builder) {
85 llvm::SmallVector<Value, 4> result;
86 int num_vals = block->getNumArguments();
87 result.reserve(num_vals);
88 for (int i = 0; i < num_vals; ++i) {
89 Value val = get_val(i);
90 Type expected = block->getArgument(i).getType();
91 if (val.getType() != expected) {
92 val =
93 builder->create<TF::CastOp>(loc, expected, val,
94 /*Truncate=*/builder->getBoolAttr(false));
95 }
96 result.push_back(val);
97 }
98 return result;
99 }
100
101 // Jumps to the given block with arguments provided by the function. Arguments
102 // are cast to the required type before the jump.
103 //
104 // Requires the function to provide values for each of the block arguments and
105 // they should be pair-wise compatible for tensor cast.
JumpToBlock(Location loc,const std::function<Value (int)> & get_arg,Block * block,OpBuilder * builder)106 static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
107 Block* block, OpBuilder* builder) {
108 auto operands = PrepareValsForJump(loc, get_arg, block, builder);
109 builder->create<cf::BranchOp>(loc, block, operands);
110 }
111
112 // Replaces all uses of the operation results in this block with block
113 // arguments.
114 //
115 // Requires that the block has same number of arguments as number of results of
116 // the operation and either they have same types or are more generic types and
117 // it is possible to cast them to results' types.
ReplaceOpResultWithBlockArgs(Location loc,Operation * op,Block * block,OpBuilder * builder)118 static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
119 Block* block, OpBuilder* builder) {
120 assert(op->getNumResults() == block->getNumArguments());
121 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
122 Value arg = block->getArgument(i);
123 Value result = op->getResult(i);
124 if (arg.getType() != result.getType()) {
125 arg =
126 builder->create<TF::CastOp>(loc, result.getType(), arg,
127 /*Truncate=*/builder->getBoolAttr(false));
128 }
129 result.replaceAllUsesWith(arg);
130 }
131 }
132
133 // Given a functional IfOp, transforms the enclosing code to eliminate it
134 // completely from the IR, breaking it into operations to evaluate the condition
135 // as a bool, plus some branches.
LowerIfOp(IfOp op)136 static LogicalResult LowerIfOp(IfOp op) {
137 Operation* op_inst = op.getOperation();
138 Location loc = op_inst->getLoc();
139
140 OpBuilder builder(op_inst);
141
142 // Lower the condition to a boolean value (i1).
143 Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
144 if (!cond_i1) return failure();
145
146 // Split the basic block before the 'if'. The new dest will be our merge
147 // point.
148 Block* orig_block = op_inst->getBlock();
149 Block* merge_block = orig_block->splitBlock(op);
150
151 // Add the block arguments to the merge point, and replace all uses of the
152 // original operation results with them.
153 for (Value value : op_inst->getResults())
154 merge_block->addArgument(value.getType(), loc);
155 ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
156
157 // Get arguments to the branches after dropping the condition which is the
158 // first operand.
159 auto get_operand = [&](int i) { return op_inst->getOperand(i + 1); };
160
161 // Set up the 'then' block.
162 Block* then_block = builder.createBlock(merge_block);
163 Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
164
165 auto get_then_result = [&](int i) { return call_op->getResult(i); };
166 JumpToBlock(loc, get_then_result, merge_block, &builder);
167
168 // Set up the 'else' block.
169 Block* else_block = builder.createBlock(merge_block);
170 call_op = CallFn(loc, get_operand, op.else_function(), &builder);
171
172 auto get_else_result = [&](int i) { return call_op->getResult(i); };
173 JumpToBlock(loc, get_else_result, merge_block, &builder);
174
175 // Now that we have the then and else blocks, replace the terminator of the
176 // orig_block with a conditional branch.
177 builder.setInsertionPointToEnd(orig_block);
178 builder.create<cf::CondBranchOp>(loc, cond_i1, then_block,
179 llvm::ArrayRef<Value>(), else_block,
180 llvm::ArrayRef<Value>());
181
182 // Finally, delete the op in question.
183 op_inst->erase();
184 return success();
185 }
186
187 // Given a functional WhileOp, transforms the enclosing code to eliminate it
188 // completely from the IR, breaking it into operations to execute the loop body
189 // repeatedly while the loop condition is true.
LowerWhileOp(WhileOp op)190 static LogicalResult LowerWhileOp(WhileOp op) {
191 Operation* op_inst = op.getOperation();
192 Location loc = op_inst->getLoc();
193
194 OpBuilder builder(op_inst);
195
196 auto cond_fn = op.cond_function();
197 auto body_fn = op.body_function();
198
199 // Split the block containing the While op into two blocks. One containing
200 // operations before the While op and other containing the rest. Create two
201 // new blocks to call condition and body functions.
202 //
203 // The final control flow graph would be as follows:
204 //
205 // ...
206 // orig_block_head(...):
207 // ...
208 // br cond_block(...)
209 // cond_block(...):
210 // %A = call @cond(...)
211 // cond br %A, body_block(...), orig_block_tail(...)
212 // body_block(...):
213 // %B = call @body(...)
214 // br cond_block(...)
215 // orig_block_tail(...):
216 // ...
217 //
218 Block* orig_block_head = op_inst->getBlock();
219 Block* orig_block_tail = orig_block_head->splitBlock(op);
220 Block* cond_block = builder.createBlock(orig_block_tail);
221 Block* body_block = builder.createBlock(orig_block_tail);
222
223 // Set argument types for the cond_block to be same as the types of the
224 // condition function and argument types for the other two blocks to be same
225 // as the input types of the body function. Note that it is always possible
226 // for body_block and orig_block_tail to have arguments of the same types as
227 // they have exactly one call-site and they are sharing the operands.
228 for (Type type : cond_fn.getFunctionType().getInputs()) {
229 cond_block->addArgument(type, loc);
230 }
231 for (Type type : body_fn.getFunctionType().getInputs()) {
232 body_block->addArgument(type, loc);
233 orig_block_tail->addArgument(type, loc);
234 }
235
236 auto get_operand = [&](int i) { return op_inst->getOperand(i); };
237
238 // Unconditionally branch from the original block to the block containing the
239 // condition.
240 builder.setInsertionPointToEnd(orig_block_head);
241 JumpToBlock(loc, get_operand, cond_block, &builder);
242
243 // Call condition function in the condition block and then branch to the body
244 // block or remainder of the original block depending on condition function
245 // result.
246 builder.setInsertionPointToEnd(cond_block);
247
248 auto get_cond_arg = [&](int i) { return cond_block->getArgument(i); };
249 Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
250
251 assert(cond_call_op->getNumResults() == 1);
252 Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
253 auto br_operands =
254 PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
255 builder.create<cf::CondBranchOp>(loc, condition, body_block, br_operands,
256 orig_block_tail, br_operands);
257
258 // Call body function in the body block and then unconditionally branch back
259 // to the condition block.
260 builder.setInsertionPointToEnd(body_block);
261 auto get_body_arg = [&](int i) { return body_block->getArgument(i); };
262 Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder);
263
264 auto get_body_result = [&](int i) { return body_call_op->getResult(i); };
265 JumpToBlock(loc, get_body_result, cond_block, &builder);
266
267 // Replace use of the while loop results with block inputs in the remainder of
268 // the original block and then delete the original While operation.
269 builder.setInsertionPoint(&orig_block_tail->front());
270 ReplaceOpResultWithBlockArgs(loc, op_inst, orig_block_tail, &builder);
271 op_inst->erase();
272
273 return success();
274 }
275
runOnOperation()276 void FunctionalControlFlowToCFG::runOnOperation() {
277 // Scan the function looking for these ops.
278 for (Block& block : getOperation()) {
279 for (Operation& op : block) {
280 // If the operation is one of the control flow ops we know, lower it.
281 // If we lower an operation, then the current basic block will be split,
282 // and the operation will be removed, so we should continue looking at
283 // subsequent blocks.
284 //
285 // TODO: Use PatternRewriter to eliminate these function control flow ops.
286
287 if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
288 if (failed(LowerIfOp(if_op))) {
289 return signalPassFailure();
290 }
291 break;
292 }
293 if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
294 if (failed(LowerWhileOp(while_op))) {
295 return signalPassFailure();
296 }
297 break;
298 }
299 }
300 }
301 }
302
303 } // namespace
304
305 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTFFunctionalControlFlowToCFG()306 CreateTFFunctionalControlFlowToCFG() {
307 return std::make_unique<FunctionalControlFlowToCFG>();
308 }
309
310 } // namespace TF
311 } // namespace mlir
312