1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to their region based counterparts, i.e.,
18 // tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion
19 
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/IR/Verifier.h"  // from @llvm-project
30 #include "mlir/IR/Visitors.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
38 
39 #define DEBUG_TYPE "tf-functional-cf-to-region"
40 
41 namespace mlir {
42 namespace TF {
43 
44 namespace {
45 
46 struct FunctionalControlFlowToRegions
47     : public TF::FunctionalControlFlowToRegionsPassBase<
48           FunctionalControlFlowToRegions> {
49   void runOnOperation() override;
50 };
51 
52 // Creates a call to function `func` in region `caller_region`. Use `args` as
53 // the call arguments, and terminate the region with a yield. The arguments are
54 // cast to the required type before the call. `use_region_args` control whether
55 // the input arguments are used as is (for IfOp) or block arguments of the same
56 // type as the input arguments are created and then used as call arguments (for
57 // While).
CreateCall(Operation * op,func::FuncOp func,Region & caller_region,ValueRange args,bool use_region_args)58 YieldOp CreateCall(Operation* op, func::FuncOp func, Region& caller_region,
59                    ValueRange args, bool use_region_args) {
60   assert(caller_region.empty() &&
61          "Expected empty region for newly created ops");
62   OpBuilder builder(caller_region);
63   Block* entry = builder.createBlock(&caller_region);
64 
65   auto loc = op->getLoc();
66   if (use_region_args) {
67     auto inputs = func.getFunctionType().getInputs();
68     entry->addArguments(inputs, SmallVector<Location>(inputs.size(), loc));
69     args = entry->getArguments();
70   }
71   llvm::SmallVector<Value, 4> casted_args;
72   casted_args.reserve(func.getNumArguments());
73   for (const auto& ArgAndType : zip(args, func.getFunctionType().getInputs())) {
74     Value arg = std::get<0>(ArgAndType);
75     Type expected_type = std::get<1>(ArgAndType);
76     if (arg.getType() != expected_type) {
77       arg = builder.create<CastOp>(loc, expected_type, arg,
78                                    /*Truncate=*/builder.getBoolAttr(false));
79     }
80     casted_args.push_back(arg);
81   }
82   auto call = builder.create<func::CallOp>(loc, func, casted_args);
83   return builder.create<YieldOp>(loc, call.getResults());
84 }
85 
86 // Converts the condition for an IfOp/WhileOp to a boolean value.
ConvertConditionToBoolean(Operation * op,Value cond)87 Value ConvertConditionToBoolean(Operation* op, Value cond) {
88   if (auto ranked_type = cond.getType().dyn_cast<RankedTensorType>())
89     if (ranked_type.getRank() == 0 &&
90         ranked_type.getElementType().isSignlessInteger(1))
91       return cond;
92 
93   OpBuilder builder(op);
94   return builder.create<TF::ToBoolOp>(op->getLoc(), cond);
95 }
96 
97 // Transform a functional IfOp to a region based IfRegionOp.
ConvertIfOp(IfOp if_op)98 LogicalResult ConvertIfOp(IfOp if_op) {
99   Value cond = ConvertConditionToBoolean(if_op, if_op.cond());
100   OpBuilder builder(if_op);
101   auto if_region = builder.create<TF::IfRegionOp>(
102       if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(),
103       builder.getStringAttr(if_op.then_function().getName()),
104       builder.getStringAttr(if_op.else_function().getName()));
105   CopyDeviceAndUnderscoredAttributes(if_op, if_region);
106 
107   CreateCall(if_op, if_op.then_function(),
108              /*caller_region=*/if_region.then_branch(), if_op.input(),
109              /*use_region_args=*/false);
110   CreateCall(if_op, if_op.else_function(),
111              /*caller_region=*/if_region.else_branch(), if_op.input(),
112              /*use_region_args=*/false);
113   if_op.replaceAllUsesWith(if_region.getResults());
114   if_op.erase();
115   return success();
116 }
117 
ConvertWhileOp(WhileOp while_op)118 LogicalResult ConvertWhileOp(WhileOp while_op) {
119   auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
120       while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
121       while_op.parallel_iterations(), while_op.is_stateless(),
122       while_op.shape_invariant());
123   CopyDeviceAndUnderscoredAttributes(while_op, while_region);
124 
125   YieldOp cond_yield =
126       CreateCall(while_op, while_op.cond_function(),
127                  /*caller_region=*/while_region.cond(), while_op.input(),
128                  /*use_region_args=*/true);
129   Value i1_cond =
130       ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
131   cond_yield.setOperand(0, i1_cond);
132 
133   CreateCall(while_op, while_op.body_function(),
134              /*caller_region=*/while_region.body(), while_op.input(),
135              /*use_region_args=*/true);
136   while_op.replaceAllUsesWith(while_region.getResults());
137   while_op.erase();
138   return success();
139 }
140 
runOnOperation()141 void FunctionalControlFlowToRegions::runOnOperation() {
142   ModuleOp module = getOperation();
143   auto result = module.walk([](Operation* op) {
144     if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
145       if (failed(ConvertIfOp(if_op))) {
146         op->emitOpError() << "failed to convert to region form";
147         return WalkResult::interrupt();
148       }
149     } else if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
150       if (failed(ConvertWhileOp(while_op))) {
151         op->emitOpError() << "failed to convert to region form";
152         return WalkResult::interrupt();
153       }
154     }
155     return WalkResult::advance();
156   });
157   if (result.wasInterrupted()) return signalPassFailure();
158 }
159 }  // namespace
160 
161 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFFunctionalControlFlowToRegions()162 CreateTFFunctionalControlFlowToRegions() {
163   return std::make_unique<FunctionalControlFlowToRegions>();
164 }
165 
166 }  // namespace TF
167 }  // namespace mlir
168