1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Block.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30
31 namespace mlir {
32 namespace TFDevice {
33
34 namespace {
35
36 constexpr char kFuncAttr[] = "func";
37
38 struct ClusterOutliningPass
39 : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40 void runOnOperation() override;
41 };
42
43 struct LaunchOutliningPass
44 : public TF::LaunchOutliningPassBase<LaunchOutliningPass> {
45 void runOnOperation() override;
46 };
47
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)48 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
49 OpBuilder* builder) {
50 builder->create<func::ReturnOp>(cluster_return_op.getLoc(),
51 cluster_return_op.getOperands());
52 cluster_return_op.erase();
53 }
54
55 // Builds a function that outlines region attached to cluster_op or launch_op,
56 // and inserts built function into given module.
57 template <typename ClusterOrLaunchOp>
BuildFunction(llvm::ArrayRef<Value> live_ins,ClusterOrLaunchOp op,SymbolTable * symbol_table,OpBuilder * builder)58 func::FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins, ClusterOrLaunchOp op,
59 SymbolTable* symbol_table, OpBuilder* builder) {
60 llvm::SmallVector<Type, 4> operand_types;
61 operand_types.reserve(live_ins.size());
62 for (Value v : live_ins) operand_types.emplace_back(v.getType());
63
64 auto func_type = builder->getFunctionType(operand_types, op.getResultTypes());
65
66 // TODO(lyandy): Define better name for outlined function. Potentially some
67 // name can be added during cluster formation.
68 func::FuncOp outlined_func =
69 func::FuncOp::create(op.getLoc(), "_func", func_type);
70
71 // This function is not externally visible and marking it private would allow
72 // symbol-dce pass to remove it when it is not referenced anymore.
73 outlined_func.setPrivate();
74
75 // Create function body.
76 Block* outlined_func_block = outlined_func.addEntryBlock();
77
78 // Replace uses of live-in values within cluster_op region with function
79 // arguments.
80 Region& op_region = op.body();
81 for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
82 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op_region);
83 }
84
85 // Move all instructions in cluster_op into outlined_function's only block.
86 auto& op_body = op.GetBody().getOperations();
87 outlined_func_block->getOperations().splice(
88 outlined_func_block->end(), op_body, op_body.begin(), op_body.end());
89
90 // Replace `tf_device.return` terminator with `std.return` in function
91 // body.
92 auto return_op =
93 cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
94 builder->setInsertionPoint(return_op);
95 ReplaceClusterReturnWithReturn(return_op, builder);
96
97 symbol_table->insert(outlined_func);
98 return outlined_func;
99 }
100
101 // Outlines body of `tf_device.cluster` into a function and create a
102 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
103 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)104 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
105 OpBuilder* builder) {
106 llvm::SetVector<Value> live_ins;
107 getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
108
109 func::FuncOp outlined_func =
110 BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
111 cluster_op->setAttr(
112 builder->getStringAttr(kFuncAttr),
113 mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName()));
114
115 builder->setInsertionPoint(cluster_op);
116 auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
117 cluster_op.getLoc(), outlined_func.getFunctionType().getResults(),
118 live_ins.getArrayRef(), cluster_op->getAttrs());
119
120 cluster_op.replaceAllUsesWith(cluster_func_op);
121 cluster_op.erase();
122 }
123
124 // Outlines body of `tf_device.launch` into a function and create a
125 // `tf_device.launch_func` to invoke that function. `tf_device.launch` is
126 // removed afterwards.`
OutlineLaunch(tf_device::LaunchOp launch_op,SymbolTable * symbol_table,OpBuilder * builder)127 void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
128 OpBuilder* builder) {
129 llvm::SetVector<Value> live_ins;
130 getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
131
132 func::FuncOp outlined_func =
133 BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder);
134 launch_op->setAttr(
135 builder->getStringAttr(kFuncAttr),
136 mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName()));
137
138 builder->setInsertionPoint(launch_op);
139 auto cluster_func_op = builder->create<tf_device::LaunchFuncOp>(
140 launch_op.getLoc(), outlined_func.getFunctionType().getResults(),
141 live_ins.getArrayRef(), launch_op->getAttrs());
142
143 launch_op.replaceAllUsesWith(cluster_func_op);
144 launch_op.erase();
145 }
146
runOnOperation()147 void ClusterOutliningPass::runOnOperation() {
148 ModuleOp module = getOperation();
149 SymbolTable symbol_table(module);
150 OpBuilder builder(module.getContext());
151 module.walk([&](tf_device::ClusterOp cluster) {
152 OutlineCluster(cluster, &symbol_table, &builder);
153 });
154 }
155
runOnOperation()156 void LaunchOutliningPass::runOnOperation() {
157 ModuleOp module = getOperation();
158 SymbolTable symbol_table(module);
159 OpBuilder builder(module.getContext());
160 module.walk([&](tf_device::LaunchOp launch) {
161 OutlineLaunch(launch, &symbol_table, &builder);
162 });
163 }
164
165 } // namespace
166
CreateClusterOutliningPass()167 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
168 return std::make_unique<ClusterOutliningPass>();
169 }
170
CreateLaunchOutliningPass()171 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass() {
172 return std::make_unique<LaunchOutliningPass>();
173 }
174
175 } // namespace TFDevice
176 } // namespace mlir
177