1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
18 
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"  // from @llvm-project
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 
34 namespace mlir {
35 namespace TF {
36 
37 namespace {
38 
39 struct FunctionalControlFlowToCFG
40     : public FunctionalControlFlowToCFGPassBase<FunctionalControlFlowToCFG> {
41   void runOnOperation() override;
42 };
43 
44 // Lowers a general tensor argument that is used as a condition to a functional
45 // control flow op into an i1 value.
LowerCondition(Location loc,Value value,OpBuilder * builder)46 static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
47   auto zero_d = builder->create<ToBoolOp>(loc, value);
48   auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
49   return scalar.getResult();
50 }
51 
52 // Calls the function `fn` with arguments provided by the given function and
53 // return the CallOp. Arguments are cast to the required type before calling
54 // the function.
55 //
56 // Requires the function to provide arguments for each of the `fn` operands
57 // that is compatible for tensor cast.
CallFn(Location loc,const std::function<Value (int)> & get_arg,func::FuncOp fn,OpBuilder * builder)58 static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
59                          func::FuncOp fn, OpBuilder* builder) {
60   FunctionType fn_type = fn.getFunctionType();
61   llvm::SmallVector<Value, 4> operands;
62   int num_operands = fn_type.getNumInputs();
63   operands.reserve(num_operands);
64   for (int i = 0; i < num_operands; ++i) {
65     Value val = get_arg(i);
66     Type expected = fn_type.getInput(i);
67     if (val.getType() != expected) {
68       val =
69           builder->create<TF::CastOp>(loc, expected, val,
70                                       /*Truncate=*/builder->getBoolAttr(false));
71     }
72     operands.push_back(val);
73   }
74   return builder->create<func::CallOp>(loc, fn, operands).getOperation();
75 }
76 
77 // Prepares for jump to the given block by introducing necessary tensor_cast
78 // operations and returning Values of types required by the block.
79 //
80 // Requires the function to provide values for each of the block arguments and
81 // they should be pair-wise compatible for tensor cast.
PrepareValsForJump(Location loc,const std::function<Value (int)> & get_val,Block * block,OpBuilder * builder)82 static llvm::SmallVector<Value, 4> PrepareValsForJump(
83     Location loc, const std::function<Value(int)>& get_val, Block* block,
84     OpBuilder* builder) {
85   llvm::SmallVector<Value, 4> result;
86   int num_vals = block->getNumArguments();
87   result.reserve(num_vals);
88   for (int i = 0; i < num_vals; ++i) {
89     Value val = get_val(i);
90     Type expected = block->getArgument(i).getType();
91     if (val.getType() != expected) {
92       val =
93           builder->create<TF::CastOp>(loc, expected, val,
94                                       /*Truncate=*/builder->getBoolAttr(false));
95     }
96     result.push_back(val);
97   }
98   return result;
99 }
100 
101 // Jumps to the given block with arguments provided by the function. Arguments
102 // are cast to the required type before the jump.
103 //
104 // Requires the function to provide values for each of the block arguments and
105 // they should be pair-wise compatible for tensor cast.
JumpToBlock(Location loc,const std::function<Value (int)> & get_arg,Block * block,OpBuilder * builder)106 static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
107                         Block* block, OpBuilder* builder) {
108   auto operands = PrepareValsForJump(loc, get_arg, block, builder);
109   builder->create<cf::BranchOp>(loc, block, operands);
110 }
111 
112 // Replaces all uses of the operation results in this block with block
113 // arguments.
114 //
115 // Requires that the block has same number of arguments as number of results of
116 // the operation and either they have same types or are more generic types and
117 // it is possible to cast them to results' types.
ReplaceOpResultWithBlockArgs(Location loc,Operation * op,Block * block,OpBuilder * builder)118 static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
119                                          Block* block, OpBuilder* builder) {
120   assert(op->getNumResults() == block->getNumArguments());
121   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
122     Value arg = block->getArgument(i);
123     Value result = op->getResult(i);
124     if (arg.getType() != result.getType()) {
125       arg =
126           builder->create<TF::CastOp>(loc, result.getType(), arg,
127                                       /*Truncate=*/builder->getBoolAttr(false));
128     }
129     result.replaceAllUsesWith(arg);
130   }
131 }
132 
133 // Given a functional IfOp, transforms the enclosing code to eliminate it
134 // completely from the IR, breaking it into operations to evaluate the condition
135 // as a bool, plus some branches.
LowerIfOp(IfOp op)136 static LogicalResult LowerIfOp(IfOp op) {
137   Operation* op_inst = op.getOperation();
138   Location loc = op_inst->getLoc();
139 
140   OpBuilder builder(op_inst);
141 
142   // Lower the condition to a boolean value (i1).
143   Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
144   if (!cond_i1) return failure();
145 
146   // Split the basic block before the 'if'.  The new dest will be our merge
147   // point.
148   Block* orig_block = op_inst->getBlock();
149   Block* merge_block = orig_block->splitBlock(op);
150 
151   // Add the block arguments to the merge point, and replace all uses of the
152   // original operation results with them.
153   for (Value value : op_inst->getResults())
154     merge_block->addArgument(value.getType(), loc);
155   ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
156 
157   // Get arguments to the branches after dropping the condition which is the
158   // first operand.
159   auto get_operand = [&](int i) { return op_inst->getOperand(i + 1); };
160 
161   // Set up the 'then' block.
162   Block* then_block = builder.createBlock(merge_block);
163   Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
164 
165   auto get_then_result = [&](int i) { return call_op->getResult(i); };
166   JumpToBlock(loc, get_then_result, merge_block, &builder);
167 
168   // Set up the 'else' block.
169   Block* else_block = builder.createBlock(merge_block);
170   call_op = CallFn(loc, get_operand, op.else_function(), &builder);
171 
172   auto get_else_result = [&](int i) { return call_op->getResult(i); };
173   JumpToBlock(loc, get_else_result, merge_block, &builder);
174 
175   // Now that we have the then and else blocks, replace the terminator of the
176   // orig_block with a conditional branch.
177   builder.setInsertionPointToEnd(orig_block);
178   builder.create<cf::CondBranchOp>(loc, cond_i1, then_block,
179                                    llvm::ArrayRef<Value>(), else_block,
180                                    llvm::ArrayRef<Value>());
181 
182   // Finally, delete the op in question.
183   op_inst->erase();
184   return success();
185 }
186 
187 // Given a functional WhileOp, transforms the enclosing code to eliminate it
188 // completely from the IR, breaking it into operations to execute the loop body
189 // repeatedly while the loop condition is true.
LowerWhileOp(WhileOp op)190 static LogicalResult LowerWhileOp(WhileOp op) {
191   Operation* op_inst = op.getOperation();
192   Location loc = op_inst->getLoc();
193 
194   OpBuilder builder(op_inst);
195 
196   auto cond_fn = op.cond_function();
197   auto body_fn = op.body_function();
198 
199   // Split the block containing the While op into two blocks.  One containing
200   // operations before the While op and other containing the rest.  Create two
201   // new blocks to call condition and body functions.
202   //
203   // The final control flow graph would be as follows:
204   //
205   // ...
206   // orig_block_head(...):
207   //   ...
208   //   br cond_block(...)
209   // cond_block(...):
210   //   %A = call @cond(...)
211   //   cond br %A, body_block(...), orig_block_tail(...)
212   // body_block(...):
213   //   %B = call @body(...)
214   //   br cond_block(...)
215   // orig_block_tail(...):
216   //   ...
217   //
218   Block* orig_block_head = op_inst->getBlock();
219   Block* orig_block_tail = orig_block_head->splitBlock(op);
220   Block* cond_block = builder.createBlock(orig_block_tail);
221   Block* body_block = builder.createBlock(orig_block_tail);
222 
223   // Set argument types for the cond_block to be same as the types of the
224   // condition function and argument types for the other two blocks to be same
225   // as the input types of the body function. Note that it is always possible
226   // for body_block and orig_block_tail to have arguments of the same types as
227   // they have exactly one call-site and they are sharing the operands.
228   for (Type type : cond_fn.getFunctionType().getInputs()) {
229     cond_block->addArgument(type, loc);
230   }
231   for (Type type : body_fn.getFunctionType().getInputs()) {
232     body_block->addArgument(type, loc);
233     orig_block_tail->addArgument(type, loc);
234   }
235 
236   auto get_operand = [&](int i) { return op_inst->getOperand(i); };
237 
238   // Unconditionally branch from the original block to the block containing the
239   // condition.
240   builder.setInsertionPointToEnd(orig_block_head);
241   JumpToBlock(loc, get_operand, cond_block, &builder);
242 
243   // Call condition function in the condition block and then branch to the body
244   // block or remainder of the original block depending on condition function
245   // result.
246   builder.setInsertionPointToEnd(cond_block);
247 
248   auto get_cond_arg = [&](int i) { return cond_block->getArgument(i); };
249   Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
250 
251   assert(cond_call_op->getNumResults() == 1);
252   Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
253   auto br_operands =
254       PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
255   builder.create<cf::CondBranchOp>(loc, condition, body_block, br_operands,
256                                    orig_block_tail, br_operands);
257 
258   // Call body function in the body block and then unconditionally branch back
259   // to the condition block.
260   builder.setInsertionPointToEnd(body_block);
261   auto get_body_arg = [&](int i) { return body_block->getArgument(i); };
262   Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder);
263 
264   auto get_body_result = [&](int i) { return body_call_op->getResult(i); };
265   JumpToBlock(loc, get_body_result, cond_block, &builder);
266 
267   // Replace use of the while loop results with block inputs in the remainder of
268   // the original block and then delete the original While operation.
269   builder.setInsertionPoint(&orig_block_tail->front());
270   ReplaceOpResultWithBlockArgs(loc, op_inst, orig_block_tail, &builder);
271   op_inst->erase();
272 
273   return success();
274 }
275 
runOnOperation()276 void FunctionalControlFlowToCFG::runOnOperation() {
277   // Scan the function looking for these ops.
278   for (Block& block : getOperation()) {
279     for (Operation& op : block) {
280       // If the operation is one of the control flow ops we know, lower it.
281       // If we lower an operation, then the current basic block will be split,
282       // and the operation will be removed, so we should continue looking at
283       // subsequent blocks.
284       //
285       // TODO: Use PatternRewriter to eliminate these function control flow ops.
286 
287       if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
288         if (failed(LowerIfOp(if_op))) {
289           return signalPassFailure();
290         }
291         break;
292       }
293       if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
294         if (failed(LowerWhileOp(while_op))) {
295           return signalPassFailure();
296         }
297         break;
298       }
299     }
300   }
301 }
302 
303 }  // namespace
304 
305 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTFFunctionalControlFlowToCFG()306 CreateTFFunctionalControlFlowToCFG() {
307   return std::make_unique<FunctionalControlFlowToCFG>();
308 }
309 
310 }  // namespace TF
311 }  // namespace mlir
312